Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.scalar_type import ScalarType
@dataclass
class MPLinearLayerConfig:
full_weight_shape: tuple[int, int] # [in, out]
partition_weight_shape: tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
out_type: torch.dtype | None = None
class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(
self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: str | None = None,
w_gidx_param_name: str | None = None,
) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
if c.zero_points:
assert w_zp_param_name is not None
if c.has_g_idx:
assert w_gidx_param_name is not None
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
def _transform_param(
self, layer: torch.nn.Module, name: str | None, fn: Callable
) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
)
def _get_weight_params(
self, layer: torch.nn.Module
) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
torch.Tensor | None, # w_zp,
torch.Tensor | None, # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)

View File

@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
AllSparkLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
BitBLASLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
ConchLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501
CutlassW4A8LinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
Dynamic4bitLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
MacheteLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
MarlinLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
MPLinearKernel,
MPLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501
XPUwNa16LinearKernel,
)
from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
CutlassW4A8LinearKernel,
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
Dynamic4bitLinearKernel,
BitBLASLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
XPUwNa16LinearKernel,
]
def choose_mp_linear_kernel(
config: MPLinearLayerConfig, compute_capability: int | None = None
) -> type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get
the compute capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f" {kernel.__name__} disabled by environment variable"
)
continue
if (
compute_capability is not None
and kernel.get_min_capability() > compute_capability
):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute "
f" capability is {compute_capability}"
)
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
raise ValueError(
"Failed to find a kernel that can implement the "
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
)

View File

