Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
167
vllm/model_executor/layers/quantization/__init__.py
Normal file
167
vllm/model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal, get_args
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
QuantizationMethods = Literal[
|
||||
"awq",
|
||||
"fp8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"fp_quant",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"modelopt_mxfp8",
|
||||
"gguf",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"gptq",
|
||||
"compressed-tensors",
|
||||
"bitsandbytes",
|
||||
"experts_int8",
|
||||
"quark",
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
"cpu_awq",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
DEPRECATED_QUANTIZATION_METHODS = [
|
||||
"tpu_int8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"fp_quant",
|
||||
"experts_int8",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
|
||||
# The customized quantization methods which will be added to this dict.
|
||||
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
|
||||
|
||||
|
||||
def register_quantization_config(quantization: str):
|
||||
"""Register a customized vllm quantization config.
|
||||
|
||||
When a quantization method is not supported by vllm, you can register a customized
|
||||
quantization config to support it.
|
||||
|
||||
Args:
|
||||
quantization (str): The quantization method name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.model_executor.layers.quantization import (
|
||||
... register_quantization_config,
|
||||
... )
|
||||
>>> from vllm.model_executor.layers.quantization import get_quantization_config
|
||||
>>> from vllm.model_executor.layers.quantization.base_config import (
|
||||
... QuantizationConfig,
|
||||
... )
|
||||
>>>
|
||||
>>> @register_quantization_config("my_quant")
|
||||
... class MyQuantConfig(QuantizationConfig):
|
||||
... pass
|
||||
>>>
|
||||
>>> get_quantization_config("my_quant")
|
||||
<class 'MyQuantConfig'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(quant_config_cls):
|
||||
if quantization in QUANTIZATION_METHODS:
|
||||
logger.warning(
|
||||
"The quantization method '%s' already exists and will be "
|
||||
"overwritten by the quantization config %s.",
|
||||
quantization,
|
||||
quant_config_cls,
|
||||
)
|
||||
else:
|
||||
QUANTIZATION_METHODS.append(quantization)
|
||||
# Automatically assume the custom quantization config is supported
|
||||
if sq := current_platform.supported_quantization:
|
||||
sq.append(quantization)
|
||||
|
||||
if not issubclass(quant_config_cls, QuantizationConfig):
|
||||
raise ValueError(
|
||||
"The quantization config must be a subclass of `QuantizationConfig`."
|
||||
)
|
||||
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
|
||||
return quant_config_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
|
||||
# lazy import to avoid triggering `torch.compile` too early
|
||||
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
|
||||
|
||||
from .awq import AWQConfig
|
||||
from .awq_marlin import AWQMarlinConfig
|
||||
from .bitsandbytes import BitsAndBytesConfig
|
||||
from .compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from .cpu_wna16 import CPUAWQConfig
|
||||
from .experts_int8 import ExpertsInt8Config
|
||||
from .fbgemm_fp8 import FBGEMMFp8Config
|
||||
from .fp8 import Fp8Config
|
||||
from .fp_quant import FPQuantConfig
|
||||
from .gguf import GGUFConfig
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .inc import INCConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .petit import PetitNvFp4Config
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .torchao import TorchAOConfig
|
||||
|
||||
method_to_config: dict[str, type[QuantizationConfig]] = {
|
||||
"awq": AWQConfig,
|
||||
"fp8": Fp8Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"fp_quant": FPQuantConfig,
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptNvFp4Config,
|
||||
"modelopt_mxfp8": ModelOptMxFp8Config,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"ptpc_fp8": PTPCFp8Config,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"quark": QuarkConfig,
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"torchao": TorchAOConfig,
|
||||
"auto-round": INCConfig,
|
||||
"inc": INCConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"cpu_awq": CPUAWQConfig,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
return method_to_config[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"QuantizationMethods",
|
||||
"get_quantization_config",
|
||||
"register_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
279
vllm/model_executor/layers/quantization/awq.py
Normal file
279
vllm/model_executor/layers/quantization/awq.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"AWQ, but got {self.weight_bits} bits."
|
||||
)
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})"
|
||||
)
|
||||
|
||||
def get_name(self) -> "QuantizationMethods":
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# The AWQ kernel only supports Turing or newer GPUs.
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(
|
||||
prefix,
|
||||
self.modules_to_not_convert,
|
||||
self.packed_modules_mapping,
|
||||
skip_with_substr=True,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
# Lazy import to avoid circular import.
|
||||
from .awq_marlin import AWQMarlinConfig
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .utils.marlin_utils import check_moe_marlin_supports_layer
|
||||
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels."
|
||||
)
|
||||
config = {
|
||||
"quant_method": "awq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"zero_point": self.zero_point,
|
||||
"lm_head": False,
|
||||
"modules_to_not_convert": self.modules_to_not_convert,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
marlin_compatible_config_dict = {
|
||||
"quant_method": "awq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"zero_point": self.zero_point,
|
||||
"lm_head": False,
|
||||
"modules_to_not_convert": self.modules_to_not_convert,
|
||||
}
|
||||
awq_marlin_config = AWQMarlinConfig.from_config(
|
||||
marlin_compatible_config_dict
|
||||
)
|
||||
return awq_marlin_config.get_quant_method(layer, prefix)
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_to_not_convert:
|
||||
self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_to_not_convert = list(layers - quant_layers)
|
||||
|
||||
|
||||
class AWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
if input_size_per_partition % group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size."
|
||||
)
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size."
|
||||
)
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
scales = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# num_tokens >= threshold
|
||||
# FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = False
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
||||
pack_factor, group_size=self.quant_config.group_size)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
968
vllm/model_executor/layers/quantization/awq_marlin.py
Normal file
968
vllm/model_executor/layers/quantization/awq_marlin.py
Normal file
@@ -0,0 +1,968 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.nn import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear,
|
||||
awq_to_marlin_zero_points,
|
||||
check_marlin_supported,
|
||||
check_marlin_supports_layer,
|
||||
check_moe_marlin_supports_layer,
|
||||
get_marlin_input_dtype,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
moe_awq_to_marlin_zero_points,
|
||||
verify_marlin_supported,
|
||||
verify_marlin_supports_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AWQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for AWQ Marlin"""
|
||||
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: list[str] | None,
|
||||
full_config: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.full_config = full_config
|
||||
|
||||
if self.weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {self.weight_bits}. "
|
||||
f"Supported num_bits = {self.TYPE_MAP.keys()}"
|
||||
)
|
||||
|
||||
self.quant_type = self.TYPE_MAP[self.weight_bits]
|
||||
|
||||
verify_marlin_supported(
|
||||
self.quant_type, group_size=self.group_size, has_zp=self.zero_point
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AWQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> "QuantizationMethods":
|
||||
return "awq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
zero_point,
|
||||
lm_head_quantized,
|
||||
modules_to_not_convert,
|
||||
config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> "QuantizationMethods | None":
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
|
||||
)
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = (
|
||||
"The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "awq":
|
||||
logger.info(
|
||||
"Detected that the model can run with awq_marlin"
|
||||
", however you specified quantization=awq explicitly,"
|
||||
" so forcing awq. Use quantization=awq_marlin for"
|
||||
" faster inference"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
if is_layer_skipped(
|
||||
prefix,
|
||||
self.modules_to_not_convert,
|
||||
self.packed_modules_mapping,
|
||||
skip_with_substr=True,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if the layer is supported by AWQMarlin.
|
||||
if not check_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
|
||||
prefix,
|
||||
)
|
||||
return AWQConfig.from_config(self.full_config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
quant_method = AWQMarlinLinearMethod(self)
|
||||
quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
# if is_layer_skipped(
|
||||
# prefix,
|
||||
# getattr(self, "modules_to_not_convert", []),
|
||||
# skip_with_substr=True,
|
||||
# ):
|
||||
# return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
# if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
# logger.warning_once(
|
||||
# f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
||||
# "Falling back to Moe WNA16 kernels."
|
||||
# )
|
||||
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
# layer, prefix
|
||||
# )
|
||||
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
zero_point = quant_config.get("zero_point")
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
if quant_method != "awq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if num_bits is None or group_size is None or zero_point is None:
|
||||
return False
|
||||
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(
|
||||
quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point
|
||||
)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_to_not_convert:
|
||||
self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_to_not_convert = list(layers - quant_layers)
|
||||
|
||||
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.quant_type = scalar_types.uint4
|
||||
self.input_dtype = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
layer.num_groups = num_groups
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
scales = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.num_groups = num_groups
|
||||
|
||||
# TODO: Update this docs
|
||||
# Checkpoints are serialized in AutoAWQ format, which is different from the
|
||||
# marlin format. This function is called after the weights are loaded.
|
||||
# Here, we handle the repacking
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.output_size_per_partition = layer.qweight.data.shape[1] * self.quant_config.pack_factor
|
||||
align_bits = 64 * 8
|
||||
align_size = align_bits // self.quant_config.weight_bits
|
||||
if layer.output_size_per_partition % align_size != 0:
|
||||
padding_output_size_per_partition = (layer.output_size_per_partition + align_size - 1) // align_size * align_size
|
||||
layer.output_padding_size = padding_output_size_per_partition - layer.output_size_per_partition
|
||||
device = layer.qweight.device
|
||||
|
||||
pad_qweight = torch.zeros(
|
||||
layer.input_size_per_partition,
|
||||
padding_output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
pad_qzeros = torch.zeros(
|
||||
layer.num_groups,
|
||||
padding_output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
pad_scales = torch.zeros(
|
||||
layer.num_groups,
|
||||
padding_output_size_per_partition,
|
||||
dtype=layer.scales.data.dtype,
|
||||
device=device,
|
||||
)
|
||||
pad_qweight[..., :layer.output_size_per_partition // self.quant_config.pack_factor] = layer.qweight.data
|
||||
pad_qzeros[..., :layer.output_size_per_partition // self.quant_config.pack_factor] = layer.qzeros.data
|
||||
pad_scales[..., :layer.output_size_per_partition] = layer.scales.data
|
||||
replace_parameter(layer, "qweight", pad_qweight)
|
||||
replace_parameter(layer, "qzeros", pad_qzeros)
|
||||
replace_parameter(layer, "scales", pad_scales)
|
||||
return
|
||||
# TODO(gyf) Marlin format is not support for now..
|
||||
device = layer.qweight.device
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if self.input_dtype == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True)
|
||||
layer.scales.data = layer.scales.data * 512
|
||||
|
||||
# Repack weights from AWQ format to marlin format.
|
||||
marlin_qweight = ops.awq_marlin_repack(
|
||||
layer.qweight,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from AWQ format to marlin format.
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.scales,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups > 1:
|
||||
marlin_scales, input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"input_global_scale", Parameter(input_global_scale, requires_grad=False)
|
||||
)
|
||||
|
||||
replace_parameter(layer, "scales", marlin_scales)
|
||||
|
||||
# Permute zero-points from AWQ format to marlin format.
|
||||
marlin_zp = awq_to_marlin_zero_points(
|
||||
layer.qzeros,
|
||||
size_k=layer.num_groups,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "qzeros", marlin_zp)
|
||||
|
||||
# Not-used
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# return apply_awq_marlin_linear(
|
||||
# input=x,
|
||||
# weight=layer.qweight,
|
||||
# weight_scale=layer.scales,
|
||||
# weight_zp=layer.qzeros,
|
||||
# g_idx=layer.g_idx,
|
||||
# g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
# workspace=layer.workspace,
|
||||
# quant_type=self.quant_config.quant_type,
|
||||
# output_size_per_partition=layer.output_size_per_partition,
|
||||
# input_size_per_partition=layer.input_size_per_partition,
|
||||
# input_global_scale=getattr(layer, "input_global_scale", None),
|
||||
# bias=bias,
|
||||
# input_dtype=self.input_dtype,
|
||||
# )
|
||||
# TODO use awq kernel temporarily..
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
||||
pack_factor, group_size=self.quant_config.group_size)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: AWQMarlinConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.weight_bits != 4:
|
||||
raise ValueError("AWQMarlinMoEMethod only supports 4bit now.")
|
||||
self.quant_type = scalar_types.uint4
|
||||
self.input_dtype = None
|
||||
self.use_marlin = True
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.input_dtype = self.input_dtype
|
||||
extra_weight_attrs.update(
|
||||
{
|
||||
"is_transposed": True,
|
||||
"quant_method": FusedMoeWeightScaleSupported.GROUP.value,
|
||||
}
|
||||
)
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop(
|
||||
"intermediate_size_full", intermediate_size_per_partition
|
||||
)
|
||||
self.is_k_full = intermediate_size_per_partition == intermediate_size_full
|
||||
|
||||
w13_qweight = Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
w2_qweight = Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
num_groups_w13 = hidden_size // self.quant_config.group_size
|
||||
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
|
||||
layer.num_groups_w13 = num_groups_w13
|
||||
layer.num_groups_w2 = num_groups_w2
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_scales = Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
intermediate_size_per_partition * 2,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = Parameter(
|
||||
torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_ZERO_POINT
|
||||
# Allocate 2 zero points for w1 and w3 respectively.
|
||||
w13_qzeros = Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
# TODO(gyf) Marlin format is not support for now..
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
device = layer.w13_qweight.device
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if self.input_dtype == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(
|
||||
layer.w13_qweight.view(-1, layer.w13_qweight.size(2)),
|
||||
layer.w13_qzeros.view(-1, layer.w13_qzeros.size(2)),
|
||||
inplace=True,
|
||||
)
|
||||
ops.marlin_int4_fp8_preprocess(
|
||||
layer.w2_qweight.view(-1, layer.w2_qweight.size(2)),
|
||||
layer.w2_qzeros.view(-1, layer.w2_qzeros.size(2)),
|
||||
inplace=True,
|
||||
)
|
||||
layer.w13_scales.data = layer.w13_scales.data * 512
|
||||
layer.w2_scales.data = layer.w2_scales.data * 512
|
||||
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
size_k=layer.w13_qweight.shape[1],
|
||||
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
|
||||
marlin_w2_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
size_k=layer.w2_qweight.shape[1],
|
||||
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
# The modular kernel expects w13_weight and w2_weight,
|
||||
# but AWQ uses w13_qweight and w2_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w13_weight = layer.w13_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w2_weight = layer.w2_qweight
|
||||
|
||||
# Why does this take the intermediate size for size_k?
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
|
||||
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w13_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_input_global_scale",
|
||||
Parameter(w13_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
|
||||
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w2_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_global_scale",
|
||||
Parameter(w2_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
marlin_w13_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w13_qzeros,
|
||||
size_k=layer.w13_qzeros.shape[1],
|
||||
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
|
||||
|
||||
marlin_w2_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w2_qzeros,
|
||||
size_k=layer.w2_qzeros.shape[1],
|
||||
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
||||
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
awq_marlin_moe_quant_config,
|
||||
)
|
||||
|
||||
return awq_marlin_moe_quant_config(
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
weight_bits=self.quant_config.weight_bits,
|
||||
group_size=self.quant_config.group_size,
|
||||
w1_zp=getattr(layer, "w13_qzeros", None)
|
||||
if self.quant_config.zero_point
|
||||
else None,
|
||||
w2_zp=getattr(layer, "w2_qzeros", None)
|
||||
if self.quant_config.zero_point
|
||||
else None,
|
||||
w1_bias=getattr(layer, "w13_bias", None),
|
||||
w2_bias=getattr(layer, "w2_bias", None),
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize,
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
"""
|
||||
Select the GEMM implementation for AWQ-Marlin MoE.
|
||||
Returns MarlinExperts configured for AWQ quantization.
|
||||
This is ONLY used when LoRA is enabled.
|
||||
Without LoRA, AWQ uses its own apply() method.
|
||||
"""
|
||||
# Only use modular kernels when LoRA is enabled
|
||||
# Without LoRA, AWQ's own apply() method works fine and is more efficient
|
||||
if not self.moe.is_lora_enabled:
|
||||
raise NotImplementedError(
|
||||
"AWQ-Marlin uses its own apply() method when LoRA is not enabled. "
|
||||
"Modular kernels are only used for LoRA support."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
# Ensure quant config is initialized
|
||||
assert self.moe_quant_config is not None, (
|
||||
"moe_quant_config must be initialized before select_gemm_impl"
|
||||
)
|
||||
|
||||
w13_g_idx = getattr(layer, "w13_g_idx", None)
|
||||
w2_g_idx = getattr(layer, "w2_g_idx", None)
|
||||
w13_g_idx_sort_indices = getattr(layer, "w13_g_idx_sort_indices", None)
|
||||
w2_g_idx_sort_indices = getattr(layer, "w2_g_idx_sort_indices", None)
|
||||
|
||||
# Check if using batched expert format (for Expert Parallelism)
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
# For batched format, use BatchedMarlinExperts
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
return BatchedMarlinExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
else:
|
||||
# Standard Marlin experts for AWQ
|
||||
return MarlinExperts(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# return fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# getattr(layer, "w13_bias", None),
|
||||
# getattr(layer, "w2_bias", None),
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
# input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
# quant_type_id=self.quant_type.id,
|
||||
# apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
# global_num_experts=layer.global_num_experts,
|
||||
# expert_map=layer.expert_map,
|
||||
# w1_zeros=layer.w13_qzeros,
|
||||
# w2_zeros=layer.w2_qzeros,
|
||||
# workspace=layer.workspace,
|
||||
# input_dtype=self.input_dtype,
|
||||
# inplace=not self.moe.disable_inplace,
|
||||
# )
|
||||
|
||||
num_tokens, num_experts = router_logits.shape
|
||||
if use_ep:
|
||||
hidden_size = x.shape[1]
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
expand_tokens,
|
||||
) = ixfops.moe_compute_token_index_ep(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
start_expert_id=start_eid,
|
||||
end_expert_id=end_eid,
|
||||
)
|
||||
if expert_sizes_cpu.sum() == 0:
|
||||
return torch.zeros(
|
||||
(num_tokens, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
else:
|
||||
expand_tokens = num_tokens * top_k
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
) = ixfops.moe_compute_token_index(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
|
||||
# expand + reorder
|
||||
# TODO use kernel
|
||||
expand_hidden_states = ixfops.moe_expand_input(
|
||||
hidden_states=x,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
|
||||
# w4a16 group gemm 1
|
||||
# pt_output_1: (expand_tokens, 2n) dtype
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=routed_scaling_factor
|
||||
)
|
||||
return final_hidden_states
|
||||
337
vllm/model_executor/layers/quantization/awq_triton.py
Normal file
337
vllm/model_executor/layers/quantization/awq_triton.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def awq_dequantize_kernel(
|
||||
qweight_ptr, # quantized matrix
|
||||
scales_ptr, # scales, per group
|
||||
zeros_ptr, # zeros, per group
|
||||
group_size, # Should always be one of the supported group sizes
|
||||
result_ptr, # Output matrix
|
||||
num_cols, # input num cols in qweight
|
||||
num_rows, # input num rows in qweight
|
||||
BLOCK_SIZE_X: tl.constexpr,
|
||||
BLOCK_SIZE_Y: tl.constexpr,
|
||||
):
|
||||
# Set up the pids.
|
||||
pid_x = tl.program_id(axis=0)
|
||||
pid_y = tl.program_id(axis=1)
|
||||
|
||||
# Compute offsets and masks for qweight_ptr.
|
||||
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
|
||||
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
||||
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
|
||||
|
||||
masks_y = offsets_y < num_rows
|
||||
masks_x = offsets_x < num_cols
|
||||
|
||||
masks = masks_y[:, None] & masks_x[None, :]
|
||||
|
||||
# Compute offsets and masks for result output ptr.
|
||||
result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
|
||||
result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)
|
||||
result_offsets = (
|
||||
8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :]
|
||||
)
|
||||
|
||||
result_masks_y = result_offsets_y < num_rows
|
||||
result_masks_x = result_offsets_x < num_cols * 8
|
||||
result_masks = result_masks_y[:, None] & result_masks_x[None, :]
|
||||
|
||||
# Load the weights.
|
||||
iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
|
||||
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
# that will map given indices to the correct order.
|
||||
reverse_awq_order_tensor = (
|
||||
(tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
|
||||
).reshape(8)
|
||||
|
||||
# Use this to compute a set of shifts that can be used to unpack and
|
||||
# reorder the values in iweights and zeros.
|
||||
shifts = reverse_awq_order_tensor * 4
|
||||
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
|
||||
shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Unpack and reorder: shift out the correct 4-bit value and mask.
|
||||
iweights = (iweights >> shifts) & 0xF
|
||||
|
||||
# Compute zero offsets and masks.
|
||||
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
||||
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
||||
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
|
||||
|
||||
zero_masks_y = zero_offsets_y < num_rows // group_size
|
||||
zero_masks_x = zero_offsets_x < num_cols
|
||||
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
|
||||
|
||||
# Load the zeros.
|
||||
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Unpack and reorder: shift out the correct 4-bit value and mask.
|
||||
zeros = (zeros >> shifts) & 0xF
|
||||
|
||||
# Compute scale offsets and masks.
|
||||
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
||||
scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)
|
||||
scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :]
|
||||
scale_masks_y = scale_offsets_y < num_rows // group_size
|
||||
scale_masks_x = scale_offsets_x < num_cols * 8
|
||||
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
|
||||
|
||||
# Load the scales.
|
||||
scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Dequantize.
|
||||
iweights = (iweights - zeros) * scales
|
||||
iweights = iweights.to(result_ptr.type.element_ty)
|
||||
|
||||
# Finally, store.
|
||||
tl.store(result_ptr + result_offsets, iweights, result_masks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def awq_gemm_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
zeros_ptr,
|
||||
scales_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
group_size,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
pid_z = tl.program_id(1)
|
||||
|
||||
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
|
||||
# num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
accumulator_dtype = c_ptr.type.element_ty
|
||||
|
||||
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
|
||||
# accumulator = tl.arange(0, BLOCK_SIZE_N)
|
||||
# accumulator = tl.broadcast_to(accumulator[None, :],
|
||||
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
|
||||
# accumulator = accumulator & 0x0
|
||||
# accumulator = accumulator.to(accumulator_dtype)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
|
||||
|
||||
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
# that will map given indices to the correct order.
|
||||
reverse_awq_order_tensor = (
|
||||
(tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
|
||||
).reshape(8)
|
||||
|
||||
# Create the necessary shifts to use to unpack.
|
||||
shifts = reverse_awq_order_tensor * 4
|
||||
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
|
||||
shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
# Offsets and masks.
|
||||
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
masks_am = offsets_am < M
|
||||
|
||||
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
||||
masks_bn = offsets_bn < N // 8
|
||||
|
||||
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
||||
masks_zn = offsets_zn < N // 8
|
||||
|
||||
offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
masks_sn = offsets_sn < N
|
||||
|
||||
offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
|
||||
offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]
|
||||
|
||||
a_ptrs = a_ptr + offsets_a
|
||||
b_ptrs = b_ptr + offsets_b
|
||||
|
||||
# NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
|
||||
# block_offset = BLOCK_SIZE_K * SPLIT_K
|
||||
# for k in range(0, (K + block_offset - 1) // (block_offset)):
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
|
||||
masks_k = offsets_k < K
|
||||
masks_a = masks_am[:, None] & masks_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=masks_a, other=0.0)
|
||||
|
||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
|
||||
# Dequantize b.
|
||||
offsets_szk = (
|
||||
BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K
|
||||
) // group_size + tl.arange(0, 1)
|
||||
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
|
||||
masks_zk = offsets_szk < K // group_size
|
||||
masks_z = masks_zk[:, None] & masks_zn[None, :]
|
||||
zeros_ptrs = zeros_ptr + offsets_z
|
||||
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
|
||||
masks_sk = offsets_szk < K // group_size
|
||||
masks_s = masks_sk[:, None] & masks_sn[None, :]
|
||||
scales_ptrs = scales_ptr + offsets_s
|
||||
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
b = (b >> shifts) & 0xF
|
||||
zeros = (zeros >> shifts) & 0xF
|
||||
b = (b - zeros) * scales
|
||||
b = b.to(c_ptr.type.element_ty)
|
||||
|
||||
# Accumulate results.
|
||||
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
||||
|
||||
offsets_k += BLOCK_SIZE_K * SPLIT_K
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)
|
||||
|
||||
c = accumulator.to(c_ptr.type.element_ty)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# qweights - [K , M // 8], int32
|
||||
# scales - [K // G, M ], float16
|
||||
# zeros - [K // G, M // 8], int32
|
||||
def awq_dequantize_triton(
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
block_size_x: int = 32,
|
||||
block_size_y: int = 32,
|
||||
) -> torch.Tensor:
|
||||
K = qweight.shape[0]
|
||||
M = scales.shape[1]
|
||||
group_size = qweight.shape[0] // scales.shape[0]
|
||||
|
||||
assert K > 0 and M > 0
|
||||
assert scales.shape[0] == K // group_size and scales.shape[1] == M
|
||||
assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
|
||||
assert group_size <= K
|
||||
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
|
||||
|
||||
# Result tensor:
|
||||
# number of rows = same as input tensor
|
||||
# number of cols = 8 x input tensor num cols
|
||||
result = torch.empty(
|
||||
qweight.shape[0],
|
||||
qweight.shape[1] * 8,
|
||||
device=qweight.device,
|
||||
dtype=scales.dtype,
|
||||
)
|
||||
|
||||
Y = qweight.shape[0] # num rows
|
||||
X = qweight.shape[1] # num cols
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(X, META["BLOCK_SIZE_X"]),
|
||||
triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
|
||||
)
|
||||
awq_dequantize_kernel[grid](
|
||||
qweight,
|
||||
scales,
|
||||
zeros,
|
||||
group_size,
|
||||
result,
|
||||
X,
|
||||
Y,
|
||||
BLOCK_SIZE_X=block_size_x,
|
||||
BLOCK_SIZE_Y=block_size_y,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# input - [M, K]
|
||||
# qweight - [K, N // 8]
|
||||
# qzeros - [K // G, N // 8]
|
||||
# scales - [K // G, N]
|
||||
# split_k_iters - parallelism along K-dimension, int, power of 2.
|
||||
def awq_gemm_triton(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: int,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
block_size_k: int = 32,
|
||||
) -> torch.Tensor:
|
||||
M, K = input.shape
|
||||
N = qweight.shape[1] * 8
|
||||
group_size = qweight.shape[0] // qzeros.shape[0]
|
||||
|
||||
assert N > 0 and K > 0 and M > 0
|
||||
assert qweight.shape[0] == K and qweight.shape[1] == N // 8
|
||||
assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
|
||||
assert scales.shape[0] == K // group_size and scales.shape[1] == N
|
||||
assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
|
||||
assert split_k_iters <= 32
|
||||
assert group_size <= K
|
||||
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
split_k_iters,
|
||||
)
|
||||
|
||||
result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device)
|
||||
|
||||
# A = input, B = qweight, C = result
|
||||
# A = M x K, B = K x N, C = M x N
|
||||
awq_gemm_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
result,
|
||||
qzeros,
|
||||
scales,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
group_size,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
SPLIT_K=split_k_iters,
|
||||
)
|
||||
|
||||
result = result.sum(0)
|
||||
|
||||
return result
|
||||
191
vllm/model_executor/layers/quantization/base_config.py
Normal file
191
vllm/model_executor/layers/quantization/base_config.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
else:
|
||||
QuantizationMethods = str
|
||||
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
|
||||
# Whether this method creates weights on meta device for online quantization.
|
||||
# When True, weights are created on meta device and quantized layer-wise
|
||||
# in process_weights_after_loading, reducing peak memory during loading.
|
||||
uses_meta_device: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs
|
||||
):
|
||||
"""Create weights for a layer.
|
||||
|
||||
The weights will be set as attributes of the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
# Not required functions
|
||||
def embedding(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||
"""Gather embeddings in the layer based on indices in the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
def method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> bool:
|
||||
"""
|
||||
Not all quant methods have embedding implemented, so we need to check that
|
||||
it exists for our given method. We check this by making sure the function
|
||||
has been changed from the base implementation.
|
||||
"""
|
||||
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None)
|
||||
class_embedding = inspect.getattr_static(method_class, "embedding", None)
|
||||
|
||||
return class_embedding is not None and class_embedding is not base_embedding
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# mapping is updated by models as they initialize
|
||||
self.packed_modules_mapping: dict[str, list[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""
|
||||
Detects if this quantization method can support a given checkpoint
|
||||
format by overriding the user specified quantization method --
|
||||
this method should only be overwritten by subclasses in exceptional
|
||||
circumstances
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(
|
||||
f"Cannot find any of {keys} in the model's quantization config."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any:
|
||||
"""Get an optional value from the model's quantization config."""
|
||||
try:
|
||||
return QuantizationConfig.get_from_keys(config, keys)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
@abstractmethod
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> QuantizeMethodBase | None:
|
||||
"""Get the quantize method to use for the quantized layer.
|
||||
|
||||
Args:
|
||||
layer: The layer for the quant method.
|
||||
prefix: The full name of the layer in the state dict
|
||||
Returns:
|
||||
The quantize method. None if the given layer doesn't support quant
|
||||
method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper( # noqa: B027
|
||||
self, hf_to_vllm_mapper: "WeightsMapper"
|
||||
):
|
||||
"""
|
||||
Interface for models to update module names referenced in
|
||||
quantization configs in order to reflect the vllm model structure
|
||||
|
||||
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
|
||||
structure of the qconfig) to vllm model structure
|
||||
"""
|
||||
# TODO (@kylesayrs): add implementations for all subclasses
|
||||
pass
|
||||
|
||||
def maybe_update_config(self, model_name: str): # noqa: B027
|
||||
"""
|
||||
Interface to update values after config initialization.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
|
||||
"""
|
||||
Determine if mxfp4 quantization will be used for this config.
|
||||
|
||||
This allows hidden_size rounding to happen before moe_config creation
|
||||
without needing to instantiate quant_method first.
|
||||
|
||||
Args:
|
||||
prefix: The layer prefix/name in the model
|
||||
layer: The layer module
|
||||
|
||||
Returns:
|
||||
True if this config uses MXFP4 quantization, False otherwise
|
||||
"""
|
||||
return False
|
||||
608
vllm/model_executor/layers/quantization/bitsandbytes.py
Normal file
608
vllm/model_executor/layers/quantization/bitsandbytes.py
Normal file
@@ -0,0 +1,608 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def _check_bitsandbytes_version():
|
||||
min_version = "0.49.2" if current_platform.is_rocm() else "0.48.1"
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if version.parse(bitsandbytes.__version__) < version.parse(min_version):
|
||||
raise ImportError(
|
||||
"bitsandbytes version is wrong. Please "
|
||||
f"install bitsandbytes>={min_version}."
|
||||
)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
f"Please install bitsandbytes>={min_version} via "
|
||||
f"`pip install bitsandbytes>={min_version}` to use "
|
||||
"bitsandbytes quantizer."
|
||||
) from err
|
||||
|
||||
|
||||
class BitsAndBytesConfig(QuantizationConfig):
|
||||
"""Config class for BitsAndBytes Quantization.
|
||||
|
||||
Reference: https://arxiv.org/abs/2305.14314
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
load_in_8bit: bool = False,
|
||||
load_in_4bit: bool = True,
|
||||
bnb_4bit_compute_dtype: str = "float32",
|
||||
bnb_4bit_quant_storage: str = "uint8",
|
||||
bnb_4bit_quant_type: str = "fp4",
|
||||
bnb_4bit_use_double_quant: bool = False,
|
||||
llm_int8_enable_fp32_cpu_offload: bool = False,
|
||||
llm_int8_has_fp16_weight: bool = False,
|
||||
llm_int8_skip_modules: list[str] | None = None,
|
||||
llm_int8_threshold: float = 6.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.load_in_4bit = load_in_4bit
|
||||
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
|
||||
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
|
||||
self.bnb_4bit_quant_type = bnb_4bit_quant_type
|
||||
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
|
||||
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
||||
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules or []
|
||||
self.llm_int8_threshold = llm_int8_threshold
|
||||
|
||||
if self.bnb_4bit_quant_storage not in ["uint8"]:
|
||||
raise ValueError(
|
||||
f"Unsupported bnb_4bit_quant_storage: {self.bnb_4bit_quant_storage}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
|
||||
f"load_in_4bit={self.load_in_4bit}, "
|
||||
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
|
||||
f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
|
||||
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
|
||||
f"llm_int8_skip_modules={self.llm_int8_skip_modules})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "bitsandbytes"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
|
||||
def get_safe_value(config, keys, default_value=None):
|
||||
try:
|
||||
value = cls.get_from_keys(config, keys)
|
||||
return value if value is not None else default_value
|
||||
except ValueError:
|
||||
return default_value
|
||||
|
||||
load_in_8bit = get_safe_value(config, ["load_in_8bit"], default_value=False)
|
||||
load_in_4bit = get_safe_value(config, ["load_in_4bit"], default_value=True)
|
||||
bnb_4bit_compute_dtype = get_safe_value(
|
||||
config, ["bnb_4bit_compute_dtype"], default_value="float32"
|
||||
)
|
||||
bnb_4bit_quant_storage = get_safe_value(
|
||||
config, ["bnb_4bit_quant_storage"], default_value="uint8"
|
||||
)
|
||||
bnb_4bit_quant_type = get_safe_value(
|
||||
config, ["bnb_4bit_quant_type"], default_value="fp4"
|
||||
)
|
||||
bnb_4bit_use_double_quant = get_safe_value(
|
||||
config, ["bnb_4bit_use_double_quant"], default_value=False
|
||||
)
|
||||
llm_int8_enable_fp32_cpu_offload = get_safe_value(
|
||||
config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False
|
||||
)
|
||||
llm_int8_has_fp16_weight = get_safe_value(
|
||||
config, ["llm_int8_has_fp16_weight"], default_value=False
|
||||
)
|
||||
llm_int8_skip_modules = get_safe_value(
|
||||
config, ["llm_int8_skip_modules"], default_value=[]
|
||||
)
|
||||
llm_int8_threshold = get_safe_value(
|
||||
config, ["llm_int8_threshold"], default_value=6.0
|
||||
)
|
||||
|
||||
return cls(
|
||||
load_in_8bit=load_in_8bit,
|
||||
load_in_4bit=load_in_4bit,
|
||||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
||||
bnb_4bit_quant_storage=bnb_4bit_quant_storage,
|
||||
bnb_4bit_quant_type=bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
|
||||
llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
|
||||
llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
|
||||
llm_int8_skip_modules=llm_int8_skip_modules,
|
||||
llm_int8_threshold=llm_int8_threshold,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Union["LinearMethodBase", "BitsAndBytesMoEMethod"] | None:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return BitsAndBytesLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return BitsAndBytesMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
||||
# Split the prefix into its dot-separated components
|
||||
components = prefix.split(".")
|
||||
|
||||
# Check if any of the skip modules exactly matches any component
|
||||
substr_check = any(
|
||||
module_name in components for module_name in llm_int8_skip_modules
|
||||
)
|
||||
|
||||
# Allow certain layers to not be quantized
|
||||
set_components = set(".".join(components[: i + 1]) for i in range(len(components)))
|
||||
set_llm_int8_skip_modules = set(llm_int8_skip_modules)
|
||||
prefix_check = len(set_llm_int8_skip_modules & set_components) != 0
|
||||
|
||||
return substr_check or prefix_check
|
||||
|
||||
|
||||
def calculate_quant_ratio(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
|
||||
|
||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitsAndBytes.
|
||||
|
||||
Args:
|
||||
quant_config: The BitsAndBytes quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: BitsAndBytesConfig):
|
||||
_check_bitsandbytes_version()
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
def create_qweight_for_8bit():
|
||||
qweight = Int8Params(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight,
|
||||
{
|
||||
"input_dim": 0,
|
||||
"output_dim": 0,
|
||||
"pack_factor": 1,
|
||||
"use_bitsandbytes_8bit": True,
|
||||
"generation": 0,
|
||||
},
|
||||
)
|
||||
return qweight
|
||||
|
||||
def create_qweight_for_4bit():
|
||||
quant_ratio = calculate_quant_ratio(params_dtype)
|
||||
|
||||
total_size = input_size_per_partition * sum(output_partition_sizes)
|
||||
if total_size % quant_ratio != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized weight shape."
|
||||
)
|
||||
|
||||
qweight = torch.nn.Parameter(
|
||||
torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight,
|
||||
{
|
||||
"input_dim": 0,
|
||||
"output_dim": 0,
|
||||
"pack_factor": quant_ratio,
|
||||
"use_bitsandbytes_4bit": True,
|
||||
},
|
||||
)
|
||||
return qweight
|
||||
|
||||
if self.quant_config.load_in_8bit:
|
||||
qweight = create_qweight_for_8bit()
|
||||
else:
|
||||
qweight = create_qweight_for_4bit()
|
||||
# Enable parameters to have the same name as in the BNB
|
||||
# checkpoint format.
|
||||
layer.register_parameter("weight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.load_in_8bit:
|
||||
return self._apply_8bit_weight(layer, x, bias)
|
||||
else:
|
||||
return self._apply_4bit_weight(layer, x, bias)
|
||||
|
||||
def _apply_8bit_weight(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import MatmulLtState, matmul
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.weight
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
quant_states = qweight.bnb_quant_state
|
||||
matmul_states = qweight.matmul_state
|
||||
generation = qweight.generation
|
||||
|
||||
out_dim_0 = x.shape[0]
|
||||
out_dim_1 = sum(
|
||||
[quant_state[1].shape[0] for quant_state in quant_states.items()]
|
||||
)
|
||||
out = torch.empty(out_dim_0, out_dim_1, dtype=torch.float16, device=x.device)
|
||||
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
|
||||
# in profile_run or the first generation of inference,
|
||||
# create new matmul_states
|
||||
if generation == 0 or generation == 1:
|
||||
matmul_states[i] = MatmulLtState()
|
||||
matmul_states[i].CB = qweight[offsets[i] : offsets[i + 1]]
|
||||
matmul_states[i].SCB = quant_states[i].to(x.device)
|
||||
matmul_states[i].threshold = self.quant_config.llm_int8_threshold
|
||||
matmul_states[
|
||||
i
|
||||
].has_fp16_weights = self.quant_config.llm_int8_has_fp16_weight
|
||||
matmul_states[i].is_training = False
|
||||
if (
|
||||
matmul_states[i].threshold > 0.0
|
||||
and not matmul_states[i].has_fp16_weights
|
||||
):
|
||||
matmul_states[i].use_pool = True
|
||||
|
||||
new_x = bf_x.unsqueeze(0)
|
||||
|
||||
out[:, current_index : current_index + output_size] = matmul(
|
||||
new_x, qweight[offsets[i] : offsets[i + 1]], state=matmul_states[i]
|
||||
)
|
||||
|
||||
current_index += output_size
|
||||
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
qweight.generation += 1
|
||||
|
||||
return out
|
||||
|
||||
def _apply_4bit_weight(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.weight
|
||||
quant_states = qweight.bnb_quant_state
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
|
||||
out_dim_0 = x.shape[0]
|
||||
out_dim_1 = sum(
|
||||
[quant_state[1].shape[0] for quant_state in quant_states.items()]
|
||||
)
|
||||
out = torch.empty(out_dim_0, out_dim_1, dtype=torch.bfloat16, device=x.device)
|
||||
apply_bnb_4bit(bf_x, qweight, offsets, out)
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _apply_bnb_4bit(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import matmul_4bit
|
||||
|
||||
quant_states = weight.bnb_quant_state
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
# It is more efficient to use out kwarg like
|
||||
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
||||
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
||||
# Need to change after the bug is fixed.
|
||||
out[:, current_index : current_index + output_size] = matmul_4bit(
|
||||
x, weight[offsets[i] : offsets[i + 1]].t(), quant_states[i]
|
||||
)
|
||||
current_index += output_size
|
||||
|
||||
|
||||
def _apply_bnb_4bit_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="apply_bnb_4bit",
|
||||
op_func=_apply_bnb_4bit,
|
||||
mutates_args=["out"],
|
||||
fake_impl=_apply_bnb_4bit_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for BitsAndBytes.
|
||||
|
||||
Args:
|
||||
quant_config: The BitsAndBytes quantization config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: BitsAndBytesConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
_check_bitsandbytes_version()
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if self.quant_config.load_in_8bit:
|
||||
call_fun = self._create_weights_8bit
|
||||
else:
|
||||
call_fun = self._create_weights_4bit
|
||||
call_fun(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
# TODO(bnell): Do these need to be called on the hot path?
|
||||
if self.quant_config.load_in_8bit:
|
||||
w13, w2 = self._apply_8bit_dequant(layer)
|
||||
else:
|
||||
w13, w2 = self._apply_4bit_dequnt(layer)
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=w13,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
def _create_weights_4bit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
quant_ratio = calculate_quant_ratio(params_dtype)
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_total_size = (
|
||||
hidden_size * 2 * intermediate_size_per_partition
|
||||
) // quant_ratio
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w13_total_size,
|
||||
1,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
set_weight_attrs(
|
||||
w13_qweight,
|
||||
{
|
||||
"num_experts": num_experts,
|
||||
"input_dim": hidden_size,
|
||||
"output_dim": 2 * intermediate_size_per_partition,
|
||||
"experts_shape": (
|
||||
num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size,
|
||||
),
|
||||
"pack_factor": quant_ratio,
|
||||
"use_bitsandbytes_4bit": True,
|
||||
},
|
||||
)
|
||||
# down_proj (row parallel)
|
||||
w2_total_size = (hidden_size * intermediate_size_per_partition) // quant_ratio
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w2_total_size,
|
||||
1,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_qweight,
|
||||
{
|
||||
"num_experts": num_experts,
|
||||
"input_dim": intermediate_size_per_partition,
|
||||
"output_dim": hidden_size,
|
||||
"experts_shape": (
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
),
|
||||
"pack_factor": quant_ratio,
|
||||
"use_bitsandbytes_4bit": True,
|
||||
},
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
def _create_weights_8bit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def _apply_4bit_dequnt(
|
||||
self, layer: torch.nn.Module
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
|
||||
w13 = dequantize_4bit(
|
||||
layer.w13_weight.reshape(-1, 1),
|
||||
layer.w13_weight.bnb_quant_state,
|
||||
)
|
||||
w2 = dequantize_4bit(
|
||||
layer.w2_weight.reshape(-1, 1),
|
||||
layer.w2_weight.bnb_quant_state,
|
||||
)
|
||||
w13 = w13.reshape(layer.w13_weight.experts_shape)
|
||||
w2 = w2.reshape(layer.w2_weight.experts_shape)
|
||||
return w13, w2
|
||||
|
||||
def _apply_8bit_dequant(
|
||||
self, layer: torch.nn.Module
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
|
||||
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8, CompressedTensorsW4A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
|
||||
|
||||
# This avoids circular import error
|
||||
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS",
|
||||
"CompressedTensors24",
|
||||
"CompressedTensorsW4A16Fp4",
|
||||
"CompressedTensorsW4A16Mxfp4",
|
||||
"CompressedTensorsW4A4Fp4",
|
||||
"CompressedTensorsW4A8Int",
|
||||
"CompressedTensorsW4A8Fp8",
|
||||
"CompressedTensorsW4A8Int8"
|
||||
]
|
||||
@@ -0,0 +1,392 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat, ModelCompressor
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType,
|
||||
)
|
||||
from compressed_tensors.utils import combine_shards
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
sparse_cutlass_supported,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensors24"]
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class CompressedTensors24(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self,
|
||||
quantized: bool = False,
|
||||
weight_quant: QuantizationArgs | None = None,
|
||||
input_quant: QuantizationArgs | None = None,
|
||||
model_compression_config: dict[str, Any] | None = None,
|
||||
):
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
model_compressor = ModelCompressor.from_compression_config(
|
||||
model_compression_config
|
||||
)
|
||||
self.do_sparse_decompress = (
|
||||
model_compressor is not None
|
||||
and model_compressor.sparsity_config.format
|
||||
== CompressionFormat.sparse_24_bitmask.value
|
||||
)
|
||||
if self.do_sparse_decompress:
|
||||
self.model_compressor = model_compressor
|
||||
|
||||
if (
|
||||
quantized
|
||||
and input_quant is not None
|
||||
and self._get_quant_dtype() == current_platform.fp8_dtype()
|
||||
):
|
||||
static = not input_quant.dynamic
|
||||
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.quant_fp8 = QuantFP8(static, g_shape)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Only cutlass 3.x kernels are implemented so far
|
||||
return 90
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
if not sparse_cutlass_supported():
|
||||
raise ValueError(
|
||||
"Sparse CUTLASS not supported. vLLM must be built with "
|
||||
"CUDA 12.2 or later to use this feature"
|
||||
)
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size = input_size
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
|
||||
|
||||
# parameter to store uncompressed weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.do_sparse_decompress:
|
||||
assert all(
|
||||
partition_size % 8 == 0 for partition_size in output_partition_sizes
|
||||
), "All partitions must be divisible by 8 for "
|
||||
"2:4 sparse compressed models"
|
||||
|
||||
shape = BasevLLMParameter(
|
||||
data=torch.empty(2, 1, dtype=torch.int64),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
compressed_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
bitmask = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 8,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("shape", shape)
|
||||
layer.register_parameter("compressed", compressed_weight)
|
||||
layer.register_parameter("bitmask", bitmask)
|
||||
|
||||
# Check if quantized, not just 2:4 Sparse
|
||||
if self.quantized:
|
||||
if (
|
||||
self.weight_quant
|
||||
and self.weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
):
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(sum(output_partition_sizes), 1), dtype=torch.float32
|
||||
),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.weight_quant
|
||||
and self.weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
)
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# input quant will be non-none
|
||||
if self.input_quant and not self.input_quant.dynamic:
|
||||
# register input quant scale
|
||||
assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
else:
|
||||
# for sparse-only, pass in 1 for weight/input scales
|
||||
weight_scale = torch.nn.Parameter(
|
||||
data=torch.ones(1, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
input_scale = torch.nn.Parameter(
|
||||
data=torch.ones(1, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""
|
||||
Compress weights after loading. Store compressed weight and meta
|
||||
tensor
|
||||
|
||||
:post-condition: layer.w_compressed and layer.meta are
|
||||
set to the compressed weight and meta tensor in the
|
||||
format expected by the Cutlass kernels
|
||||
:param layer: The layer with the weights to be processed
|
||||
|
||||
"""
|
||||
if self.do_sparse_decompress:
|
||||
layer.weight.data = self._decompress_bitmask_compressed_weight(
|
||||
compressed=layer.compressed,
|
||||
bitmask=layer.bitmask,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
# compressed and bitmask tensors
|
||||
# are no longer needed after decompression
|
||||
del layer.compressed
|
||||
del layer.bitmask
|
||||
|
||||
# torch.compile workaround
|
||||
if hasattr(layer, "input_scale"):
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
if self.weight_quant:
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
convert_to_channelwise(
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
# torch.compile workaround
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
# Set all negative zero values to 0 prior to compression
|
||||
if layer.weight.dtype.is_floating_point and layer.weight.dtype.itemsize >= 2:
|
||||
layer.weight.data[layer.weight.data == -0.0] = 0.0
|
||||
|
||||
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
|
||||
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
|
||||
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the output tensor for the layer with 2:4
|
||||
sparse compressed weights, given the input tensor
|
||||
and bias
|
||||
|
||||
:param layer: The layer with 2:4 sparse compressed
|
||||
weights to be used for the computation
|
||||
:param x: The input tensor to the layer
|
||||
:param bias: The bias to be added to the output tensor
|
||||
:return: The output tensor of the layer
|
||||
"""
|
||||
if self.quantized:
|
||||
scale = getattr(layer, "input_scale", None)
|
||||
|
||||
if self.weights_dtype == torch.int8:
|
||||
ops_output = ops.scaled_int8_quant(x, scale=scale)
|
||||
q_input = ops_output[0]
|
||||
input_scale = ops_output[1]
|
||||
else:
|
||||
assert self.weights_dtype == torch.float8_e4m3fn
|
||||
q_input, input_scale = self.quant_fp8(x, scale=scale)
|
||||
|
||||
else:
|
||||
# Not quantized, nothing to do with the input_scales, use as is
|
||||
input_scale = layer.input_scale
|
||||
q_input = x
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a=q_input,
|
||||
bt_nzs=layer.weight,
|
||||
bt_meta=layer.meta,
|
||||
scale_a=input_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
assert out.is_contiguous()
|
||||
return out
|
||||
|
||||
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
|
||||
if not self.quantized:
|
||||
return params_dtype
|
||||
return self._get_quant_dtype()
|
||||
|
||||
def _get_quant_dtype(self) -> torch.dtype:
|
||||
assert self.quantized
|
||||
assert self.weight_quant is not None
|
||||
assert self.input_quant is not None
|
||||
|
||||
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
|
||||
|
||||
if not is_8_bits:
|
||||
raise ValueError("Cutlass only supports 8-bit quantization")
|
||||
|
||||
if (
|
||||
self.weight_quant.type == QuantizationType.FLOAT
|
||||
and self.input_quant.type == QuantizationType.FLOAT
|
||||
):
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
if (
|
||||
self.weight_quant.type == QuantizationType.INT
|
||||
and self.input_quant.type == QuantizationType.INT
|
||||
):
|
||||
return torch.int8
|
||||
|
||||
raise ValueError("Quantization type not supported by Cutlass")
|
||||
|
||||
def _decompress_bitmask_compressed_weight(
|
||||
self,
|
||||
compressed: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
|
||||
return the result.
|
||||
|
||||
This function also supports sharded decompression.
|
||||
|
||||
:param compressed: The 2:4 sparse weight tensor compressed using the
|
||||
sparse-24-bitmask compressor. This is different from
|
||||
`cutlass_sparse_compress` which uses a different scheme (2 bits for
|
||||
every nonzero element that represent the coordinate within the block
|
||||
of 4). The bitmask compression here uses a bitmask to indicate the
|
||||
positions of non-zero elements.
|
||||
:param bitmask: The 2:4 bitmask associated with the compressed weights,
|
||||
representing the positions of non-zero elements in the compressed
|
||||
tensor.
|
||||
:param layer: The layer whose weights need to be processed after
|
||||
loading.
|
||||
:return: The decompressed 2:4 sparse weight tensor.
|
||||
"""
|
||||
|
||||
sparsity_compressor = self.model_compressor.sparsity_compressor
|
||||
|
||||
def _process_split(
|
||||
bitmask_compressed_weight: torch.Tensor,
|
||||
shape,
|
||||
bitmask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
weight_data = dict(
|
||||
compressed=bitmask_compressed_weight,
|
||||
shape=shape,
|
||||
bitmask=bitmask,
|
||||
)
|
||||
return sparsity_compressor.decompress_weight(weight_data)
|
||||
|
||||
split_weights: list[torch.Tensor] = []
|
||||
split_bitmask: list[torch.Tensor] = []
|
||||
split_shape: list[tuple[int, int]] = []
|
||||
|
||||
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
split_weights = torch.split(compressed, layer.logical_widths)
|
||||
split_bitmask = torch.split(bitmask, layer.logical_widths)
|
||||
split_shape = [
|
||||
(out, layer.input_size_per_partition) for out in layer.logical_widths
|
||||
]
|
||||
|
||||
if split_weights:
|
||||
decompressed_shards = [
|
||||
_process_split(compressed_weight, shape, bitmask)
|
||||
for compressed_weight, shape, bitmask in zip(
|
||||
split_weights, split_shape, split_bitmask
|
||||
)
|
||||
]
|
||||
decompressed = combine_shards(decompressed_shards)
|
||||
else:
|
||||
decompressed = sparsity_compressor.decompress_weight(
|
||||
dict(
|
||||
compressed=compressed,
|
||||
shape=(
|
||||
layer.logical_widths[0],
|
||||
layer.input_size_per_partition,
|
||||
),
|
||||
bitmask=bitmask,
|
||||
)
|
||||
)
|
||||
return decompressed
|
||||
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["CompressedTensorsScheme"]
|
||||
|
||||
|
||||
class CompressedTensorsScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by CompressedTensors.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear,
|
||||
prepare_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Mxfp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
|
||||
"""
|
||||
Compressed tensors scheme for MXFP4 weight-only quantization.
|
||||
|
||||
Supports models quantized with the compressed-tensors mxfp4-pack-quantized
|
||||
format.
|
||||
|
||||
MXFP4 format:
|
||||
- 4-bit float weights (E2M1) packed into uint8
|
||||
- Per-group E8M0 scales with group_size=32
|
||||
- No global scale (unlike NVFP4)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.group_size = 32
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.params_dtype = params_dtype
|
||||
|
||||
# Packed FP4 weights (2 values per byte)
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Per-group E8M0 scales
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Rename weight_packed to weight that marlin expects
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
del layer.weight_packed
|
||||
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_global_scale=None,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear,
|
||||
prepare_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
def __init__(self):
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# don't restrict as emulations
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Process parameters for marlin repacking
|
||||
|
||||
# Rename weight_packed to weight that marlin expects
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
del layer.weight_packed
|
||||
# ct stores the inverse of what is expected by the marlin kernel
|
||||
layer.weight_global_scale = Parameter(
|
||||
1.0 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_global_scale=layer.weight_global_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
apply_nvfp4_linear,
|
||||
convert_to_nvfp4_linear_kernel_format,
|
||||
select_nvfp4_linear_backend,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
def __init__(self):
|
||||
self.backend = select_nvfp4_linear_backend()
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Rename CT checkpoint names to standardized names
|
||||
layer.weight = layer.weight_packed
|
||||
del layer.weight_packed
|
||||
# Process global scales (CT stores as divisors, i.e. 1/scale)
|
||||
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(
|
||||
(1.0 / input_global_scale_inv).to(torch.float32), requires_grad=False
|
||||
)
|
||||
weight_global_scale = layer.weight_global_scale.max().to(torch.float32)
|
||||
layer.weight_global_scale = Parameter(
|
||||
1.0 / weight_global_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# Pre-compute alpha and inverse for runtime quantization
|
||||
layer.input_global_scale_inv = Parameter(
|
||||
input_global_scale_inv, requires_grad=False
|
||||
)
|
||||
layer.alpha = Parameter(
|
||||
layer.input_global_scale * layer.weight_global_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# Convert layer to NVFP4 linear kernel format
|
||||
convert_to_nvfp4_linear_kernel_format(self.backend, layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_nvfp4_linear(
|
||||
backend=self.backend,
|
||||
layer=layer,
|
||||
x=x,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A8Fp8"]
|
||||
W4A8_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.int4,
|
||||
}
|
||||
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: int | None = None,
|
||||
symmetric: bool | None = True,
|
||||
actorder: ActivationOrdering | None = None,
|
||||
):
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
|
||||
if self.group_size != 128 or self.strategy != "group":
|
||||
raise ValueError(
|
||||
"W4A8 kernels require group quantization with group size 128"
|
||||
)
|
||||
|
||||
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}"
|
||||
)
|
||||
|
||||
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# hopper
|
||||
return 90
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_type,
|
||||
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx,
|
||||
out_type=params_dtype,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW4A8Fp8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = input_size != input_size_per_partition
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel
|
||||
)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
weight = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
|
||||
# After loading, we will transform bf16 -> fp8 ->
|
||||
# expand by 8x via `cutlass_pack_scale_fp8`
|
||||
# and construct per-channel fp32 scales.
|
||||
weight_scale_args = {
|
||||
"weight_loader": weight_loader,
|
||||
"data": torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
output_dim=0, input_dim=1, **weight_scale_args
|
||||
)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(
|
||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name="weight_zero_point",
|
||||
w_gidx_param_name="weight_g_idx",
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A8Int"]
|
||||
W4A8_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.int4,
|
||||
}
|
||||
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: int | None = None,
|
||||
is_static_input_scheme: bool = False,
|
||||
input_symmetric: bool = True,
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}."
|
||||
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}"
|
||||
)
|
||||
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 1
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
row_parallel = input_size != input_size_per_partition
|
||||
|
||||
# Compute effective group_size
|
||||
if self.group_size == -1:
|
||||
effective_group_size = (
|
||||
input_size_per_partition if row_parallel else input_size
|
||||
)
|
||||
else:
|
||||
effective_group_size = self.group_size
|
||||
|
||||
# Ensure group_size divides input_size_per_partition
|
||||
assert input_size_per_partition % effective_group_size == 0, (
|
||||
f"input_size_per_partition {input_size_per_partition}"
|
||||
f" not divisible by group_size {effective_group_size}"
|
||||
)
|
||||
|
||||
# Determine scale partitioning
|
||||
is_channelwise = self.group_size == -1
|
||||
repeat_scales = is_channelwise and row_parallel
|
||||
partition_scales = not repeat_scales
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=effective_group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=False,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW4A8Int", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
scales_and_zp_size = input_size_per_partition // effective_group_size
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition, input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader": weight_loader,
|
||||
"data": torch.empty(
|
||||
output_size_per_partition, scales_and_zp_size, dtype=params_dtype
|
||||
),
|
||||
}
|
||||
|
||||
if partition_scales:
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
output_dim=0, input_dim=1, **weight_scale_args
|
||||
)
|
||||
else:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name=None,
|
||||
w_gidx_param_name=None,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
process_fp8_weight_block_strategy,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
|
||||
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||
|
||||
strategy_to_parameter_type = {
|
||||
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
|
||||
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
|
||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||
}
|
||||
|
||||
|
||||
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
|
||||
self.weight_quant = weight_quant
|
||||
self.strategy = weight_quant.strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
# Validate block quantization shapes
|
||||
validate_fp8_block_shape(
|
||||
layer,
|
||||
input_size,
|
||||
output_size,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = create_fp8_weight_parameter(
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = create_fp8_scale_parameter(
|
||||
strategy_to_parameter_type[self.strategy],
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
layer.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE (to deal with converted checkpoints)
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
size_k_first = True
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.is_static_input_scheme is False
|
||||
size_k_first = False
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
weight, weight_scale
|
||||
)
|
||||
else:
|
||||
# Weights must be transposed for marlin
|
||||
weight = weight.t()
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales,
|
||||
# we expand each scale to its shard's channels.
|
||||
weight_scale = convert_to_channelwise(
|
||||
weight_scale, layer.logical_widths
|
||||
)
|
||||
|
||||
# Update layer with new values
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
prepare_fp8_layer_for_marlin(layer, size_k_first=size_k_first)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,206 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_channel_strategy,
|
||||
process_fp8_weight_tensor_strategy,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
strategy_to_parameter_type = {
|
||||
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
|
||||
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
|
||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||
}
|
||||
|
||||
STATIC_QUANT = True
|
||||
DYNAMIC_QUANT = False
|
||||
activation_quant_key_mapping = {
|
||||
STATIC_QUANT: kFp8StaticTensorSym,
|
||||
DYNAMIC_QUANT: kFp8DynamicTokenSym,
|
||||
}
|
||||
weight_quant_key_mapping = {
|
||||
QuantizationStrategy.CHANNEL: kFp8StaticTokenSym,
|
||||
QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
|
||||
}
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
|
||||
self.weight_quant = weight_quant
|
||||
self.strategy = weight_quant.strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
assert not self.is_static_input_scheme
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
|
||||
weight_quant_key = weight_quant_key_mapping[self.strategy]
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.weight_block_size = None
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
# Validate block quantization shapes
|
||||
validate_fp8_block_shape(
|
||||
layer,
|
||||
input_size,
|
||||
output_size,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = create_fp8_weight_parameter(
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = create_fp8_scale_parameter(
|
||||
strategy_to_parameter_type[self.strategy],
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
layer.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
weight = weight.t()
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_channel_strategy(
|
||||
layer.weight, layer.weight_scale, getattr(layer, "input_scale", None)
|
||||
)
|
||||
weight = weight.t()
|
||||
|
||||
elif self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.is_static_input_scheme is False
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale
|
||||
)
|
||||
input_scale = None
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown quantization strategy {self.strategy}: "
|
||||
f"should be one of {list(QuantizationStrategy)}"
|
||||
)
|
||||
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight = Parameter(weight.data, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale.data, requires_grad=False)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
|
||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.weight_block_size is not None:
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,212 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_int8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
self.kernel = init_int8_linear_kernel(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point = None
|
||||
input_scale = None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int8(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
self.kernel = init_int8_linear_kernel(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
# weight = ModelWeightParameter(
|
||||
# data=torch.empty(
|
||||
# sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
# ),
|
||||
# input_dim=1,
|
||||
# output_dim=0,
|
||||
# weight_loader=weight_loader,
|
||||
# )
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
sum(output_partition_sizes) // 2,
|
||||
dtype=torch.int8
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point = None
|
||||
input_scale = None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,228 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MarlinLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsWNA16"]
|
||||
WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128}
|
||||
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
|
||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: int | None = None,
|
||||
symmetric: bool | None = True,
|
||||
actorder: ActivationOrdering | None = None,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
self.layer_name = layer_name
|
||||
|
||||
if self.group_size == -1 and self.strategy != "channel":
|
||||
raise ValueError(
|
||||
"Marlin kernels require group quantization or "
|
||||
"channelwise quantization, but found no group "
|
||||
"size and strategy is not channelwise."
|
||||
)
|
||||
|
||||
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}"
|
||||
)
|
||||
|
||||
self.quant_type = (
|
||||
WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
|
||||
if not self.symmetric
|
||||
else WNA16_SUPPORTED_TYPES_MAP[num_bits]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
if kernel_type is MarlinLinearKernel:
|
||||
input_dtype = get_marlin_input_dtype(self.layer_name)
|
||||
if input_dtype is not None:
|
||||
mp_linear_kernel_config.act_type = input_dtype
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = input_size != input_size_per_partition
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel
|
||||
)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
weight = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader": weight_loader,
|
||||
"data": torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
}
|
||||
|
||||
zeros_args = {
|
||||
"weight_loader": weight_loader,
|
||||
"data": torch.zeros(
|
||||
output_size_per_partition // self.pack_factor,
|
||||
scales_and_zp_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
|
||||
|
||||
if not self.symmetric:
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args,
|
||||
)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
output_dim=0, input_dim=1, **weight_scale_args
|
||||
)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args,
|
||||
)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(
|
||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
if not self.symmetric:
|
||||
layer.register_parameter("weight_zero_point", qzeros)
|
||||
|
||||
# group index (for activation reordering)
|
||||
if self.has_g_idx:
|
||||
weight_g_idx = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name="weight_zero_point",
|
||||
w_gidx_param_name="weight_g_idx",
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,260 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Generator
|
||||
from itertools import accumulate
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import (
|
||||
TransformArgs,
|
||||
TransformConfig,
|
||||
TransformLocation,
|
||||
TransformScheme,
|
||||
)
|
||||
from compressed_tensors.utils import is_match
|
||||
|
||||
from vllm.model_executor.layers.linear import (
|
||||
WEIGHT_LOADER_V2_SUPPORTED,
|
||||
LinearMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
|
||||
HadamardTransform,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
|
||||
TransformTuple,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsLinearTransformMethod(LinearMethodBase):
|
||||
"""
|
||||
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
|
||||
input and output transforms to either side of the original apply method
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_schemes(
|
||||
cls,
|
||||
quant_method: LinearMethodBase,
|
||||
quant_scheme: CompressedTensorsScheme | None,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
output_tfms: dict[int, TransformTuple],
|
||||
) -> "CompressedTensorsLinearTransformMethod":
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501
|
||||
QutlassNvFP4LinearMethod,
|
||||
is_qutlass_fp4_scheme,
|
||||
)
|
||||
|
||||
assert input_tfms or output_tfms
|
||||
|
||||
if is_qutlass_fp4_scheme(quant_scheme, input_tfms):
|
||||
return QutlassNvFP4LinearMethod(quant_method, input_tfms, output_tfms)
|
||||
|
||||
# hadacore or dense gemm is selected by Transform module
|
||||
|
||||
return cls(quant_method, input_tfms, output_tfms)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_method: LinearMethodBase,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
output_tfms: dict[int, TransformTuple],
|
||||
):
|
||||
self.quant_method = quant_method
|
||||
self.input_tfms = input_tfms
|
||||
self.output_tfms = output_tfms
|
||||
|
||||
self.input_transform: HadamardTransform | None = None
|
||||
self.output_transform: HadamardTransform | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# get weight loader for transforms
|
||||
weight_loader: Callable = extra_weight_attrs.get("weight_loader") # type: ignore[assignment]
|
||||
|
||||
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
|
||||
# transforms (specifically SharedWeightParameter) requires
|
||||
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
|
||||
# hack around this by getting weight loader v1 so ULM can load correctly
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
|
||||
weight_loader_v1 = layer.weight_loader
|
||||
extra_weight_attrs["weight_loader"] = weight_loader_v1
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=layer,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
# validate schemes
|
||||
num_partitions = len(output_partition_sizes)
|
||||
self._validate_tfm_schemes(num_partitions)
|
||||
|
||||
# create submodules for weight loading
|
||||
if len(self.input_tfms) > 0:
|
||||
scheme_name = list(self.input_tfms.values())[0].scheme_name
|
||||
location = list(self.input_tfms.values())[0].args.location
|
||||
transform_name = f"{scheme_name}_{location}"
|
||||
|
||||
transform = HadamardTransform(
|
||||
self.input_tfms,
|
||||
layer,
|
||||
weight_loader,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
)
|
||||
layer.register_module(transform_name, transform)
|
||||
self.input_transform = transform
|
||||
|
||||
if len(self.output_tfms) > 0:
|
||||
scheme_name = list(self.output_tfms.values())[0].scheme_name
|
||||
location = list(self.output_tfms.values())[0].args.location
|
||||
transform_name = f"{scheme_name}_{location}"
|
||||
|
||||
transform = HadamardTransform(
|
||||
self.output_tfms,
|
||||
layer,
|
||||
weight_loader,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
)
|
||||
layer.register_module(transform_name, transform)
|
||||
self.output_transform = transform
|
||||
|
||||
# compute partition ranges for slicing activations
|
||||
starts = [0] + list(accumulate(output_partition_sizes))[:-1]
|
||||
self.partition_ranges = list(zip(starts, output_partition_sizes))
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
for submodule in layer.children():
|
||||
if isinstance(submodule, HadamardTransform):
|
||||
submodule.process_weights_after_loading()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.input_transform is not None:
|
||||
x = self.input_transform(x)
|
||||
|
||||
assert bias is None
|
||||
x = self.quant_method.apply(layer, x, bias)
|
||||
|
||||
# In most cases, input transforms are preferred over output transforms
|
||||
# (@ksayers): confirm that this is done concurrently
|
||||
if self.output_transform is not None:
|
||||
for part_id, (start, length) in enumerate(self.partition_ranges):
|
||||
x[:, start : start + length] = self.output_transform(
|
||||
x[:, start : start + length].clone(), part_id=part_id
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def _validate_tfm_schemes(self, num_partitions: int):
|
||||
if len(self.input_tfms) > 0:
|
||||
if 0 not in self.input_tfms:
|
||||
raise ValueError("Must have same input")
|
||||
|
||||
for part_index in range(num_partitions):
|
||||
if self.input_tfms[part_index] != self.input_tfms[0]:
|
||||
raise ValueError("Must have same input")
|
||||
|
||||
if len(self.output_tfms) > 0:
|
||||
scheme_name = list(self.output_tfms.values())[0].scheme_name
|
||||
location = list(self.output_tfms.values())[0].args.location
|
||||
|
||||
for tfm in self.output_tfms.values():
|
||||
if tfm.scheme_name != scheme_name:
|
||||
raise ValueError("Must have same scheme name")
|
||||
if tfm.args.location != location:
|
||||
raise ValueError("Must have same location")
|
||||
|
||||
return self.input_tfms, self.output_tfms
|
||||
|
||||
|
||||
def get_linear_transform_schemes(
|
||||
layer: torch.nn.Module,
|
||||
layer_name: str,
|
||||
transform_config: TransformConfig | None,
|
||||
packed_modules_mapping: dict[str, list[str]],
|
||||
) -> tuple[
|
||||
dict[int, TransformTuple], dict[int, TransformTuple]
|
||||
]: # [input_transform, [output_transform, ...]]
|
||||
# there can only be one transform input scheme per (fused) module
|
||||
input_tfms = {}
|
||||
output_tfms = {}
|
||||
|
||||
partition_names = get_layer_partition_names(layer_name, packed_modules_mapping)
|
||||
|
||||
for scheme_name, scheme, args in get_schemes_args(transform_config):
|
||||
for part_index, part_name in enumerate(partition_names):
|
||||
if (
|
||||
is_match(part_name, layer, args.targets, args.ignore)
|
||||
and args.is_online()
|
||||
):
|
||||
if args.location == TransformLocation.INPUT:
|
||||
input_tfms[part_index] = TransformTuple(scheme_name, scheme, args)
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
output_tfms[part_index] = TransformTuple(scheme_name, scheme, args)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot apply `{args.location}` transform to `{layer_name}`"
|
||||
)
|
||||
|
||||
return (input_tfms, output_tfms)
|
||||
|
||||
|
||||
def get_schemes_args(
|
||||
transform_config: TransformConfig | None,
|
||||
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
|
||||
if transform_config is None:
|
||||
return
|
||||
|
||||
for scheme_name, scheme in transform_config.config_groups.items():
|
||||
for args in scheme.apply:
|
||||
yield (scheme_name, scheme, args)
|
||||
|
||||
|
||||
def get_layer_partition_names(
|
||||
layer_name: str, packed_modules_mapping: dict[str, list[str]]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get all partition names associated with this layer.
|
||||
Names are returned in order of their partition indices.
|
||||
|
||||
```python
|
||||
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
|
||||
|
||||
assert get_layer_partition_names("mlp.gate_up_proj", mapping) == [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
]
|
||||
assert get_layer_partition_names("mlp.down_proj", mapping) == ["down_proj"]"""
|
||||
for fused_suffix, part_suffixes in packed_modules_mapping.items():
|
||||
if layer_name.endswith(fused_suffix):
|
||||
return [
|
||||
layer_name.removesuffix(fused_suffix) + part_suffix
|
||||
for part_suffix in part_suffixes
|
||||
]
|
||||
|
||||
return [layer_name]
|
||||
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Callable, Hashable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import (
|
||||
TransformArgs,
|
||||
TransformLocation,
|
||||
TransformScheme,
|
||||
)
|
||||
from torch import Tensor
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
|
||||
TransformTuple,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.parameter import SharedWeightParameter
|
||||
|
||||
|
||||
class HadamardTransform(torch.nn.Module):
|
||||
"""
|
||||
Class which handles weight loading, postprocessing, and application of
|
||||
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
|
||||
and attention transforms method (not implemented yet)
|
||||
"""
|
||||
|
||||
transforms: dict[int, TransformTuple] # info parsed from transforms config
|
||||
weight: SharedWeightParameter # container for shared tensors
|
||||
|
||||
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transforms: dict[int, TransformTuple],
|
||||
layer: torch.nn.Module,
|
||||
weight_loader: Callable,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
):
|
||||
super().__init__()
|
||||
self.transforms = transforms
|
||||
self.scales = {}
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
raise NotImplementedError(
|
||||
"Online transforms with tensor parallelism is not supported"
|
||||
)
|
||||
|
||||
# Similar to row/col parallel params, but tensors are separate
|
||||
# to allow for loading with shared memory
|
||||
self.weight = SharedWeightParameter(weight_loader=weight_loader)
|
||||
|
||||
# create shared partition data for each partition of the original weight
|
||||
input_size = input_size_per_partition
|
||||
for part_index, (_scheme_name, scheme, args) in self.transforms.items():
|
||||
output_size = output_partition_sizes[part_index]
|
||||
weight_size = self._get_weight_size(
|
||||
layer, scheme, args, input_size, output_size
|
||||
)
|
||||
|
||||
data_key = self._get_data_key(scheme, weight_size)
|
||||
self.weight.add_partition(
|
||||
part_index,
|
||||
data_key,
|
||||
size=(weight_size, weight_size),
|
||||
dtype=scheme.precision,
|
||||
)
|
||||
|
||||
# validate that shared tensors and schemes are correct
|
||||
self._validate_input_transforms()
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
for part_id in self.weight.partitions:
|
||||
data = self.weight.partitions[part_id].data
|
||||
|
||||
# required by torch.compile
|
||||
self.weight.process_weights_after_loading()
|
||||
|
||||
# precompute scale as a runtime multiply, not division
|
||||
# do not fold into weight in order to utilize FWHT
|
||||
self.scales[part_id] = 1 / math.sqrt(data.size(0))
|
||||
|
||||
# FUTURE: avoid runtime transpose by processing weights
|
||||
# prior to apply
|
||||
|
||||
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
|
||||
if part_id not in self.weight.partitions:
|
||||
return value
|
||||
|
||||
# use hadacore if possible
|
||||
if self.transforms[part_id].scheme.type == "hadamard":
|
||||
if self.transforms[part_id].scheme.head_dim is not None:
|
||||
weight_size = self.transforms[part_id].scheme.head_dim
|
||||
value = value.unflatten(-1, (-1, weight_size))
|
||||
value = ops.hadacore_transform(value)
|
||||
value = value.flatten(-2, -1)
|
||||
|
||||
return value
|
||||
|
||||
# sylvester transforms are symmetric, inv => transpose => original
|
||||
return ops.hadacore_transform(value)
|
||||
|
||||
# fall back to dense
|
||||
else:
|
||||
weight = self.weight.partitions[part_id]
|
||||
weight = (
|
||||
weight if self.transforms[part_id].args.inverse else weight.T
|
||||
) # linear := x(W.T)
|
||||
scale = self.scales[part_id]
|
||||
|
||||
if self.transforms[part_id].scheme.head_dim is not None:
|
||||
value = value.unflatten(-1, (-1, weight.size(0)))
|
||||
value = (
|
||||
dispatch_unquantized_gemm()(
|
||||
self, value.to(weight.dtype), weight, None
|
||||
).to(value.dtype)
|
||||
* scale
|
||||
)
|
||||
value = value.flatten(-2, -1)
|
||||
|
||||
return value
|
||||
|
||||
return (
|
||||
dispatch_unquantized_gemm()(
|
||||
self, value.to(weight.dtype), weight, None
|
||||
).to(value.dtype)
|
||||
* scale
|
||||
)
|
||||
|
||||
def _get_data_key(self, scheme: TransformScheme, weight_size: int) -> Hashable:
|
||||
return (id(scheme), weight_size)
|
||||
|
||||
def _get_weight_size(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
scheme: TransformScheme,
|
||||
args: TransformArgs,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
) -> int:
|
||||
if scheme.head_dim is not None:
|
||||
return scheme.head_dim
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if args.location == TransformLocation.INPUT:
|
||||
return input_size
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
return output_size
|
||||
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if args.location == TransformLocation.INPUT:
|
||||
return output_size
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
return input_size
|
||||
|
||||
raise ValueError()
|
||||
|
||||
def _validate_input_transforms(self):
|
||||
assert len(self.transforms) > 0
|
||||
location = list(self.transforms.values())[0].args.location
|
||||
|
||||
if location == TransformLocation.INPUT:
|
||||
first_data = self.weight.partitions[0].data
|
||||
for partition in self.weight.partitions.values():
|
||||
if partition.data.data_ptr() != first_data.data_ptr():
|
||||
raise ValueError("")
|
||||
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsScheme,
|
||||
CompressedTensorsW4A4Fp4,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
CompressedTensorsLinearTransformMethod,
|
||||
TransformTuple,
|
||||
)
|
||||
|
||||
__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
|
||||
|
||||
|
||||
def is_qutlass_fp4_scheme(
|
||||
quant_scheme: CompressedTensorsScheme | None,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
) -> bool:
|
||||
return (
|
||||
isinstance(quant_scheme, (CompressedTensorsW4A4Fp4,))
|
||||
and len(input_tfms) == 1
|
||||
and input_tfms[0].scheme.head_dim == quant_scheme.group_size
|
||||
)
|
||||
|
||||
|
||||
class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
|
||||
def create_weights(
|
||||
self,
|
||||
layer,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
input_size,
|
||||
output_size,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# initializes fp4 qparams
|
||||
assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4,))
|
||||
ret = super().create_weights(
|
||||
layer,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
input_size,
|
||||
output_size,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
assert self.input_transform is not None
|
||||
assert len(self.input_transform.weight) == 1
|
||||
assert self.input_transform.weight[0].size(0) == layer.scheme.group_size
|
||||
|
||||
return ret
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import NamedTuple
|
||||
|
||||
from compressed_tensors.transform import TransformArgs, TransformScheme
|
||||
|
||||
__all__ = ["TransformTuple"]
|
||||
|
||||
|
||||
class TransformTuple(NamedTuple):
|
||||
scheme_name: str
|
||||
scheme: TransformScheme
|
||||
args: TransformArgs
|
||||
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def is_weak_contiguous(x: torch.Tensor):
|
||||
strides = x.stride()
|
||||
sizes = x.shape
|
||||
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
|
||||
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
|
||||
return is_transpose or is_not_transpose
|
||||
|
||||
|
||||
@triton.jit
|
||||
def scaled_mm_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
scale_a_ptr,
|
||||
scale_b_ptr,
|
||||
c_ptr,
|
||||
bias_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
ACCUMULATOR_DTYPE: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_SCALE_A: tl.constexpr,
|
||||
BLOCK_SIZE_SCALE_B: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
accumulator_dtype = ACCUMULATOR_DTYPE
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
|
||||
|
||||
# NOTE: Some tensor inputs are so large, they will cause int32 overflow
|
||||
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
|
||||
# eventually occur.
|
||||
|
||||
# Offsets and masks.
|
||||
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
masks_am = offsets_am < M
|
||||
|
||||
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
||||
masks_bn = offsets_bn < N
|
||||
|
||||
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
||||
offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
|
||||
offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
|
||||
|
||||
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
|
||||
# appropriate offsets and masks for each case. Same goes for
|
||||
# BLOCK_SIZE_SCALE_B.
|
||||
offsets_scale_am = (
|
||||
tl.arange(0, BLOCK_SIZE_SCALE_A)
|
||||
+ (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M
|
||||
)
|
||||
masks_scale_am = offsets_scale_am < M
|
||||
|
||||
offsets_scale_bn = (
|
||||
tl.arange(0, BLOCK_SIZE_SCALE_B)
|
||||
+ (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
|
||||
)
|
||||
masks_scale_bn = offsets_scale_bn < N
|
||||
|
||||
a_ptrs = a_ptr + offsets_a
|
||||
b_ptrs = b_ptr + offsets_b
|
||||
|
||||
scale_a_ptrs = scale_a_ptr + offsets_scale_am
|
||||
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
masks_k = offsets_k < K
|
||||
masks_a = masks_am[:, None] & masks_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=masks_a)
|
||||
|
||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||
b = tl.load(b_ptrs, mask=masks_b)
|
||||
|
||||
# Accumulate results.
|
||||
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
||||
|
||||
offsets_k += BLOCK_SIZE_K
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# Apply scale at end.
|
||||
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
|
||||
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
|
||||
# Need to broadcast to the appropriate size, if scale_a is already
|
||||
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
|
||||
# for scale_b below.
|
||||
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
|
||||
accumulator = scale_a * accumulator.to(tl.float32)
|
||||
|
||||
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
|
||||
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
|
||||
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
|
||||
accumulator = scale_b.T * accumulator.to(tl.float32)
|
||||
|
||||
# Convert to output format.
|
||||
c = accumulator.to(c_ptr.type.element_ty)
|
||||
|
||||
# Add bias, it's already in output format, so add it after conversion.
|
||||
if bias_ptr:
|
||||
offsets_bias = offsets_bn
|
||||
bias_ptrs = bias_ptr + offsets_bias
|
||||
bias_mask = offsets_bias < N
|
||||
bias = tl.load(bias_ptrs, bias_mask)
|
||||
c += bias
|
||||
|
||||
# Save output
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# input - [M, K]
|
||||
# weight - [K, N]
|
||||
def triton_scaled_mm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: torch.Tensor | None = None,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
block_size_k: int = 32,
|
||||
use_heuristic=True,
|
||||
) -> torch.Tensor:
|
||||
M, K = input.shape
|
||||
N = weight.shape[1]
|
||||
|
||||
assert N > 0 and K > 0 and M > 0
|
||||
assert weight.shape[0] == K
|
||||
assert input.dtype == weight.dtype
|
||||
|
||||
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
|
||||
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
||||
|
||||
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
||||
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M)
|
||||
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N)
|
||||
assert out_dtype.is_floating_point
|
||||
assert bias is None or bias.is_floating_point()
|
||||
assert is_weak_contiguous(input)
|
||||
assert is_weak_contiguous(weight)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
result = torch.empty((M, N), dtype=out_dtype, device=input.device)
|
||||
|
||||
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
|
||||
|
||||
if use_heuristic:
|
||||
is_small_N = N < 8192
|
||||
next_power_of_2_M = max(32, triton.next_power_of_2(M))
|
||||
if next_power_of_2_M <= 32:
|
||||
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
|
||||
elif next_power_of_2_M <= 64:
|
||||
tile_shape = (64, 64, 256)
|
||||
elif next_power_of_2_M <= 128:
|
||||
tile_shape = (64, 128, 128)
|
||||
else:
|
||||
tile_shape = (128, 128, 128)
|
||||
|
||||
block_size_m, block_size_n, block_size_k = tile_shape
|
||||
|
||||
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
|
||||
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
|
||||
|
||||
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
|
||||
|
||||
# A = input, B = weight, C = result
|
||||
# A = M x K, B = K x N, C = M x N
|
||||
scaled_mm_kernel[grid](
|
||||
input,
|
||||
weight,
|
||||
scale_a,
|
||||
scale_b,
|
||||
result,
|
||||
bias,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
weight.stride(0),
|
||||
weight.stride(1),
|
||||
result.stride(0),
|
||||
result.stride(1),
|
||||
accumulator_dtype,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
BLOCK_SIZE_SCALE_A=block_size_sa,
|
||||
BLOCK_SIZE_SCALE_B=block_size_sb,
|
||||
)
|
||||
|
||||
return result.to(out_dtype)
|
||||
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
|
||||
import regex as re
|
||||
from compressed_tensors import CompressionFormat
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
def is_activation_quantization_format(format: str) -> bool:
|
||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||
CompressionFormat.naive_quantized.value,
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value,
|
||||
CompressionFormat.nvfp4_pack_quantized.value,
|
||||
]
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: str | None,
|
||||
ignore: Iterable[str] = tuple(),
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping and layer_name not in ignore:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore
|
||||
)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(
|
||||
f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme."
|
||||
)
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(
|
||||
layer_name=layer_name, targets=ignore
|
||||
)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
|
||||
|
||||
|
||||
def find_matched_target(
|
||||
layer_name: str | None,
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
config that a layer corresponds to.
|
||||
|
||||
Recall that a compressed-tensors configs has a concept of
|
||||
config_groups, where each layer can be quantized with a different
|
||||
scheme.
|
||||
|
||||
targets in each config_group will be a list of either layer names
|
||||
(or regexes corresponding to layer names) or names of torch Modules.
|
||||
|
||||
First, we try to match the layer_name with a target
|
||||
Second, we try to match the module's name with a target
|
||||
Third, we try to map the layer_name to a list of fused module names.
|
||||
*All* component module names must match in order for a match to be
|
||||
successful. A successful match returns the first component target
|
||||
|
||||
:param layer_name: layer name
|
||||
:param module: torch.nn.Module
|
||||
:param targets: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
:param fused_strategy: either "all" or "any". If using "all", fused
|
||||
layers match if "all" of its components match
|
||||
"""
|
||||
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
|
||||
matched_target = (
|
||||
_find_first_match(layer_name, targets)
|
||||
or _find_first_match(module.__class__.__name__, targets, True)
|
||||
or _match_fused_layer(layer_name, targets, fused_mapping)
|
||||
)
|
||||
|
||||
if matched_target is None:
|
||||
raise ValueError(
|
||||
f"Unable to find matching target for {layer_name} in the "
|
||||
"compressed-tensors config."
|
||||
)
|
||||
|
||||
return matched_target
|
||||
|
||||
|
||||
def _find_first_match(
|
||||
value: str, targets: Iterable[str], check_contains: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
Returns first element of target that matches value either
|
||||
exactly or as a regex after 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
|
||||
:param value: string to compare the list of targets against
|
||||
:param targets: list of targets to match the layer against
|
||||
:param check_contains: whether or not to do a substring match
|
||||
"""
|
||||
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(value, target, check_contains=check_contains):
|
||||
return target
|
||||
return None
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(
|
||||
value: str, target: str, check_contains: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _match_fused_layer(
|
||||
layer_name: str,
|
||||
target_layers: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]],
|
||||
) -> str | None:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers. Returns first value in fused_mapping which matches targets
|
||||
|
||||
Implements an "all" matching strategy where a fused layer matches iff
|
||||
"all" of its components match
|
||||
|
||||
:param layer_name: layer name
|
||||
:param target_layers: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
|
||||
Examples:
|
||||
layer_name = "model.layers.0.self_attn.qkv_proj"
|
||||
target_layers = ["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.k_proj",
|
||||
"model.layers.0.self_attn.v_proj"]
|
||||
"""
|
||||
# find layer_name in mapping
|
||||
fused = next((key for key in fused_mapping if layer_name.endswith(key)), None)
|
||||
if fused is None:
|
||||
return None
|
||||
|
||||
# expand path of unfused components
|
||||
unfused_paths = [
|
||||
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
|
||||
]
|
||||
|
||||
# for each unfused component, find a match in targets
|
||||
unfused_matches: list[str | None] = []
|
||||
for unfused in unfused_paths:
|
||||
for target in target_layers:
|
||||
if _is_equal_or_regex_match(unfused, target):
|
||||
unfused_matches.append(target)
|
||||
break
|
||||
else:
|
||||
unfused_matches.append(None)
|
||||
|
||||
return unfused_matches[0] if all(unfused_matches) else None
|
||||
299
vllm/model_executor/layers/quantization/cpu_wna16.py
Normal file
299
vllm/model_executor/layers/quantization/cpu_wna16.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
|
||||
from vllm._custom_ops import (
|
||||
cpu_gemm_wna16,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
pack_cols,
|
||||
unpack_cols,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUAWQConfig(QuantizationConfig):
|
||||
"""Config class for CPU AWQ"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: list[str] | None,
|
||||
full_config: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert weight_bits == 4
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.full_config = full_config
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AWQMarlinConfig("
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> "QuantizationMethods":
|
||||
return "cpu_awq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CPUAWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
zero_point,
|
||||
lm_head_quantized,
|
||||
modules_to_not_convert,
|
||||
config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> "QuantizationMethods | None":
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
if current_platform.is_cpu() and (quant_method == "awq"):
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
if is_layer_skipped(
|
||||
prefix,
|
||||
self.modules_to_not_convert,
|
||||
self.packed_modules_mapping,
|
||||
skip_with_substr=True,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return CPUAWQLinearMethod(self)
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_to_not_convert:
|
||||
self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_to_not_convert = list(layers - quant_layers)
|
||||
|
||||
|
||||
class CPUAWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for CPU AWQ.
|
||||
|
||||
Args:
|
||||
quant_config: The CPU AWQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: CPUAWQConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
assert self.quant_config.zero_point
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
scales = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False)
|
||||
packed_weight = layer.qweight.data
|
||||
packed_zeros = layer.qzeros.data
|
||||
group_num = packed_zeros.size(0)
|
||||
bits = self.quant_config.weight_bits
|
||||
pack_factor = int(self.quant_config.pack_factor)
|
||||
input_size, packed_output_size = packed_weight.size()
|
||||
output_size = packed_output_size * pack_factor
|
||||
isa_hint = _get_isa_hint(layer.scales.dtype)
|
||||
layer.isa_hint = isa_hint
|
||||
|
||||
interleave_map = (0, 4, 1, 5, 2, 6, 3, 7)
|
||||
weight = unpack_cols(
|
||||
packed_weight,
|
||||
bits,
|
||||
input_size,
|
||||
output_size,
|
||||
)
|
||||
zeros = unpack_cols(
|
||||
packed_zeros,
|
||||
bits,
|
||||
group_num,
|
||||
output_size,
|
||||
)
|
||||
weight = (
|
||||
weight.view(input_size, -1, pack_factor)[:, :, interleave_map]
|
||||
.reshape(input_size, output_size)
|
||||
.contiguous()
|
||||
)
|
||||
zeros = (
|
||||
zeros.view(group_num, -1, pack_factor)[:, :, interleave_map]
|
||||
.reshape(group_num, output_size)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
zeros = pack_cols(zeros, bits, group_num, output_size).contiguous()
|
||||
# make 16 output channel as a block and transpose to
|
||||
# the make the block contigous
|
||||
weight = pack_cols(weight, bits, input_size, output_size)
|
||||
weight = (
|
||||
weight.view(input_size, -1, 16 // pack_factor)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1, input_size * 16 // pack_factor)
|
||||
.contiguous()
|
||||
)
|
||||
layer.qweight.data = weight
|
||||
layer.qzeros.data = zeros
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x = cpu_gemm_wna16(
|
||||
input=x,
|
||||
q_weight=layer.qweight,
|
||||
scales=layer.scales,
|
||||
zeros=layer.qzeros,
|
||||
g_idx=None,
|
||||
bias=bias,
|
||||
pack_factor=8,
|
||||
isa_hint=layer.isa_hint,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def _get_isa_hint(dtype: torch.dtype) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,):
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
204
vllm/model_executor/layers/quantization/experts_int8.py
Normal file
204
vllm/model_executor/layers/quantization/experts_int8.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class ExpertsInt8Config(QuantizationConfig):
|
||||
"""Config class for Int8 experts quantization."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "experts_int8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ExpertsInt8MoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ExpertsInt8Config,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
int8_dtype = torch.int8
|
||||
|
||||
assert "weight_loader" in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
|
||||
layer, weight_loader
|
||||
)
|
||||
extra_weight_attrs["weight_loader"] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=int8_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=int8_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scale", w13_scale)
|
||||
|
||||
w2_scale = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scale", w2_scale)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return int8_w8a16_moe_quant_config(
|
||||
w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, w1_zp=None, w2_zp=None
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def quantizing_weight_loader(layer, weight_loader):
|
||||
def quantize_and_call_weight_loader(
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: int,
|
||||
expert_id: int,
|
||||
):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
device = get_tp_group().device
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
# w1, gate_proj case: Load into first shard of w13.
|
||||
if shard_id == "w1":
|
||||
scales = quantize_in_place_and_get_scales(loaded_weight[shard, :])
|
||||
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, 0])
|
||||
# w3, up_proj case: Load into second shard of w13.
|
||||
elif shard_id == "w3":
|
||||
scales = quantize_in_place_and_get_scales(loaded_weight[shard, :])
|
||||
layer.w13_scale.data[expert_id, shard_size : 2 * shard_size].copy_(
|
||||
scales[:, 0]
|
||||
)
|
||||
# w2, down_proj case: Load into only shard of w2.
|
||||
elif shard_id == "w2":
|
||||
scales = quantize_in_place_and_get_scales(loaded_weight[:, shard])
|
||||
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
|
||||
else:
|
||||
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
|
||||
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
|
||||
|
||||
return quantize_and_call_weight_loader
|
||||
|
||||
|
||||
def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
|
||||
vmax = torch.iinfo(torch.int8).max
|
||||
scales = torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax
|
||||
|
||||
weight.div_(scales)
|
||||
weight.round_()
|
||||
weight.clamp_(-vmax, vmax)
|
||||
|
||||
return scales
|
||||
191
vllm/model_executor/layers/quantization/fbgemm_fp8.py
Normal file
191
vllm/model_executor/layers/quantization/fbgemm_fp8.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FBGEMMFp8Config(QuantizationConfig):
|
||||
"""Config class for FBGEMM Fp8."""
|
||||
|
||||
def __init__(self, ignore_list: list[str], input_scale_ub: float):
|
||||
super().__init__()
|
||||
self.ignore_list = ignore_list if ignore_list else []
|
||||
self.input_scale_ub = input_scale_ub
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = not current_platform.has_device_capability(89)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fbgemm_fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
|
||||
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
|
||||
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
||||
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignore_list,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return FBGEMMFp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE UPPER BOUND
|
||||
input_scale_ub = torch.nn.Parameter(
|
||||
torch.tensor((self.quant_config.input_scale_ub), dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.input_scale_ub = input_scale_ub
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=layer.weight_scale, input_scale=None
|
||||
)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
if self.quant_config.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale_ub
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
1306
vllm/model_executor/layers/quantization/fp8.py
Normal file
1306
vllm/model_executor/layers/quantization/fp8.py
Normal file
File diff suppressed because it is too large
Load Diff
416
vllm/model_executor/layers/quantization/fp_quant.py
Normal file
416
vllm/model_executor/layers/quantization/fp_quant.py
Normal file
@@ -0,0 +1,416 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._custom_ops import (
|
||||
cutlass_scaled_fp4_mm,
|
||||
fusedQuantizeMx,
|
||||
fusedQuantizeNv,
|
||||
matmul_mxf4_bf16_tn,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
class FPQuantConfig(QuantizationConfig):
|
||||
"""Config class for FPQuant."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hadamard_group_size: int = 32,
|
||||
forward_dtype: str = "mxfp4",
|
||||
forward_method: str = "abs_max",
|
||||
pseudoquantization: bool = False,
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hadamard_group_size = hadamard_group_size
|
||||
self.forward_dtype = forward_dtype
|
||||
self.forward_method = forward_method
|
||||
self.pseudoquantization = pseudoquantization
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
if pseudoquantization:
|
||||
raise ValueError("Pseudoquantization is not supported for vLLM")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"FPQuantConfig(hadamard_group_size={self.hadamard_group_size}, "
|
||||
f"forward_dtype={self.forward_dtype}, "
|
||||
f"forward_method={self.forward_method}, "
|
||||
f"pseudoquantization={self.pseudoquantization}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fp_quant"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "FPQuantConfig":
|
||||
hadamard_group_size = cls.get_from_keys(config, ["hadamard_group_size"])
|
||||
forward_dtype = cls.get_from_keys(config, ["forward_dtype"])
|
||||
forward_method = cls.get_from_keys(config, ["forward_method"])
|
||||
pseudoquantization = cls.get_from_keys(config, ["pseudoquantization"])
|
||||
modules_to_not_convert = cls.get_from_keys(config, ["modules_to_not_convert"])
|
||||
return cls(
|
||||
hadamard_group_size,
|
||||
forward_dtype,
|
||||
forward_method,
|
||||
pseudoquantization,
|
||||
modules_to_not_convert,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> LinearMethodBase | None:
|
||||
if self.modules_to_not_convert is not None and any(
|
||||
prefix.endswith(module) for module in self.modules_to_not_convert
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return FPQuantLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class FPQuantLinearMethod(LinearMethodBase):
|
||||
"""Linear method for FPQuant.
|
||||
|
||||
Args:
|
||||
quant_config: The FPQuant quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: FPQuantConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
del input_size # Unused.
|
||||
|
||||
if params_dtype != torch.bfloat16:
|
||||
raise ValueError("Only bfloat16 is currently supported by FPQuant")
|
||||
if input_size_per_partition % self.quant_config.hadamard_group_size != 0: # noqa: E501
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size. Or other skill issues."
|
||||
)
|
||||
|
||||
assert self.quant_config.forward_dtype in ["mxfp4", "nvfp4"], (
|
||||
"Only mxfp4 and nvfp4 are supported for now"
|
||||
)
|
||||
if self.quant_config.forward_dtype == "mxfp4":
|
||||
group_size = 32
|
||||
elif self.quant_config.forward_dtype == "nvfp4":
|
||||
group_size = 16
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported forward_dtype: {self.quant_config.forward_dtype}"
|
||||
)
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight,
|
||||
{
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": 2,
|
||||
}
|
||||
| extra_weight_attrs,
|
||||
)
|
||||
layer.register_parameter("qweight", qweight)
|
||||
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // group_size,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales,
|
||||
{
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": group_size,
|
||||
}
|
||||
| extra_weight_attrs,
|
||||
)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
weight_global_scale = Parameter(
|
||||
torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
weight_global_scale, {"ignore_warning": True} | extra_weight_attrs
|
||||
)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
act_global_scale = Parameter(
|
||||
torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
act_global_scale, {"ignore_warning": True} | extra_weight_attrs
|
||||
)
|
||||
layer.register_parameter("act_global_scale", act_global_scale)
|
||||
|
||||
forward_hadamard_matrix = Parameter(
|
||||
torch.empty(
|
||||
self.quant_config.hadamard_group_size,
|
||||
self.quant_config.hadamard_group_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
forward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
|
||||
)
|
||||
layer.register_parameter("forward_hadamard_matrix", forward_hadamard_matrix)
|
||||
|
||||
backward_hadamard_matrix = Parameter(
|
||||
torch.empty(
|
||||
self.quant_config.hadamard_group_size,
|
||||
self.quant_config.hadamard_group_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
backward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
|
||||
)
|
||||
layer.register_parameter("backward_hadamard_matrix", backward_hadamard_matrix)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return quantized_forward(
|
||||
x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.weight_global_scale,
|
||||
layer.act_global_scale,
|
||||
bias,
|
||||
layer.forward_hadamard_matrix,
|
||||
self.quant_config.forward_method,
|
||||
self.quant_config.forward_dtype,
|
||||
)
|
||||
|
||||
|
||||
def fused_quantize_mx(
|
||||
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method)
|
||||
|
||||
|
||||
def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method):
|
||||
rows, cols = x_flat.size(0), x_flat.size(1) // 32
|
||||
padded_rows = ((rows + 128 - 1) // 128) * 128
|
||||
padded_cols = ((cols + 4 - 1) // 4) * 4
|
||||
|
||||
xh_e2m1 = torch.empty(
|
||||
x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device
|
||||
)
|
||||
xh_e8m0 = torch.empty(
|
||||
padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=x_flat.device
|
||||
)
|
||||
|
||||
return xh_e2m1, xh_e8m0
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_quantize_mx",
|
||||
op_func=fused_quantize_mx,
|
||||
mutates_args=[],
|
||||
fake_impl=fused_quantize_mx_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def matmul_mxf4_bf16(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
xs: torch.Tensor,
|
||||
ws: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return matmul_mxf4_bf16_tn(
|
||||
x,
|
||||
w,
|
||||
to_blocked(xs, backend="triton").view(torch.float8_e8m0fnu),
|
||||
to_blocked(ws, backend="triton").view(torch.float8_e8m0fnu),
|
||||
alpha,
|
||||
)
|
||||
|
||||
|
||||
def matmul_mxf4_bf16_fake(x, w, xs, ws, alpha):
|
||||
return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="matmul_mxf4_bf16",
|
||||
op_func=matmul_mxf4_bf16,
|
||||
mutates_args=[],
|
||||
fake_impl=matmul_mxf4_bf16_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def fused_quantize_nv(
|
||||
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, global_scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return fusedQuantizeNv(x_flat, hadamard_matrix, global_scale)
|
||||
|
||||
|
||||
def fused_quantize_nv_fake(x_flat, hadamard_matrix, global_scale):
|
||||
rows, cols = x_flat.size(0), x_flat.size(1) // 16
|
||||
padded_rows = ((rows + 128 - 1) // 128) * 128
|
||||
padded_cols = ((cols + 4 - 1) // 4) * 4
|
||||
|
||||
xh_e2m1 = torch.empty(
|
||||
x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device
|
||||
)
|
||||
xh_e8m0 = torch.empty(
|
||||
padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=x_flat.device
|
||||
)
|
||||
|
||||
return xh_e2m1, xh_e8m0
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_quantize_nv",
|
||||
op_func=fused_quantize_nv,
|
||||
mutates_args=[],
|
||||
fake_impl=fused_quantize_nv_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def matmul_nvf4_bf16(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
xs: torch.Tensor,
|
||||
ws: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return cutlass_scaled_fp4_mm(
|
||||
x,
|
||||
w,
|
||||
to_blocked(xs, backend="triton")
|
||||
.view(torch.float8_e4m3fn)
|
||||
.view(-1, x.shape[1] // 8), # *2//16
|
||||
to_blocked(ws, backend="triton")
|
||||
.view(torch.float8_e4m3fn)
|
||||
.view(-1, x.shape[1] // 8),
|
||||
alpha,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
|
||||
def matmul_nvf4_bf16_fake(x, w, xs, ws, alpha):
|
||||
return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="matmul_nvf4_bf16",
|
||||
op_func=matmul_nvf4_bf16,
|
||||
mutates_args=[],
|
||||
fake_impl=matmul_nvf4_bf16_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def quantized_forward(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
weight_scales: torch.Tensor,
|
||||
weight_global_scale: torch.Tensor,
|
||||
act_global_scale: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
forward_hadamard_matrix: torch.Tensor,
|
||||
forward_method: str,
|
||||
forward_dtype: str,
|
||||
) -> torch.Tensor:
|
||||
x_flat = x.contiguous().flatten(end_dim=-2)
|
||||
|
||||
if forward_dtype == "mxfp4":
|
||||
x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_mx(
|
||||
x_flat, forward_hadamard_matrix, forward_method
|
||||
)
|
||||
y = torch.ops.vllm.matmul_mxf4_bf16(
|
||||
x_flat_q,
|
||||
qweight,
|
||||
x_flat_scales,
|
||||
weight_scales,
|
||||
1 / (weight_global_scale * act_global_scale),
|
||||
)
|
||||
elif forward_dtype == "nvfp4":
|
||||
x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_nv(
|
||||
x_flat, forward_hadamard_matrix, act_global_scale
|
||||
)
|
||||
y = torch.ops.vllm.matmul_nvf4_bf16(
|
||||
x_flat_q,
|
||||
qweight,
|
||||
x_flat_scales,
|
||||
weight_scales,
|
||||
1 / (weight_global_scale * act_global_scale),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported forward_dtype: {forward_dtype}")
|
||||
|
||||
y = y.view(*x.shape[:-1], y.shape[-1])
|
||||
if bias is not None:
|
||||
y += bias
|
||||
|
||||
return y
|
||||
678
vllm/model_executor/layers/quantization/gguf.py
Normal file
678
vllm/model_executor/layers/quantization/gguf.py
Normal file
@@ -0,0 +1,678 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
from gguf import GGMLQuantizationType as WeightType
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
|
||||
def __init__(self, unquantized_modules: list[str] | None = None) -> None:
|
||||
super().__init__()
|
||||
self.unquantized_modules = unquantized_modules or []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "GGUFConfig()"
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
# GGUF dequantization kernels use half precision (fp16) internally.
|
||||
# bfloat16 has precision issues on Blackwell devices.
|
||||
if current_platform.has_device_capability(100):
|
||||
logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.")
|
||||
return [torch.half, torch.float32]
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_gguf(
|
||||
prefix, self.unquantized_modules, self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return GGUFLinearMethod(self)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if is_layer_skipped_gguf(
|
||||
prefix, self.unquantized_modules, self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return GGUFEmbeddingMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
# TODO: Select UnquantizedFusedMoEMethod on unquantized layers.
|
||||
return GGUFMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
"""
|
||||
Interface for models to update module names referenced in
|
||||
quantization configs in order to reflect the vllm model structure
|
||||
|
||||
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
|
||||
structure of the qconfig) to vllm model structure
|
||||
"""
|
||||
if self.unquantized_modules is not None:
|
||||
self.unquantized_modules = hf_to_vllm_mapper.apply_list(
|
||||
self.unquantized_modules
|
||||
)
|
||||
|
||||
|
||||
def is_layer_skipped_gguf(
|
||||
prefix: str,
|
||||
unquantized_modules: list[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
):
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = any(
|
||||
shard_prefix in module_name for module_name in unquantized_modules
|
||||
)
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision."
|
||||
)
|
||||
else:
|
||||
is_skipped = any(module_name in prefix for module_name in unquantized_modules)
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
|
||||
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
WeightType.Q4_0,
|
||||
WeightType.Q4_1,
|
||||
WeightType.Q5_0,
|
||||
WeightType.Q5_1,
|
||||
WeightType.Q8_0,
|
||||
WeightType.Q8_1,
|
||||
}
|
||||
KQUANT_TYPES = {
|
||||
WeightType.Q2_K,
|
||||
WeightType.Q3_K,
|
||||
WeightType.Q4_K,
|
||||
WeightType.Q5_K,
|
||||
WeightType.Q6_K,
|
||||
}
|
||||
IMATRIX_QUANT_TYPES = {
|
||||
WeightType.IQ1_M,
|
||||
WeightType.IQ1_S,
|
||||
WeightType.IQ2_XXS,
|
||||
WeightType.IQ2_XS,
|
||||
WeightType.IQ2_S,
|
||||
WeightType.IQ3_XXS,
|
||||
WeightType.IQ3_S,
|
||||
WeightType.IQ4_XS,
|
||||
WeightType.IQ4_NL,
|
||||
}
|
||||
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
||||
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
||||
# MMQ kernel for I-Matrix quantization.
|
||||
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf(
|
||||
x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
|
||||
) -> torch.Tensor:
|
||||
if qweight_type in IMATRIX_QUANT_TYPES:
|
||||
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
|
||||
else:
|
||||
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
|
||||
# HACK: when doing chunked prefill we don't generate output tokens
|
||||
# so input to logits generator is empty which causes invalid parameter
|
||||
if x.shape[0] == 0:
|
||||
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
# enable MMVQ in contiguous batching with batch_size=1
|
||||
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# Use MMQ Kernel if it's available (standard + k-quants)
|
||||
elif qweight_type in MMQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
|
||||
y = x @ weight.T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
# Might be useful if llama.cpp adds a new quantization type.
|
||||
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
return y
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf_fake(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_mul_mat_gguf",
|
||||
op_func=_fused_mul_mat_gguf,
|
||||
fake_impl=_fused_mul_mat_gguf_fake,
|
||||
)
|
||||
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
def _fused_moe_gguf(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
activation_enum = MoEActivation.from_str(activation)
|
||||
|
||||
def act(x: torch.Tensor):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
apply_moe_activation(activation_enum, out, x)
|
||||
return out
|
||||
|
||||
# lazy import to avoid triggering triton import in CPU backend
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
||||
|
||||
out_hidden_states = torch.empty_like(x)
|
||||
# unless we decent expert reuse we are better off running moe_vec kernel
|
||||
if (
|
||||
qweight_type2 in MMQ_QUANT_TYPES
|
||||
and qweight_type in MMQ_QUANT_TYPES
|
||||
and x.shape[0] > 64
|
||||
):
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, BLOCK_SIZE, E
|
||||
)
|
||||
out = ops.ggml_moe_a8(
|
||||
x,
|
||||
w1,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
qweight_type,
|
||||
N,
|
||||
top_k,
|
||||
num_tokens,
|
||||
)
|
||||
out = act(out)
|
||||
out = ops.ggml_moe_a8(
|
||||
out,
|
||||
w2,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
qweight_type2,
|
||||
w2.shape[1],
|
||||
1,
|
||||
num_tokens * top_k,
|
||||
)
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1)
|
||||
)
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
|
||||
out = act(out)
|
||||
|
||||
out = ops.ggml_moe_a8_vec(
|
||||
out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
|
||||
)
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1)
|
||||
)
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"There is no support for fast MoE kernel "
|
||||
"for current quantization method. "
|
||||
"Falling back to slow implementation. "
|
||||
)
|
||||
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
||||
inp = x[tok].reshape((1,) + x.shape[1:])
|
||||
current_hidden_state = None
|
||||
for ww, ii in zip(w, idx):
|
||||
expert_up = w1[ii]
|
||||
|
||||
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
|
||||
out = act(out)
|
||||
|
||||
expert_down = w2[ii]
|
||||
current_state = fused_mul_mat_gguf(
|
||||
out, expert_down, qweight_type2
|
||||
).mul_(ww)
|
||||
if current_hidden_state is None:
|
||||
current_hidden_state = current_state
|
||||
else:
|
||||
current_hidden_state.add_(current_state)
|
||||
out_hidden_states[tok] = current_hidden_state
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def _fused_moe_gguf_fake(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_moe_gguf",
|
||||
op_func=_fused_moe_gguf,
|
||||
fake_impl=_fused_moe_gguf_fake,
|
||||
)
|
||||
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
def _apply_gguf_embedding(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return torch.embedding(qweight, x)
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
x_flat = x.flatten()
|
||||
assert hidden_size == qweight.shape[1] // type_size * block_size
|
||||
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
||||
dequant = ops.ggml_dequantize(
|
||||
quant, qweight_type, hidden_size, x_flat.shape[0], dtype
|
||||
)
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
else:
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
|
||||
|
||||
def _apply_gguf_embedding_fake(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_apply_gguf_embedding",
|
||||
op_func=_apply_gguf_embedding,
|
||||
fake_impl=_apply_gguf_embedding_fake,
|
||||
)
|
||||
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class GGUFLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GGUFConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
self.params_dtype = params_dtype
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
||||
qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight,
|
||||
{
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
"shard_id": [],
|
||||
"shard_id_map": {},
|
||||
},
|
||||
)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("qweight", qweight)
|
||||
|
||||
qweight_type = Parameter(
|
||||
torch.empty(len(output_partition_sizes), dtype=torch.uint8),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight_type,
|
||||
{
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"shard_weight_type": {},
|
||||
"ignore_warning": True,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("qweight_type", qweight_type)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise ValueError(
|
||||
f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
|
||||
)
|
||||
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
|
||||
# materialize the padded weight parameter for CUDA Graph compatibility.
|
||||
self._create_padded_weight_param(layer)
|
||||
|
||||
def _create_padded_weight_param(self, layer: torch.nn.Module):
|
||||
"""Create padded weight parameter for GGUF MergedLinear layer."""
|
||||
qweight = layer.qweight
|
||||
shard_id_map = qweight.shard_id_map
|
||||
shard_id = qweight.shard_id
|
||||
if len(data_container := qweight.data_container) > 1:
|
||||
dtype = {data.dtype for data in data_container}
|
||||
assert len(dtype) == 1, ValueError(
|
||||
f"Data container has mixed dtypes: {dtype}"
|
||||
)
|
||||
dtype = next(iter(dtype))
|
||||
# concat dim0 and pad dim1
|
||||
padded_side = max(x.size(1) for x in data_container)
|
||||
concat_side = sum(x.size(0) for x in data_container)
|
||||
# Pad the quantized weights to dense tensor, and create a map
|
||||
# with the location of each shard in the padded tensor.
|
||||
padded_data = torch.zeros(
|
||||
(concat_side, padded_side), dtype=dtype, device=qweight.device
|
||||
)
|
||||
# (dim0_start, dim0_end, dim1_size)
|
||||
shard_offset_map = dict[str, tuple[int, int, int]]()
|
||||
for idx in shard_id:
|
||||
id_in_container = shard_id_map[idx]
|
||||
start = sum(x.size(0) for x in data_container[:id_in_container])
|
||||
end = start + data_container[id_in_container].size(0)
|
||||
size = data_container[id_in_container].size(1)
|
||||
padded_data[start:end, :size] = data_container[id_in_container]
|
||||
shard_offset_map[idx] = (start, end, size)
|
||||
qweight.data_container.clear()
|
||||
padded_param = Parameter(padded_data, requires_grad=False)
|
||||
set_weight_attrs(padded_param, vars(qweight))
|
||||
set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
|
||||
layer.register_parameter("qweight", padded_param)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
shard_id = layer.qweight.shard_id
|
||||
|
||||
if shard_id:
|
||||
# dequantize shard weights respectively
|
||||
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
||||
qweight = layer.qweight
|
||||
result = []
|
||||
for idx in shard_id:
|
||||
start, end, offset = layer.qweight.shard_offset_map[idx]
|
||||
qweight_type = layer.qweight_type.shard_weight_type[idx]
|
||||
result.append(
|
||||
fused_mul_mat_gguf(
|
||||
x, qweight[start:end, :offset].contiguous(), qweight_type
|
||||
)
|
||||
)
|
||||
out = torch.cat(result, axis=1)
|
||||
else:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
out = fused_mul_mat_gguf(x, qweight, qweight_type)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out
|
||||
|
||||
|
||||
class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: GGUFConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
|
||||
# gate up proj
|
||||
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w13_qweight,
|
||||
{
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
},
|
||||
)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
|
||||
w13_qweight_type = Parameter(
|
||||
torch.empty(1, dtype=torch.uint8), requires_grad=False
|
||||
)
|
||||
set_weight_attrs(
|
||||
w13_qweight_type,
|
||||
{"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
|
||||
)
|
||||
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w13_qweight_type", w13_qweight_type)
|
||||
|
||||
tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
|
||||
# gate down proj
|
||||
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w2_qweight,
|
||||
{
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
},
|
||||
)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
|
||||
w2_qweight_type = Parameter(
|
||||
torch.empty(1, dtype=torch.uint8), requires_grad=False
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_qweight_type,
|
||||
{"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
|
||||
)
|
||||
|
||||
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer.apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused GGUF MoE method."
|
||||
)
|
||||
|
||||
return fused_moe_gguf(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
layer.w2_qweight_type.weight_type,
|
||||
layer.activation.value,
|
||||
)
|
||||
|
||||
|
||||
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
"""Embedding method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
hidden_size = qweight.tensor_shape[1]
|
||||
|
||||
return apply_gguf_embedding(
|
||||
x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
|
||||
)
|
||||
|
||||
|
||||
class GGUFUninitializedParameter(UninitializedParameter):
|
||||
cls_to_become = Parameter
|
||||
data_container: list[torch.Tensor]
|
||||
393
vllm/model_executor/layers/quantization/gptq.py
Normal file
393
vllm/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
else:
|
||||
QuantizationMethods = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, int | bool]],
|
||||
autoround_version: str = "",
|
||||
modules_in_block_to_quantize: list[str] | None = None,
|
||||
checkpoint_format: str = "",
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
super().__init__()
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
if self.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits."
|
||||
)
|
||||
# Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future.
|
||||
# For now, show a warning, since gptq_marlin will be used by default.
|
||||
if self.weight_bits == 4:
|
||||
logger.warning_once(
|
||||
"Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. "
|
||||
"Please switch to gptq_marlin."
|
||||
)
|
||||
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = autoround_version
|
||||
|
||||
# GPTQ v1 and v2 format deals with zero points differently.
|
||||
# Currently GPTQModel stores v1 format checkpoints by default,
|
||||
# but provides the option to set `format="gptq_v2"` in `QuantizeConfig`.
|
||||
self.checkpoint_format = checkpoint_format
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), "
|
||||
f"checkpoint_format={self.checkpoint_format})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
autoround_version = cls.get_from_keys_or(
|
||||
config, ["autoround_version"], default=""
|
||||
)
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None
|
||||
)
|
||||
checkpoint_format = cls.get_from_keys_or(
|
||||
config, ["checkpoint_format"], default=""
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
desc_act,
|
||||
lm_head_quantized,
|
||||
dynamic,
|
||||
autoround_version,
|
||||
modules_in_block_to_quantize,
|
||||
checkpoint_format,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Union["GPTQLinearMethod", "QuantizeMethodBase"] | None:
|
||||
if isinstance(layer, FusedMoE):
|
||||
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
|
||||
# TODO: maybe update this for GPTQv2 format checkpoints
|
||||
config = {
|
||||
"quant_method": "gptq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"sym": True, # GPTQ typically uses symmetric quantization
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
# flatten original modules_in_block_to_quantize
|
||||
self.modules_in_block_to_quantize = [
|
||||
item
|
||||
for sublist in self.modules_in_block_to_quantize
|
||||
for item in sublist
|
||||
]
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_in_block_to_quantize = list(quant_layers)
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
UNUSED = enum.auto()
|
||||
UNINITIALIZED = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
class GPTQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# GPTQ v1 and v2 format deals with zero points differently
|
||||
self.use_v2_format = quant_config.checkpoint_format == "gptq_v2"
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size."
|
||||
)
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size."
|
||||
)
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
exllama_state = ExllamaState.UNINITIALIZED
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (
|
||||
input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1
|
||||
):
|
||||
# For act-order models, we cannot use Exllama for row parallel layer
|
||||
if self.quant_config.desc_act:
|
||||
exllama_state = ExllamaState.UNUSED
|
||||
else:
|
||||
# we need to partition qzeros and scales for exllama kernel
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
g_idx = RowvLLMParameter(
|
||||
data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
qzeros_args = {
|
||||
"data": torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.exllama_state = exllama_state
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty(
|
||||
(0,), dtype=torch.int, device=layer.g_idx.device
|
||||
)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# GPTQ v1 and v2 format checkpoints deals with zero points differently,
|
||||
# and require different gemm kernels.
|
||||
output = ops.gptq_gemm(
|
||||
reshaped_x,
|
||||
layer.qweight,
|
||||
layer.qzeros,
|
||||
layer.scales,
|
||||
layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.use_v2_format,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
929
vllm/model_executor/layers/quantization/gptq_marlin.py
Normal file
929
vllm/model_executor/layers/quantization/gptq_marlin.py
Normal file
@@ -0,0 +1,929 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_dynamic_override,
|
||||
get_linear_quant_method,
|
||||
override_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
check_moe_marlin_supports_layer,
|
||||
get_marlin_input_dtype,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales,
|
||||
marlin_permute_bias,
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
verify_marlin_supported,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_moe_quant_method(
|
||||
config: "GPTQMarlinConfig",
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
moe_method_cls: type,
|
||||
):
|
||||
cloned_config = deepcopy(config)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if (
|
||||
get_dynamic_override( # noqa: E712
|
||||
cloned_config, # noqa: E712
|
||||
layer_name=prefix,
|
||||
)
|
||||
== False
|
||||
): # noqa: E712
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return moe_method_cls(cloned_config, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, int | bool]],
|
||||
full_config: dict[str, Any],
|
||||
modules_in_block_to_quantize: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.is_sym = is_sym
|
||||
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.full_config = full_config
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization config: bits={weight_bits}, sym={is_sym}"
|
||||
)
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = full_config.get("autoround_version", "")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
desc_act,
|
||||
is_sym,
|
||||
lm_head_quantized,
|
||||
dynamic,
|
||||
config,
|
||||
modules_in_block_to_quantize,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
||||
)
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = (
|
||||
"The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "gptq":
|
||||
logger.info(
|
||||
"Detected that the model can run with gptq_marlin"
|
||||
", however you specified quantization=gptq explicitly,"
|
||||
" so forcing gptq. Use quantization=gptq_marlin for"
|
||||
" faster inference"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels."
|
||||
)
|
||||
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
moe_quant_method = get_moe_quant_method(
|
||||
self, layer, prefix, GPTQMarlinMoEMethod
|
||||
)
|
||||
if moe_quant_method is None:
|
||||
return None
|
||||
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
|
||||
quant_method = get_linear_quant_method(
|
||||
self, layer, prefix, GPTQMarlinLinearMethod
|
||||
)
|
||||
if quant_method is None:
|
||||
return None
|
||||
quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
if not (current_platform.is_cuda() or current_platform.is_cpu()):
|
||||
return False
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# Marlin conversion is only valid if required properties are found
|
||||
if num_bits is None or group_size is None or sym is None or desc_act is None:
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(
|
||||
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
||||
)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
# flatten original modules_in_block_to_quantize
|
||||
self.modules_in_block_to_quantize = [
|
||||
item
|
||||
for sublist in self.modules_in_block_to_quantize
|
||||
for item in sublist
|
||||
]
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_in_block_to_quantize = list(quant_layers)
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.input_dtype = None
|
||||
self.quant_type = self.quant_config.quant_type
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(
|
||||
quant_type=self.quant_config.quant_type,
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
input_dtype = self.input_dtype
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype if input_dtype is None else input_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for GPTQMarlinLinearMethod", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Determine sharding
|
||||
if marlin_repeat_scales_on_all_ranks(
|
||||
self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
|
||||
):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each GPU in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Quantized weights
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Activation order
|
||||
g_idx = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
qzeros_args = {
|
||||
"data": torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
|
||||
if scales_and_zp_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="qweight",
|
||||
w_s_param_name="scales",
|
||||
w_zp_param_name="qzeros",
|
||||
w_gidx_param_name="g_idx",
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE Marlin method with quantization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: GPTQMarlinConfig,
|
||||
moe: FusedMoEConfig,
|
||||
) -> None:
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.quant_type.size_bits == 4:
|
||||
self.quant_type = scalar_types.uint4b8
|
||||
elif self.quant_config.quant_type.size_bits == 8:
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
self.input_dtype = None
|
||||
self.use_marlin = True
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.input_dtype = self.input_dtype
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert self.quant_type == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
||||
|
||||
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||
intermediate_size_per_partition == intermediate_size_full
|
||||
)
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
scales_size13 = hidden_size // self.quant_config.group_size
|
||||
w2_scales_size = (
|
||||
intermediate_size_full
|
||||
if self.quant_config.desc_act
|
||||
else intermediate_size_per_partition
|
||||
)
|
||||
scales_size2 = w2_scales_size // self.quant_config.group_size
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
else:
|
||||
scales_size13 = 1
|
||||
scales_size2 = 1
|
||||
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||
|
||||
layer.num_groups_w13 = scales_size13
|
||||
layer.num_groups_w2 = scales_size2
|
||||
|
||||
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
# down_proj (row parallel)
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.quant_config.pack_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
# up_proj scales
|
||||
w13_scales = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_scales = torch.nn.Parameter(
|
||||
torch.empty(num_experts, scales_size2, hidden_size, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
# don't shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act})
|
||||
# up_proj scales
|
||||
w13_qzeros = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_qzeros = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
scales_size2,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
# don't shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act})
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert self.quant_type == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
if self.input_dtype == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True)
|
||||
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True)
|
||||
layer.w13_scales.data = layer.w13_scales.data * 512
|
||||
layer.w2_scales.data = layer.w2_scales.data * 512
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
||||
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Repack weights
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
# The modular kernel expects w13_weight and w2_weight,
|
||||
# but GPTQ uses w13_qweight and w2_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w13_weight = layer.w13_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w2_weight = layer.w2_qweight
|
||||
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
|
||||
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w13_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_input_global_scale",
|
||||
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.w2_scales.shape[1]
|
||||
* (
|
||||
self.quant_config.group_size
|
||||
if self.quant_config.group_size != -1
|
||||
else self.quant_config.pack_factor
|
||||
),
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
|
||||
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w2_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_global_scale",
|
||||
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
gptq_marlin_moe_quant_config,
|
||||
)
|
||||
|
||||
return gptq_marlin_moe_quant_config(
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
weight_bits=self.quant_config.weight_bits,
|
||||
group_size=self.quant_config.group_size,
|
||||
w1_zp=getattr(layer, "w13_qzeros", None)
|
||||
if not self.quant_config.is_sym
|
||||
else None,
|
||||
w2_zp=getattr(layer, "w2_qzeros", None)
|
||||
if not self.quant_config.is_sym
|
||||
else None,
|
||||
w1_bias=getattr(layer, "w13_bias", None),
|
||||
w2_bias=getattr(layer, "w2_bias", None),
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize,
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
"""
|
||||
Select the GEMM implementation for GPTQ-Marlin MoE.
|
||||
|
||||
Returns MarlinExperts configured for GPTQ quantization.
|
||||
This is ONLY used when LoRA is enabled.
|
||||
Without LoRA, GPTQ uses its own apply() method.
|
||||
"""
|
||||
# Only use modular kernels when LoRA is enabled
|
||||
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
|
||||
if not self.moe.is_lora_enabled:
|
||||
raise NotImplementedError(
|
||||
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
|
||||
"Modular kernels are only used for LoRA support."
|
||||
)
|
||||
|
||||
# The modular marlin kernels do not support 8-bit weights.
|
||||
if self.quant_config.weight_bits == 8:
|
||||
raise NotImplementedError(
|
||||
"GPTQ-Marlin kernel does not support 8-bit weights."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
# Ensure quant config is initialized
|
||||
assert self.moe_quant_config is not None, (
|
||||
"moe_quant_config must be initialized before select_gemm_impl"
|
||||
)
|
||||
|
||||
w13_g_idx = (
|
||||
getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None
|
||||
)
|
||||
w2_g_idx = (
|
||||
getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None
|
||||
)
|
||||
w13_g_idx_sort_indices = (
|
||||
getattr(layer, "w13_g_idx_sort_indices", None)
|
||||
if self.quant_config.desc_act
|
||||
else None
|
||||
)
|
||||
w2_g_idx_sort_indices = (
|
||||
getattr(layer, "w2_g_idx_sort_indices", None)
|
||||
if self.quant_config.desc_act
|
||||
else None
|
||||
)
|
||||
|
||||
# Check if using batched expert format (for Expert Parallelism)
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
# For batched format, use BatchedMarlinExperts
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
return BatchedMarlinExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
else:
|
||||
# Standard Marlin experts for GPTQ
|
||||
return MarlinExperts(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
getattr(layer, "w13_bias", None),
|
||||
getattr(layer, "w2_bias", None),
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
442
vllm/model_executor/layers/quantization/inc.py
Normal file
442
vllm/model_executor/layers/quantization/inc.py
Normal file
@@ -0,0 +1,442 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class INCConfig(QuantizationConfig):
|
||||
"""Config class for Intel Neural Compressor (INC).
|
||||
Repo: https://github.com/intel/neural-compressor
|
||||
"""
|
||||
|
||||
SUPPORTED_BITS = {2, 3, 4, 8}
|
||||
SUPPORTED_DTYPES = {"int"}
|
||||
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
|
||||
SUPPORTED_BACKENDS = {
|
||||
"auto",
|
||||
"gptq",
|
||||
"gptq:marlin",
|
||||
"awq",
|
||||
"awq:marlin",
|
||||
"marlin",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
sym: bool = True,
|
||||
packing_format: str = "auto_round:auto_gptq",
|
||||
block_name_to_quantize: str | list[str] | None = None,
|
||||
extra_config: dict[str, Any] | None = None,
|
||||
data_type: str = "int",
|
||||
backend: str = "auto",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if weight_bits not in self.SUPPORTED_BITS:
|
||||
raise ValueError(
|
||||
f"Unsupported weight_bits: {weight_bits}, "
|
||||
f"currently only support {self.SUPPORTED_BITS}."
|
||||
)
|
||||
if data_type not in self.SUPPORTED_DTYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported data_type: {data_type},"
|
||||
f" currently only support {self.SUPPORTED_DTYPES}."
|
||||
)
|
||||
if packing_format not in self.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported packing_format: {packing_format}, "
|
||||
f"currently only support {self.SUPPORTED_FORMATS}."
|
||||
)
|
||||
if backend not in self.SUPPORTED_BACKENDS:
|
||||
raise ValueError(
|
||||
f"Unsupported backend: {backend}, "
|
||||
f"currently only support {self.SUPPORTED_BACKENDS}."
|
||||
)
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.sym = sym
|
||||
self.packing_format = packing_format
|
||||
self.block_name_to_quantize = (
|
||||
block_name_to_quantize.split(",")
|
||||
if isinstance(block_name_to_quantize, str)
|
||||
else block_name_to_quantize
|
||||
)
|
||||
self.extra_config = extra_config
|
||||
self.data_type = data_type
|
||||
self.backend = backend
|
||||
self.pack_factor = Fraction(32, weight_bits)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"INCConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, sym={self.sym})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "inc"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantization_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "INCConfig":
|
||||
return cls(
|
||||
weight_bits=cls.get_from_keys(config, ["bits"]),
|
||||
group_size=cls.get_from_keys(config, ["group_size"]),
|
||||
sym=cls.get_from_keys(config, ["sym"]),
|
||||
packing_format=cls.get_from_keys_or(
|
||||
config, ["packing_format"], "auto_round:auto_gptq"
|
||||
),
|
||||
block_name_to_quantize=cls.get_from_keys_or(
|
||||
config, ["block_name_to_quantize", "to_quant_block_names"], None
|
||||
),
|
||||
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
|
||||
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
|
||||
backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], "auto"),
|
||||
)
|
||||
|
||||
def get_layer_config(self, layer, layer_name: str):
|
||||
def get_config(name: str, quantized: bool = True):
|
||||
if not self.extra_config:
|
||||
return (
|
||||
self.weight_bits if quantized else 16,
|
||||
self.group_size if quantized else -1,
|
||||
self.sym if quantized else True,
|
||||
)
|
||||
|
||||
# exact match first
|
||||
if name in self.extra_config:
|
||||
cfg = self.extra_config[name]
|
||||
return (
|
||||
cfg.get("bits", self.weight_bits if quantized else 16),
|
||||
cfg.get("group_size", self.group_size if quantized else -1),
|
||||
cfg.get("sym", self.sym if quantized else True),
|
||||
)
|
||||
|
||||
REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\")
|
||||
for pattern, cfg in self.extra_config.items():
|
||||
if not isinstance(pattern, str) or not any(
|
||||
c in REGEX_SPECIAL_CHARS for c in pattern
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
if re.search(re.compile(pattern), name) is not None:
|
||||
return (
|
||||
cfg.get("bits", self.weight_bits if quantized else 16),
|
||||
cfg.get("group_size", self.group_size if quantized else -1),
|
||||
cfg.get("sym", self.sym if quantized else True),
|
||||
)
|
||||
except re.error:
|
||||
# Invalid regex, ignore.
|
||||
continue
|
||||
|
||||
return (
|
||||
self.weight_bits if quantized else 16,
|
||||
self.group_size if quantized else -1,
|
||||
self.sym if quantized else True,
|
||||
)
|
||||
|
||||
# 1. Exact match from config
|
||||
if self.extra_config and layer_name in self.extra_config:
|
||||
return get_config(layer_name)
|
||||
|
||||
# 2. Determine whether layer should be quantized
|
||||
quantized = not isinstance(layer, ParallelLMHead)
|
||||
if self.block_name_to_quantize:
|
||||
quantized = any(
|
||||
layer_name.startswith(name) for name in self.block_name_to_quantize
|
||||
)
|
||||
|
||||
# 3. Handle fused MoE
|
||||
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower():
|
||||
moe_configs = [
|
||||
get_config(name, quantized)
|
||||
for name in self.extra_config
|
||||
if name.startswith(layer_name)
|
||||
]
|
||||
if moe_configs:
|
||||
if len(set(moe_configs)) == 1:
|
||||
return moe_configs[0]
|
||||
raise ValueError(
|
||||
f"Fused MoE layer '{layer_name}' requires "
|
||||
f"consistent quant config for all sub-layers"
|
||||
)
|
||||
|
||||
# 4. Handle fused QKV or other patterns
|
||||
if self.extra_config:
|
||||
for fusion_key, sub_keys in self.packed_modules_mapping.items():
|
||||
if fusion_key in layer_name and layer_name.count(fusion_key) == 1:
|
||||
sub_names = [
|
||||
layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys
|
||||
]
|
||||
sub_configs = [get_config(name, quantized) for name in sub_names]
|
||||
if len(set(sub_configs)) == 1:
|
||||
return sub_configs[0]
|
||||
raise ValueError(
|
||||
f"Fused module '{layer_name}' requires "
|
||||
f"consistent quant config for {sub_names}"
|
||||
)
|
||||
|
||||
# 5. Fallback or try a regular expression match
|
||||
return get_config(layer_name, quantized)
|
||||
|
||||
def check_quantized(self, weight_bits: int) -> bool:
|
||||
return weight_bits < 16
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.block_name_to_quantize is not None:
|
||||
self.block_name_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.block_name_to_quantize
|
||||
)
|
||||
if self.extra_config is not None:
|
||||
self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config)
|
||||
|
||||
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
check_moe_marlin_supports_layer,
|
||||
)
|
||||
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix,
|
||||
layer.__class__.__name__,
|
||||
weight_bits,
|
||||
group_size,
|
||||
sym,
|
||||
)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
AWQ_TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported(
|
||||
AWQ_TYPE_MAP[weight_bits], group_size, not sym
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
||||
layer, group_size
|
||||
)
|
||||
|
||||
else:
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig,
|
||||
AWQMarlinLinearMethod,
|
||||
AWQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
quant_args_marlin = AWQMarlinConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
lm_head_quantized=False,
|
||||
full_config={},
|
||||
modules_to_not_convert=[],
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.awq import (
|
||||
AWQConfig,
|
||||
AWQLinearMethod,
|
||||
)
|
||||
|
||||
quant_args = AWQConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return AWQMarlinMoEMethod(quant_args_marlin, layer.moe_config)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
config = {
|
||||
"quant_method": "awq",
|
||||
"bits": weight_bits,
|
||||
"group_size": group_size,
|
||||
"zero_point": not sym,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if use_marlin:
|
||||
return AWQMarlinLinearMethod(quant_args_marlin)
|
||||
else:
|
||||
return AWQLinearMethod(quant_args)
|
||||
return None
|
||||
|
||||
def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
check_moe_marlin_supports_layer,
|
||||
)
|
||||
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix,
|
||||
layer.__class__.__name__,
|
||||
weight_bits,
|
||||
group_size,
|
||||
sym,
|
||||
)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
GPTQ_TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported(
|
||||
GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym
|
||||
)
|
||||
if isinstance(layer, FusedMoE):
|
||||
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
||||
layer, group_size
|
||||
)
|
||||
else:
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig,
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
quant_args_marlin = GPTQMarlinConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
full_config={},
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.gptq import (
|
||||
GPTQConfig,
|
||||
GPTQLinearMethod,
|
||||
)
|
||||
|
||||
quant_args = GPTQConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config)
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config,
|
||||
)
|
||||
|
||||
config = {
|
||||
"quant_method": "gptq",
|
||||
"bits": weight_bits,
|
||||
"group_size": group_size,
|
||||
"sym": sym,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if use_marlin:
|
||||
return GPTQMarlinLinearMethod(quant_args_marlin)
|
||||
else:
|
||||
return GPTQLinearMethod(quant_args)
|
||||
|
||||
return None
|
||||
|
||||
def apply_ipex_quant_layer(self, layer, prefix: str):
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
raise NotImplementedError(
|
||||
"INC quantization is not supported during xpu kernel migration."
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
|
||||
if prefix and self.extra_config:
|
||||
for layer_name in self.extra_config:
|
||||
if (
|
||||
layer_name == prefix or layer_name == f"model.{prefix}"
|
||||
) and self.extra_config[layer_name].get("bits", 16) >= 16:
|
||||
return UnquantizedLinearMethod()
|
||||
if (
|
||||
current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
or self.backend == "ipex"
|
||||
):
|
||||
return self.apply_ipex_quant_layer(layer, prefix)
|
||||
if "gptq" in self.packing_format or "gptq" in self.backend:
|
||||
return self.apply_gptq_quant_layer(layer, prefix)
|
||||
if "awq" in self.packing_format or "awq" in self.backend:
|
||||
return self.apply_awq_quant_layer(layer, prefix)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> "QuantizationMethods | None":
|
||||
"""Override the `auto-round` method to `inc`."""
|
||||
is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round"
|
||||
if is_auto_round_format:
|
||||
return cls.get_name()
|
||||
return None
|
||||
250
vllm/model_executor/layers/quantization/input_quant_fp8.py
Normal file
250
vllm/model_executor/layers/quantization/input_quant_fp8.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
get_fp8_min_max,
|
||||
group_broadcast,
|
||||
prep_scale_for_group_broadcast,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
|
||||
_FP8_DTYPE = current_platform.fp8_dtype()
|
||||
_FP8_MIN, _FP8_MAX = get_fp8_min_max()
|
||||
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
|
||||
|
||||
|
||||
# --8<-- [start:quant_fp8]
|
||||
@CustomOp.register("quant_fp8")
|
||||
class QuantFP8(CustomOp):
|
||||
"""
|
||||
Quantize input tensor to FP8 (per-tensor, per-token, per-channel, or per-group).
|
||||
This CustomOp supports both static and dynamic quantization.
|
||||
"""
|
||||
|
||||
# --8<-- [end:quant_fp8]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: int | None = None,
|
||||
column_major_scales: bool = False,
|
||||
tma_aligned_scales: bool = False,
|
||||
use_ue8m0: bool | None = None, # for Torch compile
|
||||
compile_native: bool = True,
|
||||
):
|
||||
"""
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
|
||||
PER_CHANNEL, or arbitrary block size)
|
||||
:param num_token_padding: Pad the token dimension of output to this
|
||||
size
|
||||
:param tma_aligned_scales: For group quantization, output scales in
|
||||
TMA-aligned layout
|
||||
:param column_major_scales: For group quantization, output scales in
|
||||
column major format
|
||||
:param compile_native: Manually compile forward_native if compile mode > None
|
||||
"""
|
||||
super().__init__(compile_native=compile_native)
|
||||
self.static = static
|
||||
self.group_shape = group_shape
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
self.num_token_padding = num_token_padding
|
||||
self.column_major_scales = column_major_scales
|
||||
self.tma_aligned_scales = tma_aligned_scales
|
||||
self.use_ue8m0 = is_deep_gemm_e8m0_used() if use_ue8m0 is None else use_ue8m0
|
||||
self.use_deep_gemm_supported = is_deep_gemm_supported()
|
||||
|
||||
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
|
||||
self.is_group_quant = group_shape.is_per_group()
|
||||
if self.is_group_quant:
|
||||
self.group_size = group_shape.col
|
||||
else:
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
if not static:
|
||||
assert group_shape in (GroupShape.PER_TOKEN, GroupShape.PER_TENSOR), (
|
||||
"Only per-token or per-tensor scales are supported for dynamic "
|
||||
"non-group quantization."
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
use_triton: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils
|
||||
|
||||
if (
|
||||
self.is_group_quant
|
||||
and self.use_deep_gemm_supported
|
||||
and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0)
|
||||
):
|
||||
return fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
x,
|
||||
group_size=self.group_size,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
|
||||
if self.is_group_quant and not self.static:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
|
||||
return fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size=self.group_size,
|
||||
column_major_scales=self.column_major_scales,
|
||||
tma_aligned_scales=self.tma_aligned_scales,
|
||||
dtype=_FP8_DTYPE,
|
||||
use_ue8m0=self.use_ue8m0,
|
||||
)
|
||||
|
||||
assert (scale is not None) == self.static
|
||||
assert scale_ub is None or (
|
||||
not self.static
|
||||
and self.group_shape == GroupShape.PER_TOKEN
|
||||
and scale_ub.numel() == 1
|
||||
)
|
||||
|
||||
return ops.scaled_fp8_quant(
|
||||
x,
|
||||
scale,
|
||||
num_token_padding=self.num_token_padding,
|
||||
scale_ub=scale_ub,
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||
group_shape=(self.group_shape.row, self.group_shape.col)
|
||||
if self.static
|
||||
else None,
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
use_triton: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.is_group_quant and use_triton:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
|
||||
return torch.ops.vllm.triton_per_token_group_quant_fp8(x, self.group_size)
|
||||
|
||||
use_aiter_quant = self.use_aiter and scale_ub is None and x.is_contiguous()
|
||||
use_aiter_per_tensor_quant = (
|
||||
use_aiter_quant and self.group_shape.is_per_tensor()
|
||||
)
|
||||
use_aiter_per_token_quant = use_aiter_quant and self.group_shape.is_per_token()
|
||||
|
||||
use_aiter_per_group_quant = use_aiter_quant and self.group_shape.is_per_group()
|
||||
|
||||
if use_aiter_per_group_quant:
|
||||
return rocm_aiter_ops.group_fp8_quant(x, self.group_size)
|
||||
if use_aiter_per_tensor_quant:
|
||||
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
|
||||
if use_aiter_per_token_quant:
|
||||
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
|
||||
|
||||
# Fallback to native implementation for group quantization.
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
return self._quantize_group_native(x)
|
||||
|
||||
# Fallback to CUDA implementation
|
||||
return self.forward_cuda(x, scale, scale_ub)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
use_triton: bool = False,
|
||||
):
|
||||
if self.is_group_quant and not self.static:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
return self._quantize_group_native(x)
|
||||
|
||||
assert (scale is not None) == self.static
|
||||
assert scale_ub is None or (
|
||||
not self.static
|
||||
and self.group_shape == GroupShape.PER_TOKEN
|
||||
and scale_ub.numel() == 1
|
||||
)
|
||||
|
||||
if scale is None:
|
||||
if self.group_shape == GroupShape.PER_TOKEN:
|
||||
x_max, _ = x.abs().max(dim=-1)
|
||||
x_max = x_max.unsqueeze(-1).to(torch.float32)
|
||||
if scale_ub is not None:
|
||||
x_max = x_max.clamp(max=scale_ub)
|
||||
else:
|
||||
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
|
||||
|
||||
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
else:
|
||||
scale = prep_scale_for_group_broadcast(scale, x, self.group_shape)
|
||||
|
||||
# Even for dynamic per-token scales,
|
||||
# reciprocal performs slightly better than division
|
||||
out = (
|
||||
x.to(torch.float32)
|
||||
* group_broadcast(scale.to(torch.float32), x.shape[-2:]).reciprocal()
|
||||
)
|
||||
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
# This currently generates an extra Triton kernel in compilation.
|
||||
# Fortunately, we don't use padding if compiling.
|
||||
# TODO(luka): benchmark torch._scaled_mm to hopefully remove padding
|
||||
# in general.
|
||||
if self.num_token_padding is not None:
|
||||
padding = max(self.num_token_padding - out.size(0), 0)
|
||||
out = F.pad(out, (0, 0, 0, padding), "constant", 0.0)
|
||||
|
||||
return out, scale
|
||||
|
||||
def _quantize_group_native(
|
||||
self, x: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_shape = x.shape
|
||||
hidden_dim = x.shape[-1]
|
||||
num_groups = (hidden_dim + self.group_size - 1) // self.group_size
|
||||
padded_dim = num_groups * self.group_size
|
||||
|
||||
if padded_dim != hidden_dim:
|
||||
padding = padded_dim - hidden_dim
|
||||
x = F.pad(x, (0, padding), mode="constant", value=0.0)
|
||||
|
||||
x_grouped = x.view(-1, num_groups, self.group_size)
|
||||
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
|
||||
scales_raw = absmax / _FP8_MAX
|
||||
if self.use_ue8m0:
|
||||
scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw)))
|
||||
scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
|
||||
x_scaled = x_grouped / scales
|
||||
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
x_quant = x_quant.view(-1, padded_dim)
|
||||
if padded_dim != hidden_dim:
|
||||
x_quant = x_quant[..., :hidden_dim]
|
||||
x_quant = x_quant.view(orig_shape)
|
||||
|
||||
scales = scales.squeeze(-1)
|
||||
scales = scales.reshape(orig_shape[:-1] + (num_groups,))
|
||||
|
||||
if self.column_major_scales:
|
||||
scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2)
|
||||
|
||||
return x_quant, scales
|
||||
157
vllm/model_executor/layers/quantization/kv_cache.py
Normal file
157
vllm/model_executor/layers/quantization/kv_cache.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import is_quantized_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"""
|
||||
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
||||
Attention layer to support loading those scaling factors from checkpoints.
|
||||
The k/v_scale will be used to:
|
||||
- quantize k/v_cache entries before saving them to the cache
|
||||
- dequantize k/v_cache entries before fetching them from the cache
|
||||
|
||||
:param quant_config: the appropriate QuantizationConfig
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QuantizationConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Create "weight" (aka q_scale, k_scale and v_scale)
|
||||
for an attention layer.
|
||||
"""
|
||||
# Initialize the Q and KV cache scales to -1.0, an invalid value.
|
||||
# If the q and k/v_scales appear in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
# Initialize P = softmax(QK^T) scales
|
||||
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# skip if there are no weights to process (for example, weight reloading)
|
||||
if not hasattr(layer, "q_scale"):
|
||||
assert not hasattr(layer, "k_scale")
|
||||
assert not hasattr(layer, "v_scale")
|
||||
assert not hasattr(layer, "prob_scale")
|
||||
return
|
||||
|
||||
# If the kv-cache is not quantized, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
# No need to process kv scales after loading if we are going to
|
||||
# calculate them on the fly.
|
||||
if (
|
||||
is_quantized_kv_cache(layer.kv_cache_dtype)
|
||||
and not layer.calculate_kv_scales
|
||||
):
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if current_platform.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
k_scale = 1.0
|
||||
v_scale = 1.0
|
||||
else:
|
||||
# If we find a single kv_scale in the checkpoint, we remap
|
||||
# kv_scale to k_scale during weight loading, and duplicate
|
||||
# k_scale to v_scale here
|
||||
assert layer.k_scale > 0.0
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if current_platform.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
|
||||
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
||||
raise ValueError(
|
||||
"Only support per-tensor scaling factor for fp8 KV cache"
|
||||
)
|
||||
|
||||
if layer.q_scale < 0.0:
|
||||
logger.warning_once(
|
||||
"Checkpoint does not provide a q scaling factor. "
|
||||
"Setting it to k_scale. This only matters for "
|
||||
"FP8 Attention backends (flash-attn or flashinfer)."
|
||||
)
|
||||
layer._q_scale.copy_(k_scale)
|
||||
layer._q_scale_float = k_scale
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale.copy_(k_scale)
|
||||
layer._v_scale.copy_(v_scale)
|
||||
layer._k_scale_float = k_scale
|
||||
layer._v_scale_float = v_scale
|
||||
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
||||
logger.warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. "
|
||||
"If this is unintended, verify that k/v_scale "
|
||||
"scaling factors are properly set in the checkpoint."
|
||||
)
|
||||
|
||||
if layer.q_scale > 0.0:
|
||||
q_scale = layer.q_scale
|
||||
if current_platform.is_fp8_fnuz():
|
||||
q_scale *= 2
|
||||
layer.calculate_kv_scales = False
|
||||
else:
|
||||
q_scale = 1.0
|
||||
if layer.prob_scale > 0.0:
|
||||
prob_scale = layer.prob_scale
|
||||
if current_platform.is_fp8_fnuz():
|
||||
prob_scale *= 2
|
||||
else:
|
||||
prob_scale = 1.0
|
||||
|
||||
is_singleton_float = (
|
||||
lambda x: isinstance(x, float)
|
||||
or isinstance(x, torch.Tensor)
|
||||
and x.numel() == 1
|
||||
and x.is_floating_point()
|
||||
)
|
||||
if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale):
|
||||
raise ValueError(
|
||||
"Only support per-tensor scaling factorfor fp8-quantized Q/prob"
|
||||
)
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._q_scale.copy_(q_scale)
|
||||
layer._q_scale_float = (
|
||||
q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale
|
||||
)
|
||||
|
||||
layer._prob_scale.copy_(prob_scale)
|
||||
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0):
|
||||
logger.warning_once(
|
||||
f"Using uncalibrated q_scale {q_scale} and/or prob_scale "
|
||||
f"{prob_scale} with fp8 attention. This may cause accuracy "
|
||||
"issues. Please make sure q/prob scaling factors are "
|
||||
"available in the fp8 checkpoint."
|
||||
)
|
||||
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
del layer.q_scale
|
||||
del layer.prob_scale
|
||||
1843
vllm/model_executor/layers/quantization/modelopt.py
Normal file
1843
vllm/model_executor/layers/quantization/modelopt.py
Normal file
File diff suppressed because it is too large
Load Diff
517
vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
517
vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
@@ -0,0 +1,517 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supports_layer,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class MoeWNA16Config(QuantizationConfig):
|
||||
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
linear_quant_method: str,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
has_zp: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: list[str] | None,
|
||||
full_config: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.bit8_pack_factor = 8 // self.weight_bits
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.linear_quant_method = linear_quant_method
|
||||
self.full_config = full_config
|
||||
self.use_marlin = False
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
||||
|
||||
if self.linear_quant_method == "gptq":
|
||||
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
|
||||
elif self.linear_quant_method in ("awq", "awq_marlin"):
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (
|
||||
-1 if capability_tuple is None else capability_tuple.to_int()
|
||||
)
|
||||
awq_min_capability = AWQConfig.get_min_capability()
|
||||
if device_capability < awq_min_capability:
|
||||
raise ValueError(
|
||||
"The quantization method moe_wna16 + awq is not supported "
|
||||
"for the current GPU. "
|
||||
f"Minimum capability: {awq_min_capability}. "
|
||||
f"Current capability: {device_capability}."
|
||||
)
|
||||
self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(full_config)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
|
||||
if modules_to_not_convert is None:
|
||||
self.modules_to_not_convert = []
|
||||
else:
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "moe_wna16"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
|
||||
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
if linear_quant_method == "gptq":
|
||||
has_zp = not cls.get_from_keys(config, ["sym"])
|
||||
modules_to_not_convert = []
|
||||
elif linear_quant_method in ("awq", "awq_marlin"):
|
||||
has_zp = cls.get_from_keys(config, ["zero_point"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
|
||||
return cls(
|
||||
linear_quant_method,
|
||||
weight_bits,
|
||||
group_size,
|
||||
has_zp,
|
||||
lm_head_quantized,
|
||||
modules_to_not_convert,
|
||||
config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
||||
if can_convert and user_quant == "moe_wna16":
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (
|
||||
-1 if capability_tuple is None else capability_tuple.to_int()
|
||||
)
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
|
||||
awq_min_capability = AWQConfig.get_min_capability()
|
||||
|
||||
gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8]
|
||||
awq_compatible = (
|
||||
quant_method == "awq"
|
||||
and num_bits == 4
|
||||
and device_capability >= awq_min_capability
|
||||
)
|
||||
|
||||
return gptq_compatible or awq_compatible
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
||||
if isinstance(layer, FusedMoE):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, LinearBase):
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig,
|
||||
)
|
||||
|
||||
if self.linear_quant_method == "gptq":
|
||||
if self.use_marlin:
|
||||
return GPTQMarlinConfig.from_config(
|
||||
self.full_config
|
||||
).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return GPTQConfig.from_config(self.full_config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
elif self.linear_quant_method in ("awq", "awq_marlin"):
|
||||
if self.use_marlin and check_marlin_supports_layer(
|
||||
layer, self.group_size
|
||||
):
|
||||
return AWQMarlinConfig.from_config(
|
||||
self.full_config
|
||||
).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return AWQConfig.from_config(self.full_config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return MoeWNA16Method(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
class MoeWNA16Method(FusedMoEMethodBase):
|
||||
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: MoeWNA16Config, moe: "FusedMoEConfig") -> None:
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.quant_config = self.quant_config
|
||||
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||
group_size = self.quant_config.group_size
|
||||
group_size_div_factor = 1
|
||||
|
||||
# make intermediate_size and hidden_size divisible by group_size
|
||||
# we reduce the group size to ensure that
|
||||
# and we would repeat the loaded_weight later
|
||||
while intermediate_size_per_partition % group_size or hidden_size % group_size:
|
||||
group_size = group_size // 2
|
||||
group_size_div_factor *= 2
|
||||
assert group_size >= 32
|
||||
layer.group_size = group_size
|
||||
layer.group_size_div_factor = group_size_div_factor
|
||||
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False})
|
||||
|
||||
assert "weight_loader" in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader)
|
||||
extra_weight_attrs["weight_loader"] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // bit8_pack_factor,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // bit8_pack_factor,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
w13_scales = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // group_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.has_zp:
|
||||
w13_qzeros = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition // bit8_pack_factor,
|
||||
hidden_size // group_size,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size // bit8_pack_factor,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.linear_quant_method == "gptq":
|
||||
# some param are unused, but we need to init them in order to
|
||||
# load weights
|
||||
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
||||
if not self.quant_config.has_zp:
|
||||
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
||||
for key in invalid_param_keys:
|
||||
param = torch.nn.Parameter(
|
||||
torch.empty((0,), dtype=torch.int32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter(key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
assert weight_bits == 4 or weight_bits == 8
|
||||
config_builder = (
|
||||
int4_w4a16_moe_quant_config
|
||||
if weight_bits == 4
|
||||
else int8_w8a16_moe_quant_config
|
||||
)
|
||||
|
||||
return config_builder(
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size],
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Only SiLU activation is supported, not {layer.activation}."
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
def convert_awq_tensor(tensor, tensor_type):
|
||||
# convert awq qweight/qzeros to a standard format (assume int4)
|
||||
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
||||
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
||||
# (n // pack_factor_bit8, k // group_size)
|
||||
# pack_factor_bit32 = 32 // weight_bits
|
||||
# pack_factor_bit8 = 8 // weight_bits
|
||||
|
||||
# 0. suppose origin shape (a, b), dtype int32
|
||||
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
||||
size0 = tensor.size(0)
|
||||
tensor = tensor.view(torch.uint8)
|
||||
|
||||
# 2. unpack to uint4 (only when weight_bits == 4)
|
||||
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
||||
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
|
||||
# 3. change order, see
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
||||
# shape -> (a, 4 * b * pack_factor_bit8)
|
||||
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
||||
tensor = tensor.view(size0, -1)
|
||||
|
||||
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
||||
tensor = tensor.T.contiguous()
|
||||
|
||||
# 5. repack (only when weight_bits == 4)
|
||||
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
||||
# qzeros shape -> (4 * b, a)
|
||||
|
||||
if tensor_type == "qweight":
|
||||
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
||||
elif tensor_type == "qzeros":
|
||||
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
||||
return tensor
|
||||
|
||||
def convert_gptq_int4_qzeros(tensor):
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
tensor = tensor + 1
|
||||
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
||||
return tensor
|
||||
|
||||
def moe_wna16_weight_loader(
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
return_success: bool = False,
|
||||
):
|
||||
if "g_idx" in weight_name:
|
||||
return False if return_success else None
|
||||
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
||||
return False if return_success else None
|
||||
|
||||
device = get_tp_group().device
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
|
||||
# convert gptq and awq weight to a standard format
|
||||
# awq_marlin uses the same weight format as awq
|
||||
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
|
||||
assert layer.quant_config.weight_bits == 4
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
|
||||
elif "zeros" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
elif layer.quant_config.linear_quant_method == "gptq":
|
||||
assert layer.quant_config.weight_bits in [4, 8]
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = loaded_weight.T.contiguous().view(torch.uint8)
|
||||
elif "zeros" in weight_name:
|
||||
# add 1 to gptq qzeros to align with awq
|
||||
loaded_weight = loaded_weight.view(torch.uint8)
|
||||
if layer.quant_config.weight_bits == 4:
|
||||
loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T
|
||||
else:
|
||||
loaded_weight = loaded_weight.T + 1
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
# repeat the qzeros/scales to fit new group size
|
||||
if (
|
||||
layer.group_size_div_factor > 1
|
||||
and "qzeros" in weight_name
|
||||
or "scales" in weight_name
|
||||
):
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor, 1
|
||||
)
|
||||
|
||||
if "w13_qzeros" in weight_name:
|
||||
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
|
||||
tp_rank
|
||||
]
|
||||
if shard_id == "w1":
|
||||
param.data[expert_id, : shard_size // 2] = tensor
|
||||
else:
|
||||
param.data[expert_id, shard_size // 2 :] = tensor
|
||||
return True if return_success else None
|
||||
elif "w2_qzeros" in weight_name:
|
||||
param.data[expert_id] = loaded_weight.view(
|
||||
loaded_weight.size(0), layer.tp_size, -1
|
||||
)[:, tp_rank]
|
||||
return True if return_success else None
|
||||
else:
|
||||
# Delegate to the original loader, passing return_success
|
||||
return weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
return_success=return_success,
|
||||
)
|
||||
|
||||
return moe_wna16_weight_loader
|
||||
1168
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
1168
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
File diff suppressed because it is too large
Load Diff
319
vllm/model_executor/layers/quantization/petit.py
Normal file
319
vllm/model_executor/layers/quantization/petit.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.petit_utils import (
|
||||
apply_petit_nvfp4_linear,
|
||||
prepare_nvfp4_layer_for_petit,
|
||||
verify_petit_nvfp4_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Configuration class to support the NVFP4 quantized model
|
||||
# generated by the ModelOpt quantization tool
|
||||
class PetitNvFp4Config(QuantizationConfig):
|
||||
"""Config class for Petit FP4."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_nvfp4_serialized: bool = False,
|
||||
kv_cache_quant_algo: str | None = None,
|
||||
group_size: int | None = None,
|
||||
exclude_modules: list[str] | None = None,
|
||||
) -> None:
|
||||
self._check_hardware_support()
|
||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||
if is_checkpoint_nvfp4_serialized:
|
||||
logger.warning(
|
||||
"Detected nvfp4 checkpoint. Please note that the "
|
||||
"format is experimental and subject to change."
|
||||
)
|
||||
self.group_size = group_size
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
def _check_hardware_support(self) -> None:
|
||||
"""
|
||||
Verifies that the current hardware is supported by the Petit backend.
|
||||
This backend is specifically designed for AMD GPUs and is not
|
||||
supported on the CUDA platform.
|
||||
"""
|
||||
# This check ensures the code is NOT running on an NVIDIA GPU.
|
||||
if current_platform.is_cuda():
|
||||
raise ValueError(
|
||||
"The 'petit' quantization backend is designed for AMD GPUs "
|
||||
"and is not supported on the CUDA platform. For NVIDIA GPUs, "
|
||||
"please use a different quantization method such as FP8, AWQ, "
|
||||
"or GPTQ."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "petit_nvfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Petit supports the gfx90a and gfx942 GPUs
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
|
||||
qc = cls.get_from_keys(config, ["quantization"])
|
||||
|
||||
quant_method_raw = qc.get("quant_algo")
|
||||
if not isinstance(quant_method_raw, str) or not quant_method_raw:
|
||||
raise ValueError("Missing or invalid 'quant_algo' in quantization config.")
|
||||
quant_method = quant_method_raw.upper()
|
||||
|
||||
group_size_raw = qc.get("group_size")
|
||||
if not isinstance(group_size_raw, int):
|
||||
raise ValueError(
|
||||
"Missing or invalid 'group_size' (int) in hf_quant_config.json."
|
||||
)
|
||||
group_size = group_size_raw
|
||||
|
||||
verify_petit_nvfp4_supported(quant_method, group_size)
|
||||
|
||||
kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto"
|
||||
if not isinstance(kv_cache_quant_algo_raw, str):
|
||||
raise ValueError("'kv_cache_quant_algo' must be a string if provided.")
|
||||
kv_cache_quant_algo = kv_cache_quant_algo_raw
|
||||
|
||||
exclude_raw = qc.get("exclude_modules", [])
|
||||
if exclude_raw is None:
|
||||
exclude_modules: list[str] = []
|
||||
elif isinstance(exclude_raw, list) and all(
|
||||
isinstance(x, str) for x in exclude_raw
|
||||
):
|
||||
exclude_modules = exclude_raw
|
||||
else:
|
||||
raise ValueError("'exclude_modules' must be a list[str] (or omitted).")
|
||||
|
||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||
|
||||
return cls(
|
||||
is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
|
||||
kv_cache_quant_algo=kv_cache_quant_algo,
|
||||
group_size=group_size,
|
||||
exclude_modules=exclude_modules,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
if not current_platform.is_rocm():
|
||||
return None
|
||||
|
||||
qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
|
||||
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
|
||||
if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
|
||||
return cls.get_name() # "petit_nvfp4"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool:
|
||||
qc = quant_config.get("quantization", quant_config)
|
||||
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
|
||||
return algo == "NVFP4"
|
||||
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: list[str]) -> bool:
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
exclude = self.require_exclude_modules()
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, exclude) or self.is_layer_excluded(
|
||||
prefix, exclude
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return PetitNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return PetitFp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def require_group_size(self) -> int:
|
||||
if self.group_size is None:
|
||||
logger.warning("group_size not set; defaulting to 16 for NVFP4.")
|
||||
return 16
|
||||
return self.group_size
|
||||
|
||||
def require_kv_cache_quant_algo(self) -> str:
|
||||
return self.kv_cache_quant_algo or "auto"
|
||||
|
||||
def require_exclude_modules(self) -> list[str]:
|
||||
return list(self.exclude_modules or [])
|
||||
|
||||
|
||||
class PetitFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PetitNvFp4Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
|
||||
class PetitNvFp4LinearMethod(LinearMethodBase):
|
||||
"""Linear method for NVFP4.
|
||||
Supports loading NVFP4 checkpoints with the following structure:
|
||||
|
||||
|Tensor Name | datatype | shape |
|
||||
|----------------------------------------------------|
|
||||
|input_scale | torch.float32 | scalar |
|
||||
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|
||||
|weight_scale | FP8-E4M3 | [X, Y] |
|
||||
|weight_scale_2 | torch.float32 | scalar |
|
||||
|
||||
The weights are quantized per block of 16 elements.
|
||||
Args: quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PetitNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError(
|
||||
"NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported."
|
||||
)
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
if input_size_per_partition % 16 != 0:
|
||||
raise ValueError(
|
||||
"Unsupported model when in features size is not multiple of 16"
|
||||
)
|
||||
|
||||
weight_dtype = (
|
||||
torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_nvfp4_serialized
|
||||
else params_dtype
|
||||
)
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
# 2 fp4 data is packed in one uint8 in the input dimension
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale_2", weight_scale_2)
|
||||
|
||||
group_size = self.quant_config.require_group_size()
|
||||
weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // group_size,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
input_scale_2 = layer.input_scale.max().to(torch.float32)
|
||||
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
|
||||
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
|
||||
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
|
||||
layer.alpha = Parameter(
|
||||
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
||||
)
|
||||
|
||||
prepare_nvfp4_layer_for_petit(layer)
|
||||
del layer.input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_petit_nvfp4_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
137
vllm/model_executor/layers/quantization/ptpc_fp8.py
Normal file
137
vllm/model_executor/layers/quantization/ptpc_fp8.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PTPCFp8Config(Fp8Config):
|
||||
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: list[str] | None = None,
|
||||
) -> None:
|
||||
if not current_platform.is_rocm():
|
||||
raise ValueError("ptpc_fp8 quantization is supported only on ROCm.")
|
||||
|
||||
if not current_platform.has_device_capability(94):
|
||||
raise ValueError(
|
||||
"ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501
|
||||
)
|
||||
if activation_scheme == "static":
|
||||
raise ValueError("ptpc_fp8 as of now only support dynamic quantization.")
|
||||
|
||||
super().__init__(
|
||||
is_checkpoint_fp8_serialized=False,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ptpc_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
return cls(activation_scheme=activation_scheme, ignored_layers=ignored_layers)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
return PTPCFp8LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
"""Linear method for Per-Token and Per-Channel FP8 Quantization.
|
||||
Only supports loading quantized BF16 model checkpoints with dynamic
|
||||
activation scaling. To load FP16 model checkpoints, user must specify
|
||||
to convert the FP16 model weight loading into BF16.
|
||||
The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Limitations:
|
||||
1. Only support float8_e4m3fnuz data type due to the limitation of
|
||||
torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041)
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PTPCFp8Config):
|
||||
assert current_platform.is_rocm(), (
|
||||
"PTPCFp8LinearMethod is only supported on ROCm."
|
||||
)
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8DynamicTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
assert layer.weight.data.dtype not in (torch.float16, torch.float32), (
|
||||
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support "
|
||||
f"output dtype of bfloat16. {layer.weight.data.dtype} is specified."
|
||||
)
|
||||
|
||||
if layer.weight.data.dtype == torch.bfloat16:
|
||||
# Quantize the weights.
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(
|
||||
layer.weight, scale=None, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(
|
||||
qweight.t(), requires_grad=False
|
||||
) # Pretranspose the weight
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
else:
|
||||
assert layer.weight.data.dtype == current_platform.fp8_dtype()
|
||||
assert getattr(layer, "weight_scale", None) is not None
|
||||
layer.input_scale = None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
601
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
601
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
@@ -0,0 +1,601 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import fnmatch
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
QuarkMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||
QuarkOCP_MX,
|
||||
QuarkScheme,
|
||||
QuarkW8A8Fp8,
|
||||
QuarkW8A8Int8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
deep_compare,
|
||||
should_ignore_layer,
|
||||
)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
__all__ = ["QuarkLinearMethod"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkConfig(QuantizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: dict[str, Any],
|
||||
kv_cache_group: list[str] | None = None,
|
||||
kv_cache_config: dict[str, Any] | None = None,
|
||||
pack_method: str = "reorder",
|
||||
):
|
||||
super().__init__()
|
||||
if kv_cache_group is None:
|
||||
kv_cache_group = []
|
||||
self.quant_config = quant_config
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "quark"
|
||||
|
||||
def apply_vllm_mapper( # noqa: B027
|
||||
self, hf_to_vllm_mapper: "WeightsMapper"
|
||||
):
|
||||
"""
|
||||
Interface for models to update module names referenced in
|
||||
quantization configs in order to reflect the vllm model structure
|
||||
|
||||
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
|
||||
structure of the qconfig) to vllm model structure
|
||||
"""
|
||||
quant_config_with_hf_to_vllm_mapper = {}
|
||||
|
||||
for k, v in self.quant_config.items():
|
||||
if isinstance(v, list):
|
||||
quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_list(v)
|
||||
elif isinstance(v, dict):
|
||||
quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_dict(v)
|
||||
else:
|
||||
if isinstance(v, str):
|
||||
mapped_v_list = hf_to_vllm_mapper.apply_list([v])
|
||||
if mapped_v_list:
|
||||
quant_config_with_hf_to_vllm_mapper[k] = mapped_v_list[0]
|
||||
else:
|
||||
quant_config_with_hf_to_vllm_mapper[k] = v
|
||||
|
||||
self.quant_config = quant_config_with_hf_to_vllm_mapper
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(
|
||||
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return QuarkKVCacheMethod(self)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
||||
export_config = config.get("export")
|
||||
if export_config is None:
|
||||
raise ValueError(
|
||||
"The export key should be included in "
|
||||
"the configurations of Quark quantized model"
|
||||
)
|
||||
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
||||
pack_method = cast(str, export_config.get("pack_method"))
|
||||
|
||||
# In the export model of quark, the quantization configuration
|
||||
# of kv_cache is stored in layer_quant_config. First, it is
|
||||
# judged whether kv_cache_group exists, and then it is judged
|
||||
# whether layer_quant_config has a quantization configuration
|
||||
# that matches kv_cache.
|
||||
if len(kv_cache_group) == 0:
|
||||
kv_cache_config = None
|
||||
else:
|
||||
kv_cache_set = set(kv_cache_group)
|
||||
layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config"))
|
||||
layer_quant_names = list(layer_quant_config.keys())
|
||||
layer_quant_set = set(layer_quant_names)
|
||||
|
||||
if not (
|
||||
kv_cache_set.issubset(layer_quant_set)
|
||||
or any(
|
||||
fnmatch.fnmatchcase(layer_quant, pat)
|
||||
for layer_quant in list(layer_quant_set)
|
||||
for pat in list(kv_cache_set)
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"The Quark quantized model has the "
|
||||
"kv_cache_group parameter setting, "
|
||||
"but no kv_cache quantization settings "
|
||||
"were found in the quantization "
|
||||
"configuration."
|
||||
)
|
||||
|
||||
q_configs = [
|
||||
quant_cfg
|
||||
for name, quant_cfg in layer_quant_config.items()
|
||||
if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group)
|
||||
]
|
||||
|
||||
if not all(
|
||||
deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"])
|
||||
for q_config in q_configs
|
||||
):
|
||||
raise ValueError(
|
||||
"The quantization method used for kv_cache should "
|
||||
"be the same, but the quantization method for the "
|
||||
"kv_cache layer in the config is different."
|
||||
)
|
||||
kv_cache_config = q_configs[0].get("output_tensors")
|
||||
if kv_cache_config is None:
|
||||
raise ValueError("The kv_cache quantization configuration is empty.")
|
||||
|
||||
# Since we have already set kv_cache quantization configurations,
|
||||
# we will remove the quantization configuration for the
|
||||
# output_tensors corresponding to the kv_cache layer.
|
||||
for q_config in q_configs:
|
||||
q_config["output_tensors"] = None
|
||||
|
||||
# In case q_proj output is also quantized, remove the configuration
|
||||
# to keep qkv consistency.
|
||||
q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj"))
|
||||
if q_proj_q_config is not None:
|
||||
q_proj_q_config["output_tensors"] = None
|
||||
|
||||
return cls(
|
||||
quant_config=config,
|
||||
kv_cache_group=kv_cache_group,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pack_method=pack_method,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.",
|
||||
)
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp8_w4a8(
|
||||
self,
|
||||
weight_quant: list[dict[str, Any]] | None,
|
||||
input_quant: dict[str, Any] | None,
|
||||
) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
if not isinstance(weight_quant, list) or len(weight_quant) != 2:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported
|
||||
is_w4a8_dtype = (
|
||||
weight_quant[0].get("dtype") == "fp8_e4m3"
|
||||
and weight_quant[1].get("dtype") == "int4"
|
||||
and input_quant.get("dtype") == "fp8_e4m3"
|
||||
)
|
||||
is_static_weight = not weight_quant[0].get("is_dynamic") and not weight_quant[
|
||||
1
|
||||
].get("is_dynamic")
|
||||
is_per_tensor_fp8_and_per_channel_int4_weight = (
|
||||
weight_quant[0].get("qscheme") == "per_tensor"
|
||||
and weight_quant[1].get("qscheme") == "per_channel"
|
||||
and weight_quant[1].get("symmetric") is True
|
||||
and weight_quant[1].get("ch_axis") == 0
|
||||
)
|
||||
|
||||
if not (
|
||||
is_w4a8_dtype
|
||||
and is_static_weight
|
||||
and is_per_tensor_fp8_and_per_channel_int4_weight
|
||||
):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.get("is_dynamic"):
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor"
|
||||
return is_per_tensor_activation
|
||||
|
||||
def _is_fp8_w8a8(
|
||||
self,
|
||||
weight_quant: dict[str, Any] | None,
|
||||
input_quant: dict[str, Any] | None,
|
||||
) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported
|
||||
is_fp8_dtype = (
|
||||
weight_quant.get("dtype") == "fp8_e4m3"
|
||||
and input_quant.get("dtype") == "fp8_e4m3"
|
||||
)
|
||||
is_static_weight = not weight_quant.get("is_dynamic")
|
||||
is_per_tensor_or_channel_weight = weight_quant.get("qscheme") in [
|
||||
"per_tensor",
|
||||
"per_channel",
|
||||
]
|
||||
|
||||
if not (is_fp8_dtype and is_static_weight and is_per_tensor_or_channel_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.get("is_dynamic"):
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor"
|
||||
return is_per_tensor_activation
|
||||
|
||||
def _is_static_tensor_w8a8(
|
||||
self,
|
||||
weight_quant: dict[str, Any] | None,
|
||||
input_quant: dict[str, Any] | None,
|
||||
) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
is_int8_dtype = (
|
||||
weight_quant.get("dtype") == "int8" and input_quant.get("dtype") == "int8"
|
||||
)
|
||||
|
||||
is_tensor = (
|
||||
weight_quant.get("qscheme") in ["per_tensor", "per_channel"]
|
||||
and input_quant.get("qscheme") == "per_tensor"
|
||||
)
|
||||
|
||||
is_static = not weight_quant.get("is_dynamic") and not input_quant.get(
|
||||
"is_dynamic"
|
||||
)
|
||||
|
||||
is_weight_symmetric = weight_quant.get("symmetric") is True
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_w_ocp_mx_a_x(
|
||||
self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
|
||||
) -> bool:
|
||||
"""
|
||||
This check returns True only if it is an OCP-MX weight quantization.
|
||||
The activation can be any data type (e.g., FP16/BF16, FP8, or OCP-MX format).
|
||||
The rationale for checking only the weight type is that
|
||||
the model loading concept and process primarily concerns the weights themselves.
|
||||
"""
|
||||
# Confirm weights quantized.
|
||||
if weight_quant is None:
|
||||
logger.debug(
|
||||
"Quark model's weight quantization is incompatible with OCP_MX format: "
|
||||
"weight_quant is not set."
|
||||
)
|
||||
return False
|
||||
|
||||
if isinstance(weight_quant, list):
|
||||
logger.debug(
|
||||
"Quark model's weight quantization is incompatible with OCP_MX format: "
|
||||
"weight_quant is a list (e.g. fp8_w4a8), OCP_MX requires a single dict."
|
||||
)
|
||||
return False
|
||||
|
||||
# Input and weight qscheme needs to be per group.
|
||||
if weight_quant.get("qscheme") != "per_group":
|
||||
logger.debug(
|
||||
"Quark model's weight quantization is incompatible with OCP MX format: "
|
||||
"weight is not per_group."
|
||||
)
|
||||
return False
|
||||
|
||||
# Input and weight group size needs to be 32.
|
||||
if weight_quant.get("group_size") != 32:
|
||||
logger.debug(
|
||||
"Quark model's weight quantization is incompatible with OCP MX format: "
|
||||
"group_size of weight is not 32."
|
||||
)
|
||||
return False
|
||||
|
||||
# Activations and weight scales need to be in e8m0 format.
|
||||
if weight_quant.get("scale_format") != "e8m0":
|
||||
logger.debug(
|
||||
"Quark model's weight quantization is incompatible with OCP MX format: "
|
||||
"scale_format of weight is not e8m0."
|
||||
)
|
||||
return False
|
||||
|
||||
# Input and weight dtypes need to be any of fp4,
|
||||
# fp6_e3m2 or fp6_e3m2, possibly mixed.
|
||||
if weight_quant.get("dtype") not in {
|
||||
"fp4",
|
||||
"fp6_e3m2",
|
||||
"fp6_e2m3",
|
||||
}:
|
||||
logger.debug(
|
||||
"Quark model's weight quantization is incompatible with OCP MX format: "
|
||||
"dtype is not in {fp4, fp6_e3m2, fp6_e2m3}."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
|
||||
"""
|
||||
For Quark, determine if it's OCP MXFP4 by checking config directly.
|
||||
This allows hidden_size rounding to happen before moe_config creation.
|
||||
"""
|
||||
layer_quant_config = self._find_matched_config(prefix, layer)
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
return (
|
||||
self._is_w_ocp_mx_a_x(weight_config, input_config)
|
||||
and weight_config is not None
|
||||
and weight_config.get("dtype") == "fp4"
|
||||
and getattr(torch, "float4_e2m1fn_x2", None) is not None
|
||||
)
|
||||
|
||||
def _find_matched_config(
|
||||
self, layer_name: str, module: torch.nn.Module
|
||||
) -> dict[str, Any]:
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
shard_proj_names = self.packed_modules_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
shard_configs = [
|
||||
self._find_matched_config(shard_name, module)
|
||||
for shard_name in shard_names
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, shard_configs[0]) for q_config in shard_configs
|
||||
):
|
||||
raise ValueError(
|
||||
f"Found a different quantization configuration for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme."
|
||||
)
|
||||
return shard_configs[0]
|
||||
else:
|
||||
layer_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("layer_quant_config")
|
||||
)
|
||||
|
||||
def _matches_pattern(layer_name, pattern):
|
||||
if "*" not in pattern:
|
||||
return layer_name in pattern
|
||||
return fnmatch.fnmatch(layer_name, pattern)
|
||||
|
||||
for name_pattern, config in layer_quant_config.items():
|
||||
if _matches_pattern(layer_name, name_pattern):
|
||||
return config
|
||||
|
||||
layer_type = cast(str, type(module))
|
||||
layer_type_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("layer_type_quant_config")
|
||||
)
|
||||
if layer_type in layer_type_quant_config:
|
||||
return layer_type_quant_config[layer_type]
|
||||
|
||||
global_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("global_quant_config")
|
||||
)
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
"and bias quantized are not supported"
|
||||
)
|
||||
weight_config = cast(dict[str, Any], config.get("weight"))
|
||||
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
||||
|
||||
if self._is_fp8_w8a8(weight_config, input_config):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
QuarkW8A8Fp8.get_min_capability(), error=False
|
||||
)
|
||||
if is_fp8_w8a8_supported:
|
||||
return QuarkW8A8Fp8(weight_config, input_config)
|
||||
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
return QuarkW8A8Int8(
|
||||
qscheme=weight_qscheme,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_config.get("symmetric"),
|
||||
)
|
||||
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX(weight_config, input_config)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No quark compatible scheme was found. "
|
||||
f"Weight config: {weight_config}, "
|
||||
f"Input config: {input_config}"
|
||||
)
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in quark. If this is the case, return its equivalent param name
|
||||
expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||
if name.endswith("self_attn.prob_output_scale"):
|
||||
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
|
||||
class QuarkLinearMethod(LinearMethodBase):
|
||||
def __init__(self, quantization_config: QuarkConfig):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
|
||||
|
||||
class QuarkKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from quark checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QuarkConfig):
|
||||
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_config(kv_cache_config: dict[str, Any] | None):
|
||||
"""
|
||||
Validator for the kv cache configuration. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
:param kv_cache_config: the quark kv cache scheme
|
||||
"""
|
||||
if kv_cache_config is None:
|
||||
return
|
||||
|
||||
dtype = kv_cache_config.get("dtype")
|
||||
if dtype != "fp8_e4m3":
|
||||
raise NotImplementedError(
|
||||
"Currently supported kv cache quantization is "
|
||||
f"dtype=fp8_e4m3, however received {dtype}"
|
||||
)
|
||||
|
||||
qscheme = kv_cache_config.get("qscheme")
|
||||
if qscheme != "per_tensor":
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for quark KV cache. "
|
||||
f"Expected qscheme: per_tensor, found qscheme: {qscheme}"
|
||||
)
|
||||
1033
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
1033
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .quark_ocp_mx import QuarkOCP_MX
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w8a8_fp8 import QuarkW8A8Fp8
|
||||
from .quark_w8a8_int8 import QuarkW8A8Int8
|
||||
|
||||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"]
|
||||
@@ -0,0 +1,353 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from fractions import Fraction
|
||||
from functools import cache, partial
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
dequant_mxfp4,
|
||||
quant_dequant_mxfp4,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
|
||||
dequant_mxfp6,
|
||||
quant_dequant_mxfp6,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# TODO: move registration of custom op to aiter_ops.py
|
||||
# `from vllm._aiter_ops import rocm_aiter_ops`
|
||||
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
|
||||
# for envs checks which does not require @cache anymore.
|
||||
# triton kernel is torch compile compatible.
|
||||
# does not require direct registration.
|
||||
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
|
||||
@cache
|
||||
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import (
|
||||
gemm_afp4wfp4,
|
||||
gemm_afp4wfp4_preshuffled_weight_scales,
|
||||
)
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
if is_rocm_aiter_fp4_asm_gemm_enabled():
|
||||
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
|
||||
|
||||
def gemm_with_dynamic_quant(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||
out_dtype: torch.dtype | None = torch.bfloat16,
|
||||
x_scales: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
M = x.shape[0]
|
||||
N = weight.shape[0]
|
||||
K = weight.shape[1]
|
||||
if rocm_use_aiter_fp4_asm_gemm:
|
||||
if M <= 64 and rocm_aiter_ops.is_triton_gemm_afp4wfp4_presh_ws_tuned(N, K):
|
||||
if x_scales is None:
|
||||
# use hip quant kernel for performance
|
||||
if M >= 32:
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||
else:
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=False)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
if M >= 32:
|
||||
x_s = x_s.view(torch.uint8).view(x_s.shape[0] // 32, -1)
|
||||
else:
|
||||
x_s = x_s[:M, ...].view(torch.uint8)
|
||||
|
||||
y = torch.empty(M, N, device=x_q.device, dtype=out_dtype)
|
||||
gemm_afp4wfp4_preshuffled_weight_scales(
|
||||
x_q.view(torch.uint8),
|
||||
weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
|
||||
x_s,
|
||||
weight_scale.view(torch.uint8).view(
|
||||
weight_scale.shape[0] // 32, -1
|
||||
),
|
||||
out_dtype,
|
||||
y,
|
||||
)
|
||||
else:
|
||||
if x_scales is None:
|
||||
# use hip quant kernel for performance
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
# 32 alignment is enough for dim0 padding of output for
|
||||
# gemm_a4w4 kernel
|
||||
y = torch.empty(
|
||||
(M + 31) // 32 * 32,
|
||||
weight.shape[0],
|
||||
device=x_q.device,
|
||||
dtype=out_dtype,
|
||||
)
|
||||
|
||||
gemm_a4w4(
|
||||
x_q,
|
||||
weight.view(x_q.dtype),
|
||||
x_s,
|
||||
weight_scale.view(x_s.dtype),
|
||||
y,
|
||||
bpreshuffle=True,
|
||||
)
|
||||
return y[:M]
|
||||
else:
|
||||
if x_scales is None:
|
||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
y = torch.empty(
|
||||
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||
)
|
||||
|
||||
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
||||
return y
|
||||
|
||||
def gemm_with_dynamic_quant_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
x_scales: torch.Tensor = None,
|
||||
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||
out_dtype: torch.dtype | None = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(*x.shape[:-1], weight.shape[0]), dtype=out_dtype, device=x.device
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="gemm_with_dynamic_quant",
|
||||
op_func=gemm_with_dynamic_quant,
|
||||
mutates_args=[],
|
||||
fake_impl=gemm_with_dynamic_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
except (ImportError, AttributeError, RuntimeError):
|
||||
if current_platform.is_rocm():
|
||||
logger.warning(
|
||||
"AITER is not found or QuarkOCP_MX is not supported on the current "
|
||||
"platform. QuarkOCP_MX quantization will not be available."
|
||||
)
|
||||
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
|
||||
|
||||
|
||||
class QuarkOCP_MX(QuarkScheme):
|
||||
def __init__(
|
||||
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
||||
):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
|
||||
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
|
||||
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
|
||||
self.input_dtype, self.weight_dtype
|
||||
)
|
||||
|
||||
if self.weight_dtype == "mxfp4":
|
||||
self.packed_factor: int | Fraction = 2
|
||||
self.dequant_func = dequant_mxfp4
|
||||
else:
|
||||
self.packed_factor = Fraction(numerator=8, denominator=6)
|
||||
self.dequant_func = partial(
|
||||
dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
|
||||
)
|
||||
|
||||
if self.input_dtype == "mxfp4":
|
||||
self.quant_dequant_func = quant_dequant_mxfp4
|
||||
else:
|
||||
self.quant_dequant_func = partial(
|
||||
quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
|
||||
)
|
||||
|
||||
self.static_input_scales = not input_quant_spec.get("is_dynamic")
|
||||
|
||||
if self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
"QuarkOCP_MX with static input scales is currently not "
|
||||
"implemented. Please open an issue."
|
||||
)
|
||||
|
||||
# TODO: integrate (or test) mixed-precision kernel.
|
||||
self.emulate = not current_platform.supports_mx() or (
|
||||
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
|
||||
)
|
||||
|
||||
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
|
||||
|
||||
if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
|
||||
# Currently need these kernels if not emulating
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} requires AITER to be installed "
|
||||
"for non-emulation mode! Please refer to "
|
||||
"https://github.com/ROCm/aiter for installation details."
|
||||
)
|
||||
|
||||
if not current_platform.supports_mx():
|
||||
logger.warning_once(
|
||||
"The current platform does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
|
||||
if current_platform.supports_mx() and (
|
||||
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
|
||||
):
|
||||
logger.warning_once(
|
||||
"The current platform supports native MXFP4/MXFP6 "
|
||||
f"computation, but kernels for input_dtype={self.input_dtype} "
|
||||
f"and weight_dtype={self.weight_dtype} are not yet integrated "
|
||||
"in vLLM. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
|
||||
def get_packed_dim(self, dim: int, quant_dtype: str):
|
||||
if quant_dtype == "mxfp4":
|
||||
assert dim % 2 == 0
|
||||
return dim // 2
|
||||
elif quant_dtype in {"mxfp6_e3m2", "mxfp6_e2m3"}:
|
||||
# FP6 packs 4 * 6 = 24 bits on 3 bytes.
|
||||
assert (dim * 3) % 4 == 0
|
||||
return (dim * 3) // 4
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported quant_dtype in QuarkOCP_MX.get_packed_dim, "
|
||||
f"got quant_dtype={quant_dtype}. Something is wrong, please "
|
||||
"open an issue."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
if self.emulate:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
else:
|
||||
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||
# shuffle weight scale
|
||||
weight_scale_shuffle = layer.weight_scale.data
|
||||
sm, sn = weight_scale_shuffle.shape
|
||||
weight_scale_shuffle = weight_scale_shuffle.view(
|
||||
sm // 32, 2, 16, sn // 8, 2, 4, 1
|
||||
)
|
||||
weight_scale_shuffle = weight_scale_shuffle.permute(
|
||||
0, 3, 5, 2, 4, 1, 6
|
||||
).contiguous()
|
||||
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
weight_scale_shuffle, requires_grad=False
|
||||
)
|
||||
|
||||
# shuffle weight
|
||||
weight_shuffle = layer.weight.data
|
||||
weight_shuffle = shuffle_weight(weight_shuffle, layout=(16, 16))
|
||||
layer.weight = torch.nn.Parameter(weight_shuffle, requires_grad=False)
|
||||
else:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.emulate:
|
||||
dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype)
|
||||
qdq_x = self.quant_dequant_func(x)
|
||||
return F.linear(qdq_x, dq_w, bias)
|
||||
else:
|
||||
return torch.ops.vllm.gemm_with_dynamic_quant(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
self.rocm_use_aiter_fp4_asm_gemm,
|
||||
self.out_dtype,
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["QuarkScheme"]
|
||||
|
||||
|
||||
class QuarkScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by Quark.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,188 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
def __init__(
|
||||
self, weight_config: dict[str, Any], input_config: dict[str, Any] | None
|
||||
):
|
||||
self.weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
self.is_static_input_scheme: bool = False
|
||||
self.input_qscheme: str | None = None
|
||||
if input_config is not None:
|
||||
self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
|
||||
self.input_qscheme = cast(str, input_config.get("qscheme"))
|
||||
|
||||
per_token_activation = (
|
||||
not self.is_static_input_scheme and self.input_qscheme == "per_channel"
|
||||
)
|
||||
per_token_weight = self.weight_qscheme == "per_channel"
|
||||
|
||||
self.activation_quant_key = (
|
||||
kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym
|
||||
)
|
||||
self.weight_quant_key = (
|
||||
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, "input_scale", None)
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale,
|
||||
)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
else:
|
||||
max_w_scale = layer.weight_scale
|
||||
weight = layer.weight
|
||||
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, "input_scale", None)
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale,
|
||||
)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN:
|
||||
weight_scale = weight_scale.view(-1, 1)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization scheme {self.weight_qscheme}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.weight_qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.weight_qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# min requirement for fp8 kernels
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_int8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Int8(QuarkScheme):
|
||||
def __init__(
|
||||
self,
|
||||
qscheme: str,
|
||||
is_static_input_scheme: bool | None,
|
||||
input_symmetric: bool | None,
|
||||
):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
self.kernel = init_int8_linear_kernel(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||
input_symmetric=(self.input_symmetric is True),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
ChannelQuantZPParameter = ChannelQuantScaleParameter
|
||||
weight_zero_point = ChannelQuantZPParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)), dtype=torch.int8),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
PerTensorZPParameter = PerTensorScaleParameter
|
||||
weight_zero_point = PerTensorZPParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.int8),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point = None
|
||||
input_scale = None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.register_parameter("weight_zero_point", None)
|
||||
delattr(layer, "weight_zero_point")
|
||||
if self.input_symmetric:
|
||||
layer.register_parameter("input_zero_point", None)
|
||||
delattr(layer, "input_zero_point")
|
||||
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
119
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
119
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
if type(dict1) is not type(dict2):
|
||||
return False
|
||||
if isinstance(dict1, dict):
|
||||
if dict1.keys() != dict2.keys():
|
||||
return False
|
||||
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
||||
elif isinstance(dict1, list):
|
||||
return set(dict1) == set(dict2)
|
||||
else:
|
||||
return dict1 == dict2
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: str | None,
|
||||
ignore: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore
|
||||
)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(
|
||||
f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme."
|
||||
)
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(
|
||||
layer_name=layer_name, targets=ignore
|
||||
)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(
|
||||
value: str, target: str, check_contains: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# utility for tensor dims > 2 cases
|
||||
def quark_quantize_weight_to_mxfp4(w: torch.Tensor):
|
||||
assert w.dtype == torch.bfloat16, (
|
||||
"Quark dynamic quantization is supported only for fp16 weights and only to MXF4"
|
||||
)
|
||||
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
|
||||
*dims, d = w.shape
|
||||
w, w_scales = dynamic_mxfp4_quant(w.reshape(-1, d))
|
||||
return w.view(*dims, d // 2), w_scales.view(*dims, d // 32)
|
||||
182
vllm/model_executor/layers/quantization/qutlass_utils.py
Normal file
182
vllm/model_executor/layers/quantization/qutlass_utils.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Modified by Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
#
|
||||
# Copied from https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats
|
||||
#
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch.library import wrap_triton
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
@triton.jit
|
||||
def triton_scale_swizzle(
|
||||
scale_ptr: torch.Tensor,
|
||||
scale_rows: int,
|
||||
scale_cols: int,
|
||||
output_ptr: torch.Tensor,
|
||||
input_row_stride: int,
|
||||
output_block_stride: int,
|
||||
BLOCK_ROWS: tl.constexpr,
|
||||
BLOCK_COLS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Rearranges tensor data from row-major to block-scaled swizzle format.
|
||||
|
||||
Args:
|
||||
scale_ptr: Pointer to the input scale tensor
|
||||
scale_rows: Number of rows in the scale tensor
|
||||
scale_cols: Number of columns in the scale tensor
|
||||
output_ptr: Pointer to the output tensor
|
||||
input_row_stride: Stride between rows in the input tensor
|
||||
output_block_stride: Stride between blocks in the output tensor
|
||||
BLOCK_ROWS: Number of rows in a tile (compile-time constant)
|
||||
BLOCK_COLS: Number of columns in a tile (compile-time constant)
|
||||
"""
|
||||
pid_row = tl.program_id(0)
|
||||
pid_col = tl.program_id(1)
|
||||
|
||||
rows = tl.arange(0, BLOCK_ROWS)[:, None]
|
||||
cols = tl.arange(0, BLOCK_COLS)[None, :]
|
||||
|
||||
# Calculate starting row and column for this tile
|
||||
start_row = pid_row * BLOCK_ROWS
|
||||
start_col = pid_col * BLOCK_COLS
|
||||
global_rows = start_row + rows
|
||||
global_cols = start_col + cols
|
||||
|
||||
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
|
||||
|
||||
input_scales = tl.load(
|
||||
scale_ptr + global_rows * input_row_stride + global_cols,
|
||||
mask=mask,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
r_div_32 = rows // 32
|
||||
r_mod_32 = rows % 32
|
||||
|
||||
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
|
||||
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
|
||||
|
||||
# Flatten
|
||||
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
|
||||
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
|
||||
|
||||
# Calculate block offset using provided output block stride
|
||||
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
|
||||
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
|
||||
|
||||
tl.store(
|
||||
output_ptr + block_offset + dest_indices_flat,
|
||||
scales_flat,
|
||||
)
|
||||
|
||||
|
||||
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rearranges an E8M0 tensor scale from row-major format to
|
||||
block-scaled swizzle format.
|
||||
|
||||
This format is suitable for Tmem as described in NVIDIA documentation:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
scale_tensor: Input tensor in row-major format with 8-bit elements
|
||||
|
||||
Returns:
|
||||
Rearranged tensor in block-scaled swizzle format
|
||||
"""
|
||||
assert scale_tensor.element_size() == 1, (
|
||||
"Expected element size to be 1 byte (8 bits)"
|
||||
)
|
||||
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
|
||||
|
||||
rows, cols = scale_tensor.shape
|
||||
|
||||
# Calculate blocks needed
|
||||
n_row_blocks = triton.cdiv(rows, 128)
|
||||
n_col_blocks = triton.cdiv(cols, 4)
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
out = scale_tensor.new_empty((padded_rows, padded_cols))
|
||||
|
||||
# Input stride (for row-major format)
|
||||
input_row_stride = cols
|
||||
|
||||
# We probably want handle multiple blocks per tile but
|
||||
# for now keep it simple
|
||||
BLOCK_ROWS, BLOCK_COLS = 128, 4
|
||||
|
||||
# Output block stride for the rearranged format
|
||||
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(padded_rows, BLOCK_ROWS),
|
||||
triton.cdiv(padded_cols, BLOCK_COLS),
|
||||
)
|
||||
|
||||
wrap_triton(triton_scale_swizzle)[grid](
|
||||
scale_tensor.view(torch.uint8),
|
||||
rows,
|
||||
cols,
|
||||
out.view(torch.uint8),
|
||||
input_row_stride,
|
||||
output_block_stride,
|
||||
BLOCK_ROWS=BLOCK_ROWS,
|
||||
BLOCK_COLS=BLOCK_COLS,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def to_blocked(
|
||||
input_matrix: torch.Tensor, backend: Literal["torch", "triton"] = "triton"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying
|
||||
the rearrangement pattern.
|
||||
|
||||
See:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
input_matrix: Input tensor of shape (H, W)
|
||||
backend: "torch" (PyTorch path) or "triton" (Triton kernel)
|
||||
|
||||
Returns:
|
||||
Rearranged tensor of shape (32*cdiv(H,128), 16*cdiv(W,4))
|
||||
"""
|
||||
if backend == "triton":
|
||||
return triton_mx_block_rearrange(input_matrix).flatten()
|
||||
elif backend != "torch":
|
||||
raise ValueError(f'backend must be "torch" or "triton", got {backend!r}')
|
||||
|
||||
rows, cols = input_matrix.shape
|
||||
n_row_blocks = cdiv(rows, 128)
|
||||
n_col_blocks = cdiv(cols, 4)
|
||||
|
||||
# Calculate the padded shape
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
padded = input_matrix
|
||||
assert (rows, cols) == (padded_rows, padded_cols)
|
||||
|
||||
# Rearrange the blocks
|
||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
|
||||
return rearranged.flatten()
|
||||
90
vllm/model_executor/layers/quantization/schema.py
Normal file
90
vllm/model_executor/layers/quantization/schema.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains the Pydantic schemas for various quantization-related
|
||||
parameters. When a relevant quantization technique is specified, these
|
||||
parameters are loaded in the form of a JSON alongside the model weights
|
||||
and augment the model with additional information needed for use of that
|
||||
technique. The format of this JSON should be specified by one or more
|
||||
schemas contained here.
|
||||
|
||||
For example, when the KV cache is quantized to FP8-E4M3 (currently only
|
||||
possible on ROCm), the model can be optionally augmented with KV cache
|
||||
scaling factors.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||
|
||||
|
||||
class KVCacheQuantSchema(BaseModel):
|
||||
dtype: str
|
||||
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
|
||||
# layer indices to their per-tensor KV cache scaling factor.
|
||||
# TODO: Consider pulling this and its validation methods out into its
|
||||
# own schema class (tricky as its members are variable)
|
||||
scaling_factor: dict[int, dict[int, float]]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
||||
assert self.dtype == "float8_e4m3fn", (
|
||||
"Loaded scaling factors intended for KV cache dtype = "
|
||||
f"{self.dtype} rather than float8_e4m3fn!"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
tp_size = context["tp_size"]
|
||||
num_hidden_layers = context["num_hidden_layers"]
|
||||
assert len(self.scaling_factor) == tp_size, (
|
||||
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
|
||||
f"but LLM engine is currently running with TP size {tp_size}."
|
||||
)
|
||||
for tp_rank, layer_maps in self.scaling_factor.items():
|
||||
assert len(layer_maps) == num_hidden_layers, (
|
||||
f"KV cache scales map for TP rank {tp_rank} is malformed. "
|
||||
f"Expected {num_hidden_layers} layers, got "
|
||||
f"{len(layer_maps)}."
|
||||
)
|
||||
for i in range(tp_size):
|
||||
assert i in self.scaling_factor, (
|
||||
f"KV cache scales map for TP rank {i} not found."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
tp_rank = context["tp_rank"]
|
||||
num_hidden_layers = context["num_hidden_layers"]
|
||||
layer_scales_map = self.scaling_factor[tp_rank]
|
||||
for i in range(num_hidden_layers):
|
||||
assert i in layer_scales_map, (
|
||||
f"Could not find KV cache scales for layer {i} in "
|
||||
f"TP rank {tp_rank}."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class QuantParamSchema(BaseModel):
|
||||
# TODO: Generalize and extend with more fields
|
||||
# (e.g. weights/activations params) once functionality is enabled
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
model_type: str | None
|
||||
kv_cache: KVCacheQuantSchema
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
model_type = context.get("model_type", None)
|
||||
if model_type is not None:
|
||||
assert model_type == self.model_type, (
|
||||
f"Model type is {model_type} but loaded "
|
||||
f"scaling factors belonging to different "
|
||||
f"model type {self.model_type}!"
|
||||
)
|
||||
return self
|
||||
366
vllm/model_executor/layers/quantization/torchao.py
Normal file
366
vllm/model_executor/layers/quantization/torchao.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
import json
|
||||
import types
|
||||
from importlib.util import find_spec
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from packaging import version
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _bond_method_to_cls(func, obj):
|
||||
if hasattr(func, "__self__") or not callable(func):
|
||||
# If the function is already bound to an instance, return it as is
|
||||
return func
|
||||
else:
|
||||
return types.MethodType(func, obj)
|
||||
|
||||
|
||||
def _get_weight_attrs(param):
|
||||
# record attributes attached to the weight, so we can
|
||||
# recover later
|
||||
recorded_weight_attr = {}
|
||||
for key in param.__dict__:
|
||||
if hasattr(param, key):
|
||||
attr = getattr(param, key)
|
||||
if not callable(attr):
|
||||
recorded_weight_attr[key] = attr
|
||||
elif hasattr(attr, "__self__") and param is attr.__self__:
|
||||
# if attr is a bonded method for an instance, and
|
||||
# attr.__self__ points to the instance (param)
|
||||
# we'll record the underlying function object
|
||||
recorded_weight_attr[key] = attr.__func__
|
||||
else:
|
||||
recorded_weight_attr[key] = attr
|
||||
return recorded_weight_attr
|
||||
|
||||
|
||||
def _restore_weight_attrs(param, recorded_weight_attr):
|
||||
for attr_name, attr in recorded_weight_attr.items():
|
||||
if not hasattr(param, attr_name):
|
||||
setattr(param, attr_name, _bond_method_to_cls(attr, param))
|
||||
|
||||
|
||||
def torchao_version_at_least(torchao_version: str) -> bool:
|
||||
if find_spec("torchao"):
|
||||
try:
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
|
||||
torchao_version
|
||||
):
|
||||
return True
|
||||
except (ImportError, version.InvalidVersion):
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def should_skip(prefix: str, skip_modules: list[str]) -> bool:
|
||||
"""
|
||||
Robust skipping logic:
|
||||
should_skip("model.model.layers.1.q_proj",
|
||||
["model.model.layers.1.q_proj"]) # True
|
||||
should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True
|
||||
should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True
|
||||
should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True
|
||||
should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False
|
||||
"""
|
||||
for s in skip_modules:
|
||||
if prefix == s:
|
||||
return True
|
||||
if f".{s}." in f".{prefix}.":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
if torchao_version_at_least("0.15.0"):
|
||||
from torchao.prototype.tensor_conversion.api import (
|
||||
convert_to_packed_tensor_based_on_current_hardware,
|
||||
)
|
||||
else:
|
||||
convert_to_packed_tensor_based_on_current_hardware = lambda t: t
|
||||
|
||||
|
||||
class TorchAOConfig(QuantizationConfig):
|
||||
"""Config class for torchao."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
torchao_config,
|
||||
skip_modules: list[str] | None = None,
|
||||
is_checkpoint_torchao_serialized: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.torchao_config = torchao_config
|
||||
self.skip_modules = skip_modules or []
|
||||
self.is_checkpoint_torchao_serialized = is_checkpoint_torchao_serialized
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"TorchAOConfig({self.torchao_config=}, {self.skip_modules=}, "
|
||||
f"{self.is_checkpoint_torchao_serialized=})"
|
||||
)
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "torchao"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
"""torchao doesn't require additional config files, we use
|
||||
`config.json` from huggingface: `model_config.hf_config`
|
||||
"""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
|
||||
"""Create the quant config from an hf model config"""
|
||||
try:
|
||||
from torchao.core.config import config_from_dict
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torchao>=0.10.0 via "
|
||||
"`pip install torchao>=0.10.0` to use torchao quantization."
|
||||
) from err
|
||||
|
||||
quant_method = cls.get_from_keys_or(config, ["quant_method"], None)
|
||||
is_checkpoint_torchao_serialized = (
|
||||
quant_method is not None and "torchao" in quant_method
|
||||
)
|
||||
|
||||
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
|
||||
assert hf_config is not None, "quant_type must be specified"
|
||||
assert len(hf_config) == 1 and "default" in hf_config, (
|
||||
"Expected only one key 'default' in quant_type dictionary"
|
||||
)
|
||||
quant_type = hf_config["default"]
|
||||
ao_config = config_from_dict(quant_type)
|
||||
|
||||
# Adds skipped modules defined in "modules_to_not_convert"
|
||||
skip_modules = config.get("modules_to_not_convert", []) or []
|
||||
|
||||
# Adds skipped modules defined in "module_fqn_to_config"
|
||||
_data = quant_type.get("_data", {})
|
||||
if not isinstance(_data, dict):
|
||||
_data = {}
|
||||
|
||||
module_fqn = _data.get("module_fqn_to_config", {})
|
||||
if not isinstance(module_fqn, dict):
|
||||
module_fqn = {}
|
||||
|
||||
for layer, layer_cfg in module_fqn.items():
|
||||
if layer_cfg is None:
|
||||
skip_modules.append(layer)
|
||||
|
||||
return cls(ao_config, skip_modules, is_checkpoint_torchao_serialized)
|
||||
|
||||
@classmethod
|
||||
def from_config_file(cls, config_file: str) -> "TorchAOConfig":
|
||||
"""Initialize class from a config file. Example:
|
||||
```
|
||||
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
|
||||
fn = "torchao_config.json"
|
||||
|
||||
with open(fn, "w") as f:
|
||||
f.write(json.dumps(config_to_dict(config)))
|
||||
```
|
||||
"""
|
||||
with open(config_file) as f:
|
||||
f.seek(0)
|
||||
f_read = f.read()
|
||||
config_dict = json.loads(f_read)
|
||||
|
||||
hf_config = {"quant_type": {"default": config_dict}}
|
||||
return cls.from_config(hf_config)
|
||||
|
||||
@classmethod
|
||||
def from_config_dict_json(cls, config_dict_json: str) -> "TorchAOConfig":
|
||||
"""Iniitalize class from a config_dict json string, got from
|
||||
torchao_config_object = some AOBaseConfig object
|
||||
json.dumps(config_to_dict(torchao_config_object))
|
||||
"""
|
||||
config_dict = json.loads(config_dict_json)
|
||||
hf_config = {"quant_type": {"default": config_dict}}
|
||||
return cls.from_config(hf_config)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if not isinstance(layer, LinearBase):
|
||||
return None
|
||||
|
||||
from torchao.quantization import ModuleFqnToConfig
|
||||
|
||||
if should_skip(prefix, self.skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
module_fqn = prefix
|
||||
if isinstance(self.torchao_config, ModuleFqnToConfig):
|
||||
module_fqn_to_config = self.torchao_config.module_fqn_to_config
|
||||
c = None
|
||||
if module_fqn in module_fqn_to_config:
|
||||
assert not module_fqn.startswith("re:"), (
|
||||
"module fqn should not start with"
|
||||
"`re:`, which is used for specifying regex"
|
||||
)
|
||||
c = module_fqn_to_config[module_fqn]
|
||||
else:
|
||||
for maybe_module_fqn_pattern in module_fqn_to_config:
|
||||
if not maybe_module_fqn_pattern.startswith("re:"):
|
||||
continue
|
||||
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
|
||||
# we'll apply the config for first fully matched pattern
|
||||
c = module_fqn_to_config[maybe_module_fqn_pattern]
|
||||
break
|
||||
else:
|
||||
# fallback to use default if no module specific
|
||||
# config is provided
|
||||
c = module_fqn_to_config.get("_default", None)
|
||||
|
||||
if c is not None:
|
||||
current_torchao_config = TorchAOConfig(
|
||||
c, self.skip_modules, self.is_checkpoint_torchao_serialized
|
||||
)
|
||||
return TorchAOLinearMethod(current_torchao_config)
|
||||
else:
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
return TorchAOLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def torchao_quantize_param_data(
|
||||
param: torch.Tensor, torchao_config: Any
|
||||
) -> torch.nn.Parameter:
|
||||
"""Quantize a Tensor with torchao quantization specified by torchao_config
|
||||
|
||||
Args:
|
||||
param: weight parameter of the linear module
|
||||
torchao_config: type of quantization and their arguments we want to
|
||||
use to quantize the Tensor
|
||||
"""
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
|
||||
"""
|
||||
Avoid real weight allocation for faster load, since we will
|
||||
end up setting it to param.
|
||||
"""
|
||||
with torch.device("meta"):
|
||||
# linear can't be top level module since quantize_ is inplace
|
||||
# while some of our configs need to do module swap, and only non-top
|
||||
# level modules support module swap
|
||||
dummy_linear = torch.nn.Sequential(
|
||||
torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
||||
)
|
||||
|
||||
dummy_linear[0].weight = param
|
||||
quantize_(dummy_linear, torchao_config)
|
||||
return dummy_linear[0].weight
|
||||
|
||||
|
||||
class TorchAOLinearMethod(LinearMethodBase):
|
||||
"""Linear method for torchao.
|
||||
|
||||
Args:
|
||||
quant_config: The torchao quantization config, a string that encodes
|
||||
the type of quantization and all relevant arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: TorchAOConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
if self.quant_config.is_checkpoint_torchao_serialized:
|
||||
weight = torchao_quantize_param_data(
|
||||
weight, self.quant_config.torchao_config
|
||||
)
|
||||
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.quant_config.is_checkpoint_torchao_serialized:
|
||||
if not hasattr(layer, "weight"):
|
||||
return
|
||||
|
||||
# record attributes attached to the weight, so we can
|
||||
# recover later
|
||||
recorded_weight_attr = _get_weight_attrs(layer.weight)
|
||||
|
||||
layer.weight = Parameter(
|
||||
convert_to_packed_tensor_based_on_current_hardware(layer.weight),
|
||||
requires_grad=layer.weight.requires_grad,
|
||||
)
|
||||
|
||||
_restore_weight_attrs(layer.weight, recorded_weight_attr)
|
||||
return
|
||||
|
||||
# online quantize the weight if the checkpoint is not already
|
||||
# quantized by torchao
|
||||
recorded_weight_attr = _get_weight_attrs(layer.weight)
|
||||
|
||||
weight = torchao_quantize_param_data(
|
||||
layer.weight, self.quant_config.torchao_config
|
||||
)
|
||||
weight = torch.nn.Parameter(
|
||||
convert_to_packed_tensor_based_on_current_hardware(weight),
|
||||
weight.requires_grad,
|
||||
)
|
||||
|
||||
_restore_weight_attrs(weight, recorded_weight_attr)
|
||||
layer.register_parameter("weight", weight)
|
||||
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .layer_utils import replace_parameter, update_tensor_inplace
|
||||
|
||||
__all__ = ["update_tensor_inplace", "replace_parameter"]
|
||||
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024
|
||||
ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128]
|
||||
ALLSPARK_AMPERE_N_ALIGN = 16
|
||||
ALLSPARK_AMPERE_K_ALIGN = 16
|
||||
|
||||
|
||||
def check_allspark_supported_dtype_shape(
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
group_size: int,
|
||||
weight_dtype: ScalarType,
|
||||
act_dtype: torch.dtype,
|
||||
):
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
|
||||
# For Ampere GPU
|
||||
if device_capability >= 80 and device_capability < 90:
|
||||
if group_size != -1:
|
||||
return (
|
||||
False,
|
||||
"For Ampere GPU, AllSpark does not support group_size "
|
||||
f"= {group_size}. Only group_size = -1 are supported.",
|
||||
)
|
||||
|
||||
if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
"For Ampere GPU, AllSpark does not support "
|
||||
f"quant type ({weight_dtype}). Only quant type "
|
||||
f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported.",
|
||||
)
|
||||
|
||||
if (
|
||||
input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0
|
||||
or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"AllSpark needs input_size_per_partition % "
|
||||
f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "
|
||||
f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "
|
||||
"for Ampere GPU optimized kernels.",
|
||||
)
|
||||
|
||||
if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
|
||||
return (
|
||||
False,
|
||||
"AllSpark only supports act_dtype = float16 or bfloat16,"
|
||||
f"for Ampere GPU, but got act_dtype = {act_dtype}.",
|
||||
)
|
||||
else:
|
||||
return (
|
||||
False,
|
||||
"AllSpark currently does not support "
|
||||
f"device_capability = {device_capability}.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user