@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
check_allspark_supported_dtype_shape,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class AllSparkLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if c.has_g_idx:
return False, "Act reordering currently not supported by AllSpark"
if c.zero_points:
return False, "Zero points currently not supported by AllSpark"
return check_allspark_supported_dtype_shape(
c.partition_weight_shape[0], # in_features
c.partition_weight_shape[1], # out_features
c.group_size,
c.weight_type,
c.act_type,
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
# prepare the parameters required for the kernel
properties = torch.cuda.get_device_properties(device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor
gemm_args = {}
gemm_args["sm_count"] = sm_count
gemm_args["sm_version"] = sm_version
self.gemm_args = gemm_args
# transform param weight, scale
old_weight_param = getattr(layer, self.w_q_name)
old_scale_param = getattr(layer, self.w_s_name)
assert isinstance(old_weight_param, BasevLLMParameter)
permute_param_layout_(old_weight_param, input_dim=0, output_dim=1, packed_dim=0)
assert isinstance(old_scale_param, BasevLLMParameter)
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
# unpack weight from K / 4 x N int32 to K x N uint8
new_weight_param = torch.nn.Parameter(
old_weight_param.data, requires_grad=False
)
new_weight_param.data = (
new_weight_param.data.t().contiguous().view(dtype=torch.uint8)
)
new_weight_param.data = new_weight_param.data.t().contiguous()
new_scale_param = torch.nn.Parameter(old_scale_param.data, requires_grad=False)
# reorder K x N weight as N32K16 format for Ampere W8A16
new_weight_param.data, new_scale_param.data, _ = ops.allspark_repack_weight(
new_weight_param.data, new_scale_param.data, None, c.zero_points
)
replace_parameter(layer, self.w_q_name, new_weight_param.data)
replace_parameter(layer, self.w_s_name, new_scale_param.data)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
gemm_args = self.gemm_args
w_q, w_s, _, _ = self._get_weight_params(layer)
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
output = ops.allspark_w8a16_gemm(
a=reshaped_x,
b_qweight=w_q,
b_scales=w_s,
b_qzeros=None,
n=c.partition_weight_shape[1],
group_size=c.group_size,
sm_count=gemm_args["sm_count"],
sm_version=gemm_args["sm_version"],
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp=c.zero_points,
n32k16_reorder=True,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -0,0 +1,323 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from packaging import version
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES,
BITBLAS_SUPPORTED_GROUP_SIZES,
MINIMUM_BITBLAS_VERSION,
bitblas_make_empty_g_idx,
bitblas_sort_g_idx,
check_bitblas_supports_shape,
query_bitblas_supported_quant_types,
unpack_gptq_qweight,
unpack_gptq_qzeros,
)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
logger = init_logger(__name__)
class BitBLASLinearKernel(MPLinearKernel):
OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING: bool = True
MATMUL_LAYOUT: str = "nt"
BITBLAS_DTYPES: dict[torch.dtype, str] = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.half: "float16",
torch.int8: "int8",
}
bitblas_matmul: object = None
def __init__(
self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: str | None = None,
w_gidx_param_name: str | None = None,
bitblas_quant_config: QuantizationConfig | None = None,
):
self.quant_config = bitblas_quant_config
super().__init__(
c, w_q_param_name, w_s_param_name, w_zp_param_name, w_gidx_param_name
)
def repack_bitblas_from_gptq(
self,
b_q_weight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor | None = None,
):
from bitblas.quantization.utils import general_compress
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
quant_config = self.quant_config
# qweight in gptq old quant linear stored with
# (outfeatures, infeatures), should be transposed.
qweight = b_q_weight.T.contiguous().view(quant_config.torch_storage_dtype) # type: ignore[union-attr]
intweight = unpack_gptq_qweight(qweight, quant_config.weight_bits).contiguous() # type: ignore[union-attr]
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
intweight.cpu()
).cuda()
# scales in gptq old quant linear stored with
# (infeatures // group_size, outfeatures), should be transposed.
scales = scales.T.contiguous()
if qzeros is None:
return qweight, scales, None
# qzeros should be de-quantized to int zeros.
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
zeros: torch.Tensor | None = None
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
if zeros_mode == "original":
zeros = intzeros.to(torch.float16).contiguous()
elif zeros_mode == "rescale":
assert zeros is not None, "zeros should not be None"
zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
elif zeros_mode == "quantized":
zeros = (
torch.Tensor(
general_compress(
intzeros.T.contiguous().cpu().numpy(),
weight_bits,
)
)
.to(qweight.device)
.to(
quant_config.torch_storage_dtype # type: ignore[union-attr]
)
.contiguous()
)
else:
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
return qweight, scales, zeros
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
is_bitblas_installed = True
try:
import bitblas
if version.parse(bitblas.__version__) < version.parse(
MINIMUM_BITBLAS_VERSION
):
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
)
except ImportError:
is_bitblas_installed = False
if not is_bitblas_installed:
return (
False,
"bitblas is not installed. Please install bitblas "
"by running `pip install bitblas>="
f"{MINIMUM_BITBLAS_VERSION}`",
)
quant_types = query_bitblas_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, (
f"Quant type ({c.weight_type}) not supported by"
f" BitBLAS, supported types are: {quant_types}"
)
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
return False, (
f"Group size ({c.group_size}) not supported by "
"BitBLAS, supported group sizes are: "
f"{BITBLAS_SUPPORTED_GROUP_SIZES}"
)
return check_bitblas_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size,
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
quant_config = self.quant_config
# Default names since bitblas requires empty parameters for these,
# TODO: remove this requirement from bitblas (allow optional tensors)
if getattr(self, "w_gidx_name", None) is None:
self.w_gidx_name: str = "g_idx"
if getattr(self, "w_zp_name", None) is None:
self.w_zp_name: str = "qzeros"
if c.has_g_idx:
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
getattr(layer, self.w_gidx_name)
)
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)
if c.zero_points:
raise NotImplementedError("Zero points not supported by BitBLAS")
else:
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
# Repack weights
bitblas_qweight, bitblas_scales, bitblas_qzeros = self.repack_bitblas_from_gptq(
layer.qweight,
layer.scales,
None if quant_config.is_sym else layer.qzeros, # type: ignore[union-attr]
)
replace_parameter(layer, self.w_q_name, bitblas_qweight)
replace_parameter(layer, self.w_s_name, bitblas_scales)
if bitblas_qzeros is not None:
replace_parameter(layer, self.w_zp_name, bitblas_qzeros)
def configure_bitblas_matmul(
self,
infeatures: int,
outfeatures: int,
params_dtype: torch.dtype,
bias: bool,
) -> None:
enable_tuning = self.ENABLE_TUNING
layout = self.MATMUL_LAYOUT
bits = self.quant_config.weight_bits # type: ignore[union-attr]
self._configure_bitblas_matmul(
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
)
def _configure_bitblas_matmul(
self,
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
quant_config = self.quant_config
with_scaling = False
with_zeros = False
group_size = quant_config.group_size # type: ignore[union-attr]
zeros_mode = quant_config.zeros_mode # type: ignore[union-attr]
if quant_config.quant_method == "gptq": # type: ignore[union-attr]
with_scaling = True
with_zeros = True
W_dtype = f"uint{bits}"
if quant_config.is_sym: # type: ignore[union-attr]
with_zeros = False
W_dtype = f"int{bits}"
else:
raise ValueError(
f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr]
) # type: ignore[union-attr]
matmul_config = MatmulConfig(
M=self.OPT_FEATURES,
N=outfeatures,
K=infeatures,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=bitblas_dtype,
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
storage_dtype=quant_config. # type: ignore[union-attr]
storage_dtype, # type: ignore[union-attr]
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
with_bias=bias,
layout=layout,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning
)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET
)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET
)
TUNING_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database."
)
logger.info(TUNING_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created without tuning. "
logger.info(_message)
else:
_message = f"BitBLAS Operator {config} retrieved from cache."
logger.info(_message)
return bitblas_matmul
def apply_gptq_bitblas_linear(
self,
layer: torch.nn.Module,
x: torch.Tensor,
) -> torch.Tensor:
output_size_per_partition = self.config.partition_weight_shape[1]
out_shape = x.shape[:-1] + (output_size_per_partition,)
args = [x, layer.qweight, layer.scales]
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
args.append(layer.qzeros)
output = self.bitblas_matmul(*args) # type: ignore[operator]
return output.view(out_shape)
def apply_weights(self, layer, x, bias=None):
NOT_IMPLEMENT_MESSAGE = (
f"{self.__class__.__name__}.apply_weights is not implemented. "
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead"
)
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec
from typing import Final
import torch
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
scalar_types.uint4,
scalar_types.uint8,
scalar_types.uint4b8,
scalar_types.uint8b128,
]
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
class ConchLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
error_msg = (
f"Weight type ({c.weight_type}) not supported by "
"ConchLinearKernel, supported types are: "
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
)
return False, error_msg
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
error_msg = (
f"Group size ({c.group_size}) not supported by "
"ConchLinearKernel, supported group sizes are: "
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
)
return False, error_msg
if find_spec("conch") is None:
error_msg = (
"conch-triton-kernels is not installed, please "
"install it via `pip install conch-triton-kernels` "
"and try again!"
)
return False, error_msg
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = x.data.contiguous()
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
from conch.ops.quantization.gemm import mixed_precision_gemm
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
output = mixed_precision_gemm(
x=x,
w_q_packed=w_q.data,
w_s=w_s.data,
w_zp=w_zp.data if w_zp is not None else None,
weight_size_bits=self.config.weight_type.size_bits,
weight_bias=self.config.weight_type.bias,
group_size=self.config.group_size,
)
if bias is not None:
output.add_(bias) # In-place add
return output

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class CutlassW4A8LinearKernel(MPLinearKernel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# dynamic per-tok fp8 activation quantization
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "CUTLASS only supported on CUDA"
if not current_platform.is_device_capability(90):
return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)"
if c.act_type != torch.float8_e4m3fn:
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
if c.has_g_idx:
return False, "Act reordering not supported by CUTLASS W4A8"
if c.zero_points:
return False, "Zero points not supported by CUTLASS W4A8"
if c.weight_type != scalar_types.int4:
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"CUTLASS W4A8, only supported int4",
)
if c.group_size != 128:
return False, "Only group_size 128 is supported"
in_features, out_features = c.partition_weight_shape
if in_features % 128 or out_features % 128:
return (
False,
f"K and N must be divisible by 128, got {c.partition_weight_shape}",
)
if c.out_type != torch.bfloat16:
return (
False,
f"Only bfloat16 output type currently supportedgot {c.out_type=}",
)
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous().to(torch.float8_e4m3fn)
x.data = ops.cutlass_pack_scale_fp8(x.data)
return x
w_s = getattr(layer, self.w_s_name)
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
w_s.data = fp8_scales
# register per-channel scales
layer.register_parameter(
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
)
# Encode/reorder weights and pack scales
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
w_ch_s = layer.weight_chan_scale
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
x_2d, act_scales = self.quant_fp8(x_2d)
output = ops.cutlass_w4a8_mm(
a=x_2d,
b_q=w_q,
b_group_scales=w_s,
b_group_size=c.group_size,
a_token_scales=act_scales,
b_channel_scales=w_ch_s,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class Dynamic4bitLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
@classmethod
def get_min_capability(cls) -> int:
return 1
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "Only CPU is supported"
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return False, f"Unsupported quant type {c.weight_type}"
if (
current_platform.get_cpu_architecture() == CpuArchEnum.ARM
and c.act_type
not in [
torch.float32,
]
):
return False, "Dynamic4bitLinearKernel on Arm requires Float32 activations"
if c.full_weight_shape[0] % c.group_size != 0:
return (
False,
f"Group size ({c.group_size}) does not evenly divide"
" the number of input features "
f"({c.full_weight_shape[0]})",
)
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
try:
# Attempt to retrieve the operation
_ = torch.ops.aten._dyn_quant_matmul_4bit
except AttributeError:
return (
False,
f"PyTorch {torch.__version__} does not support"
" _dyn_quant_matmul_4bit. Install a newer version",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
packed_weight = getattr(layer, self.w_q_name)
packed_weight = packed_weight.add(8)
uint8_packed = (packed_weight[::, 1::2] << 4 | packed_weight[::, ::2]).to(
torch.uint8
)
scales = getattr(layer, self.w_s_name)
block_size = c.group_size
# Handle scaling factors for partitioned weights
if block_size == c.partition_weight_shape[0]:
scales = scales.to(
torch.float32
) # Float32 & Bfloat16 variants requires float32 scales
scales = scales.view(-1, 1) # Channel-wise scales
if layer.bias is not None:
layer.bias = layer.bias.to(
torch.float32
) # Float32 & Bfloat16 variants requires float32 bias
else:
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
scales = scales.to(torch.bfloat16)
# Repack weights as per kernel requirement
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
uint8_packed,
scales,
layer.bias,
block_size,
c.partition_weight_shape[0],
c.partition_weight_shape[1],
)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(w, requires_grad=False)
)
setattr(layer, self.w_s_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
w_q = getattr(layer, self.w_q_name)
output = torch.ops.aten._dyn_quant_matmul_4bit(
x_2d,
w_q,
c.group_size,
c.partition_weight_shape[0],
c.partition_weight_shape[1],
)
return output.reshape(out_shape)

View File

@@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class ExllamaLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
# currently untested so not added to the list
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
return (
False,
"Act reordering currently not supported by Exllama, "
"when the input features are partitioned across "
"devices",
)
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
return (
False,
"Output features must be a multiple of the pack "
"factor (32 / num_bits) so that we can correctly "
"pack the zero points",
)
if c.act_type != torch.float16:
return False, "Exllama only supports float16 activations"
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"Exllama, supported types are: "
f"{cls.SUPPORTED_QUANT_TYPES}",
)
if c.full_weight_shape[0] % c.group_size != 0:
return (
False,
f"Group size ({c.group_size}) does not evenly divide"
" the number of input features "
f"({c.full_weight_shape[0]})",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
# For Exllama, we need to set a zero-point tensor if there is not one
if not c.zero_points:
self.w_zp_name = "qzeros"
device = getattr(layer, self.w_q_name).device
groups = c.partition_weight_shape[0] // c.group_size
out_features = c.partition_weight_shape[1]
if c.weight_type.has_bias():
# if the type has a bias we have to create a zeros tensor that
# contains the bias values repeated for each group (-1 due to
# a bug in the original GPTQ checkpoint format leading to
# exllama kernel adding 1 to the zero points during inference)
# Documentation of the bug can be found here:
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
zeros = torch.full(
(groups, out_features),
c.weight_type.bias - 1,
dtype=torch.int32,
device=device,
)
else:
raise NotImplementedError(
"A 0 zero-point is not supported by Exllama due to "
"a bug in the original GPTQ checkpoint format leading to "
"exllama kernel adding 1 to the zero points during "
"inference"
)
zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
setattr(
layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)
)
if c.has_g_idx:
def transform_w_g_idx(x):
# Exllama wants the permutation array instead of the group
# indices
return torch.argsort(x).to(torch.int)
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
else:
self.w_gidx_name = "g_idx"
empty_g_idx = torch.nn.Parameter(
torch.empty((0,), dtype=torch.int, device=device), requires_grad=False
)
setattr(layer, self.w_gidx_name, empty_g_idx)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
assert self.w_gidx_name is not None
g_idx = getattr(layer, self.w_gidx_name)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x_cont = x.data.contiguous()
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
return x_cont
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x.to(dtype=c.act_type)
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
# However, the MPLinearLayerConfig doesn't contain format info.
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
use_v2_format = False
assert w_zp is not None, "Zero points are required by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm(
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)

View File

@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import partial
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
check_machete_supports_shape,
query_machete_supported_group_sizes,
query_machete_supported_quant_types,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32,
unpack_quantized_values_into_int32,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MacheteLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
# Machete uses CUTLASS, so it can only be compatible with Nvidia
if not current_platform.is_cuda():
return False, "Machete only supported on CUDA"
if not current_platform.is_device_capability(90):
return False, "Machete requires compute capability of 90 (Hopper)"
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
return (
False,
"Act reordering currently not supported by Machete, "
"when the input features are partitioned across "
"devices",
)
if c.weight_type not in query_machete_supported_quant_types(c.zero_points):
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"Machete, supported types are: "
f"{query_machete_supported_quant_types(c.zero_points)}",
)
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
return (
False,
f"Group size ({c.group_size}) not supported by "
"Machete, supported group sizes are: "
f"{query_machete_supported_group_sizes(c.act_type)}",
)
return check_machete_supports_shape(
c.partition_weight_shape[0], c.partition_weight_shape[1]
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
self.act_perm = lambda x: x[:, perm]
# use `ops.permute_cols` if possible
if (
c.act_type in [torch.float16, torch.bfloat16]
and c.partition_weight_shape[0] % 8 == 0
):
self.act_perm = partial(ops.permute_cols, perm=perm)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_quantized_values_into_int32(
x.data, c.weight_type, packed_dim=0
)
x_perm = x_unpacked[perm, :]
x.data = pack_quantized_values_into_int32(
x_perm, c.weight_type, packed_dim=0
)
x.data = ops.machete_prepack_B(
x.data.t().contiguous().t(),
a_type=c.act_type,
b_type=c.weight_type,
group_scales_type=c.act_type,
)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
def transform_w_zp(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
x_unpacked = unpack_quantized_values_into_int32(
x.data, c.weight_type, packed_dim=1
)
w_s = getattr(layer, self.w_s_name).data
# pre-apply scales to zero-points
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
return x
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
if c.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
if c.zero_points:
assert w_zp is not None
else:
w_zp = None
output = ops.machete_mm(
a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_group_zeros=w_zp,
b_group_scales=w_s,
b_group_size=c.group_size,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES,
apply_gptq_marlin_linear,
check_marlin_supports_shape,
marlin_act_int8_process_scales,
marlin_is_k_full,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
marlin_sort_g_idx,
marlin_zero_points,
query_marlin_supported_quant_types,
unpack_cols,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MarlinLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
# Marlin uses inline PTX, so it can only be compatible with Nvidia
if not current_platform.is_cuda():
return False, "Marlin only supported on CUDA"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return (
False,
f"Quant type ({c.weight_type}) not supported by"
f" Marlin, supported types are: {quant_types}",
)
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return (
False,
f"Group size ({c.group_size}) not supported by "
"Marlin, supported group sizes are: "
f"{MARLIN_SUPPORTED_GROUP_SIZES}",
)
return check_marlin_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size,
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
if is_a_8bit:
assert c.weight_type == scalar_types.uint4b8, (
"W8A8 is not supported by marlin kernel."
)
if c.act_type == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
getattr(layer, self.w_s_name).data = (
getattr(layer, self.w_s_name).data * 512
)
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace_new(device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "w_zp"
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(
x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(
x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size,
is_a_8bit=is_a_8bit,
)
if c.group_size == -1:
num_groups = 1
else:
num_groups = c.partition_weight_shape[0] // c.group_size
if c.act_type == torch.int8 and num_groups > 1:
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
layer.register_parameter(
"input_global_scale",
torch.nn.Parameter(input_global_scale, requires_grad=False),
)
else:
layer.input_global_scale = None
return x
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name)
)
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
grouped_k = (
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
)
self._transform_param(
layer,
self.w_zp_name,
lambda x: marlin_zero_points(
unpack_cols(
x.t(),
c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1],
),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
),
)
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data = marlin_permute_bias(layer.bias)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias,
input_dtype=c.act_type,
)

View File

@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class XPUwNa16LinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 0
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_xpu():
return False, "IPEX wNa16 only supported on XPU/CPU devices"
# TODO: (yiliu30) relax these restrictions in later PRs
if c.zero_points:
return False, "Zero points not supported for Now"
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from packaging import version
MIN_IPEX_VERSION = "2.6.0"
bias = layer.bias if not layer.skip_bias_add else None
try:
import intel_extension_for_pytorch as ipex
if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION):
raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}."
)
except ImportError as err:
raise ImportError(
"Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method."
) from err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode = ipex.quantization.WoqLowpMode.INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype,
lowp_mode=lowp_mode,
act_quant_mode=act_quant_mode,
group_size=self.config.group_size,
weight_qscheme=ipex.quantization.WoqWeightQScheme.SYMMETRIC,
)
qweight = layer.weight_packed
g_idx = layer.weight_g_idx if self.config.has_g_idx else None
scales = layer.weight_scale
qzeros = None
if self.config.zero_points:
qzeros = layer.weight_zero_point.contiguous()
qweight = qweight.t().contiguous()
scales = scales.t().contiguous()
layer.ipex_output_size = self.config.partition_weight_shape[1]
layer.ipex_qlinear = (
ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight(
qweight,
scales,
qzeros,
in_features=self.config.partition_weight_shape[0],
out_features=self.config.partition_weight_shape[1],
qconfig=qconfig,
g_idx=g_idx,
bias=bias,
group_size=self.config.group_size,
quant_method=0, # `0` stands for the IPEX GPTQ
)
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size,))

View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
@dataclass
class ScaledMMLinearLayerConfig:
is_channelwise: bool
is_static_input_scheme: bool
input_symmetric: bool
class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(
self,
c: ScaledMMLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
i_s_param_name: str,
i_zp_param_name: str,
azp_adj_param_name: str,
) -> None:
assert self.can_implement(c)
assert self.is_supported()
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.i_s_name = i_s_param_name
self.i_zp_name = i_zp_param_name
self.azp_adj_name = azp_adj_param_name
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
def _get_weight_params(
self, layer: torch.nn.Module
) -> tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.i_s_name),
getattr(layer, self.i_zp_name),
getattr(layer, self.azp_adj_name),
)

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel,
)
from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}
def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None
) -> type[ScaledMMLinearKernel]:
"""
Choose an ScaledMMLinearKernel that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (ScaledMMLinearLayerConfig): Description of the linear layer
to be implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the
compute capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
type[ScaledMMLinearKernel]: Chosen kernel.
"""
failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
continue
# If the current platform uses compute_capability,
# make sure the kernel supports the compute capability.
is_supported, reason = kernel.is_supported(compute_capability)
if not is_supported:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue
can_implement, reason = kernel.can_implement(config)
if not can_implement:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue
return kernel
raise ValueError(
"Failed to find a kernel that can implement the "
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons)
)

View File

@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "currently supported on non-ROCm platform.",
)
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 90:
return False, f"requires capability 90, got {compute_capability}"
try:
import aiter # noqa: F401 # deliberately attempt to import aiter
except Exception:
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "installed on ROCm.",
)
if not rocm_aiter_ops.is_linear_enabled():
return (
False,
"AiterScaledMMLinearKernel is disabled. "
+ "Enable by setting `VLLM_ROCM_USE_AITER=1` "
+ "and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
)
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return (
False,
"AiterScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
`AiterScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
assert symmetric, (
"AiterScaledMMLinearKernel only supports symmetric quantization."
)
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
assert x_zp is None, (
"AiterScaledMMLinearKernel only supports symmetric quantization."
)
out_dtype = x.dtype
assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype
m = x_q.shape[0] # a
n = w_q.shape[1] # b
per_tensor_scale_a = x_s.numel() == 1
per_tensor_scale_b = w_s.numel() == 1
per_token_scale_a = x_s.numel() == m
per_channel_scale_b = w_s.numel() == n
# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports:
# - per-tensor-per-tensor a8w8 scaled GEMM, and
# - per-token-per-channel a8w8 scaled GEMM
assert (per_tensor_scale_a and per_tensor_scale_b) or (
per_token_scale_a and per_channel_scale_b
), (
"Currently only support per-tensor-per-tensor GEMM "
+ " and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` "
+ "does not support AITER block scaled GEMM."
)
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

View File

@@ -0,0 +1,220 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "Requires CPU."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = getattr(layer, self.w_q_name)
dtype = weight.dtype
N, K = weight.size()
if (
current_platform.get_cpu_architecture() == CpuArchEnum.X86
and envs.VLLM_CPU_SGL_KERNEL
and self.config.input_symmetric
and check_cpu_sgl_kernel(N, K, dtype)
):
self.linear_method = self._apply_weights_sgl
self.process_weights_for_sgl(layer)
else:
self.linear_method = self._apply_weights_onednn
self.process_weights_for_onednn(layer)
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Transpose to [K, N] for convenience
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# oneDNN kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
azp = (
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# Different from cutlass, oneDNN kernels only need the AZP adjustment
# term for dynamic quantization. And s_b should be folded into the
# term. Such as:
# s_a * s_b * [(A - zp_a)B] + bias =
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * GEMM_output - s_a * zp_a * adj + bias
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
weight = getattr(layer, self.w_q_name)
weight_scale = getattr(layer, self.w_s_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
azp_adj = azp_adj * weight_scale.squeeze()
setattr(
layer,
self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
weight = getattr(layer, self.w_q_name)
self.dnnl_handler = ops.create_onednn_scaled_mm(
weight,
getattr(layer, self.w_s_name),
torch.get_default_dtype(),
getattr(layer, self.i_s_name) is None,
not self.config.input_symmetric,
32,
)
# weight is prepacked and maintained by the dnnl_handler,
# release the original weight
setattr(layer, self.w_q_name, None)
del weight
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
# WEIGHT
weight = getattr(layer, self.w_q_name)
packed_weight = torch.ops._C.convert_weight_packed(weight)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
)
if layer.bias is not None:
bias = layer.bias
layer.register_parameter(
"bias_fp32", torch.nn.Parameter(bias.float().data, requires_grad=False)
)
# WEIGHT SCALE
# CPU SGL kernels only support per-channel.
# For per-tensor quant, convert to the per-channel case.
weight_scale = getattr(layer, self.w_s_name)
if not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.linear_method(
layer,
x,
bias,
)
def _apply_weights_onednn(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
x, i_s, i_zp, self.config.input_symmetric
)
m = x.size(0)
n = self.dnnl_handler.n
out = torch.empty((m, n), dtype=x.dtype)
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias)
return out
def _apply_weights_sgl(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
return torch.ops._C.int8_scaled_mm_with_quant(
x,
w_q,
w_s,
layer.bias_fp32 if bias is not None else None,
x.dtype,
True,
)

View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "Requires CUDA."
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 75:
return False, f"requires capability 75, got {compute_capability}"
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not self.config.input_symmetric:
weight = getattr(layer, self.w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
setattr(
layer,
self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=symmetric
)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
static = i_zp is not None
azp = None if static else x_zp
return ops.cutlass_scaled_mm_azp(
x_q,
w_q,
scale_a=x_s,
scale_b=w_s,
out_dtype=x.dtype,
azp_adj=azp_adj,
azp=azp,
bias=bias,
)
return ops.cutlass_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)

View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
triton_scaled_mm,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.is_cuda_alike():
return True, None
return False, "Requires ROCm or CUDA."
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return False, "Only symmetric input is supported."
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=True
)
assert x_zp is None, "Triton kernel only supports symmetric quantization"
return triton_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
import torch
from functorch.experimental.control_flow import cond # noqa: F401
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "Requires TPU."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU."
if c.is_static_input_scheme:
return False, "ScaledMMXLA requires dynamic activation scales."
if not c.input_symmetric:
return False, "ScaledMMXLA requires symmetric activation scales."
if not c.is_channelwise:
return False, "ScaledMMXLA requires channelwise weight scales"
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# [out, in] (different than cutlass_scaled_mm)
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
)
# WEIGHT SCALE
# XLA kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
# [out_channel,] (different than cutlass_scaled_mm)
weight_scale = weight_scale.squeeze(-1)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# Only support symmetric dynamic activation quantization.
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
# Filter warning for cond usage in apply_weights. It is okay
# to specialize the graph since bias is not dynamic.
warnings.filterwarnings(
"ignore",
message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
)
def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
return x
def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
return x + bias
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
out = torch.ops.xla.quantized_matmul_int8(
x,
w_q,
w_s,
quantize_activation=True,
)
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])