update
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
|
||||
@dataclass
|
||||
class MPLinearLayerConfig:
|
||||
full_weight_shape: tuple[int, int] # [in, out]
|
||||
partition_weight_shape: tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
out_type: torch.dtype | None = None
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: str | None = None,
|
||||
w_gidx_param_name: str | None = None,
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
if c.zero_points:
|
||||
assert w_zp_param_name is not None
|
||||
if c.has_g_idx:
|
||||
assert w_gidx_param_name is not None
|
||||
self.w_zp_name = w_zp_param_name
|
||||
self.w_gidx_name = w_gidx_param_name
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_param(
|
||||
self, layer: torch.nn.Module, name: str | None, fn: Callable
|
||||
) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
old_param = getattr(layer, name)
|
||||
new_param = fn(old_param)
|
||||
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||
# compatibility
|
||||
replace_parameter(
|
||||
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
|
||||
)
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module
|
||||
) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
torch.Tensor | None, # w_zp,
|
||||
torch.Tensor | None, # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.w_zp_name or "", None),
|
||||
getattr(layer, self.w_gidx_name or "", None),
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
|
||||
AllSparkLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
|
||||
ConchLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
|
||||
CPUWNA16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
|
||||
CutlassW4A8LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
|
||||
Dynamic4bitLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MPLinearKernel",
|
||||
"MPLinearLayerConfig",
|
||||
"AllSparkLinearKernel",
|
||||
"ConchLinearKernel",
|
||||
"CPUWNA16LinearKernel",
|
||||
"CutlassW4A8LinearKernel",
|
||||
"Dynamic4bitLinearKernel",
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
]
|
||||
116
vllm/model_executor/kernels/linear/mixed_precision/allspark.py
Normal file
116
vllm/model_executor/kernels/linear/mixed_precision/allspark.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
check_allspark_supported_dtype_shape,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class AllSparkLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by AllSpark"
|
||||
|
||||
return check_allspark_supported_dtype_shape(
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.group_size,
|
||||
c.weight_type,
|
||||
c.act_type,
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
# prepare the parameters required for the kernel
|
||||
properties = torch.cuda.get_device_properties(device.index)
|
||||
sm_count = num_compute_units(device.index)
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
gemm_args = {}
|
||||
gemm_args["sm_count"] = sm_count
|
||||
gemm_args["sm_version"] = sm_version
|
||||
|
||||
self.gemm_args = gemm_args
|
||||
|
||||
# transform param weight, scale
|
||||
old_weight_param = getattr(layer, self.w_q_name)
|
||||
old_scale_param = getattr(layer, self.w_s_name)
|
||||
|
||||
assert isinstance(old_weight_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_weight_param, input_dim=0, output_dim=1, packed_dim=0)
|
||||
|
||||
assert isinstance(old_scale_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
|
||||
|
||||
# unpack weight from K / 4 x N int32 to K x N uint8
|
||||
new_weight_param = torch.nn.Parameter(
|
||||
old_weight_param.data, requires_grad=False
|
||||
)
|
||||
new_weight_param.data = (
|
||||
new_weight_param.data.t().contiguous().view(dtype=torch.uint8)
|
||||
)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous()
|
||||
|
||||
new_scale_param = torch.nn.Parameter(old_scale_param.data, requires_grad=False)
|
||||
|
||||
# reorder K x N weight as N32K16 format for Ampere W8A16
|
||||
new_weight_param.data, new_scale_param.data, _ = ops.allspark_repack_weight(
|
||||
new_weight_param.data, new_scale_param.data, None, c.zero_points
|
||||
)
|
||||
|
||||
replace_parameter(layer, self.w_q_name, new_weight_param.data)
|
||||
replace_parameter(layer, self.w_s_name, new_scale_param.data)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
gemm_args = self.gemm_args
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
a=reshaped_x,
|
||||
b_qweight=w_q,
|
||||
b_scales=w_s,
|
||||
b_qzeros=None,
|
||||
n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
sm_count=gemm_args["sm_count"],
|
||||
sm_version=gemm_args["sm_version"],
|
||||
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp=c.zero_points,
|
||||
n32k16_reorder=True,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
140
vllm/model_executor/kernels/linear/mixed_precision/conch.py
Normal file
140
vllm/model_executor/kernels/linear/mixed_precision/conch.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from importlib.util import find_spec
|
||||
from typing import Final
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint8,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.uint8b128,
|
||||
]
|
||||
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
|
||||
|
||||
|
||||
class ConchLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
|
||||
error_msg = (
|
||||
f"Weight type ({c.weight_type}) not supported by "
|
||||
"ConchLinearKernel, supported types are: "
|
||||
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
|
||||
error_msg = (
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"ConchLinearKernel, supported group sizes are: "
|
||||
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
if find_spec("conch") is None:
|
||||
error_msg = (
|
||||
"conch-triton-kernels is not installed, please "
|
||||
"install it via `pip install conch-triton-kernels` "
|
||||
"and try again!"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zero_point` is: {input_dim = 1, output_dim = 0, packed_dim = 0}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_zp(x):
|
||||
# Zero points are stored PACKED as [N//pack_factor, K//G]
|
||||
# The Conch kernel expects UNPACKED zeros: [K//G, N]
|
||||
# We need to unpack and reorder
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
packed = x.data # shape: [N//pack_factor, K//G], dtype: int32
|
||||
|
||||
# Determine packing based on weight bit width
|
||||
size_bits = self.config.weight_type.size_bits
|
||||
pack_factor = 32 // size_bits # 8 for 4-bit, 4 for 8-bit
|
||||
mask = (1 << size_bits) - 1 # 0xF for 4-bit, 0xFF for 8-bit
|
||||
|
||||
n_packed, k_groups = packed.shape
|
||||
n_full = n_packed * pack_factor
|
||||
|
||||
# Unpack using vectorized bitwise ops
|
||||
# shifts = [0, size_bits, 2*size_bits, ...] for each packed position
|
||||
shifts = torch.arange(
|
||||
0, 32, size_bits, dtype=torch.int32, device=packed.device
|
||||
)
|
||||
# packed: [N//pack_factor, K//G] -> [N//pack_factor, K//G, 1]
|
||||
# shifts: [pack_factor] -> [1, 1, pack_factor]
|
||||
# Result: [N//pack_factor, K//G, pack_factor]
|
||||
unpacked = (packed.unsqueeze(-1) >> shifts) & mask
|
||||
|
||||
# Permute to [K//G, N//pack_factor, pack_factor] then reshape to [K//G, N]
|
||||
unpacked = unpacked.permute(1, 0, 2).reshape(k_groups, n_full)
|
||||
|
||||
x.data = unpacked.to(torch.uint8).contiguous()
|
||||
|
||||
# Update metadata - zeros are no longer packed
|
||||
if hasattr(x, "_input_dim"):
|
||||
x._input_dim = 0
|
||||
if hasattr(x, "_output_dim"):
|
||||
x._output_dim = 1
|
||||
if hasattr(x, "_packed_factor"):
|
||||
x._packed_factor = 1
|
||||
return x
|
||||
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
if self.config.zero_points:
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from conch.ops.quantization.gemm import mixed_precision_gemm
|
||||
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
output = mixed_precision_gemm(
|
||||
x=x,
|
||||
w_q_packed=w_q.data,
|
||||
w_s=w_s.data,
|
||||
w_zp=w_zp.data if w_zp is not None else None,
|
||||
weight_size_bits=self.config.weight_type.size_bits,
|
||||
weight_bias=self.config.weight_type.bias,
|
||||
group_size=self.config.group_size,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
126
vllm/model_executor/kernels/linear/mixed_precision/cpu.py
Normal file
126
vllm/model_executor/kernels/linear/mixed_precision/cpu.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_CPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
|
||||
|
||||
|
||||
class CPUWNA16LinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "CPUWNA16 only supported on CPU"
|
||||
|
||||
if c.weight_type not in _CPUWNA16_SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"CPUWNA16, supported types are: "
|
||||
f"{_CPUWNA16_SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
|
||||
if c.group_size != -1 and c.group_size % 2 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"CPUWNA16, supported group sizes are multiples of 2",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[0] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Input size ({c.partition_weight_shape[0]}) not supported by "
|
||||
"CPUWNA16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Output size ({c.partition_weight_shape[1]}) not supported by "
|
||||
"CPUWNA16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
|
||||
def _process_gptq_weights(self, layer: torch.nn.Module):
|
||||
packed_weight = layer.qweight.data
|
||||
bits = self.config.weight_type.mantissa
|
||||
pack_factor = 32 // bits
|
||||
p_w_k, p_w_n = packed_weight.size()
|
||||
input_size = p_w_k * pack_factor
|
||||
output_size = p_w_n
|
||||
isa_hint = _get_isa_hint(layer.scales.dtype)
|
||||
layer.isa_hint = isa_hint
|
||||
|
||||
layer.qzeros = None
|
||||
if not self.config.has_g_idx:
|
||||
layer.g_idx = None
|
||||
|
||||
# convert input dim packed to output dim packed
|
||||
weight = unpack_quantized_values_into_int32(
|
||||
packed_weight, self.config.weight_type, 1
|
||||
).view(p_w_k, p_w_n, pack_factor)
|
||||
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
|
||||
weight = pack_quantized_values_into_int32(weight, self.config.weight_type, 1)
|
||||
# make 16 output channel as a block and transpose to the make
|
||||
# the block contigous
|
||||
weight = (
|
||||
weight.view(input_size, -1, 16 // pack_factor)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1, input_size * 16 // pack_factor)
|
||||
.contiguous()
|
||||
)
|
||||
layer.qweight.data = weight
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if not self.config.zero_points:
|
||||
# GPTQ
|
||||
self._process_gptq_weights(layer)
|
||||
else:
|
||||
# AWQ
|
||||
raise NotImplementedError("AWQ is not supported in CPUWNA16LinearKernel")
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x = ops.cpu_gemm_wna16(
|
||||
input=x,
|
||||
q_weight=layer.qweight,
|
||||
scales=layer.scales,
|
||||
zeros=layer.qzeros,
|
||||
g_idx=layer.g_idx,
|
||||
bias=bias,
|
||||
pack_factor=8, # 32 // 4
|
||||
isa_hint=layer.isa_hint,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def _get_isa_hint(dtype: torch.dtype) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,):
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
131
vllm/model_executor/kernels/linear/mixed_precision/cutlass.py
Normal file
131
vllm/model_executor/kernels/linear/mixed_precision/cutlass.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
convert_bf16_scales_to_fp8,
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# dynamic per-tok fp8 activation quantization
|
||||
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUTLASS only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)"
|
||||
|
||||
if c.act_type != torch.float8_e4m3fn:
|
||||
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
|
||||
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering not supported by CUTLASS W4A8"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points not supported by CUTLASS W4A8"
|
||||
|
||||
if c.weight_type != scalar_types.int4:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"CUTLASS W4A8, only supported int4",
|
||||
)
|
||||
|
||||
if c.group_size != 128:
|
||||
return False, "Only group_size 128 is supported"
|
||||
|
||||
in_features, out_features = c.partition_weight_shape
|
||||
if in_features % 128 or out_features % 128:
|
||||
return (
|
||||
False,
|
||||
f"K and N must be divisible by 128, got {c.partition_weight_shape}",
|
||||
)
|
||||
|
||||
if c.out_type != torch.bfloat16:
|
||||
return (
|
||||
False,
|
||||
f"Only bfloat16 output type currently supportedgot {c.out_type=}",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
|
||||
torch.cuda.synchronize()
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous().to(torch.float8_e4m3fn)
|
||||
x.data = ops.cutlass_pack_scale_fp8(x.data)
|
||||
return x
|
||||
|
||||
w_s = getattr(layer, self.w_s_name)
|
||||
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
|
||||
w_s.data = fp8_scales
|
||||
|
||||
# register per-channel scales
|
||||
layer.register_parameter(
|
||||
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
|
||||
)
|
||||
|
||||
# Encode/reorder weights and pack scales
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
w_ch_s = layer.weight_chan_scale
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
x_2d, act_scales = self.quant_fp8(x_2d)
|
||||
output = ops.cutlass_w4a8_mm(
|
||||
a=x_2d,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
a_token_scales=act_scales,
|
||||
b_channel_scales=w_ch_s,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
# This implementation is for the KleidiAI-accelerated w4a8int quantization
|
||||
# scheme on Arm CPUs:
|
||||
# torch.ops.aten._dyn_quant_matmul_4bit performs dynamic quantized matmul
|
||||
# it takes:
|
||||
# - int4 weights packed along with bias/scales by
|
||||
# torch.ops.aten._dyn_quant_pack_4bit_weight
|
||||
# - float32/bfloat16 activations
|
||||
# then it leverages KleidiAI ukernels that:
|
||||
# - dynamically quantize the activations to int8
|
||||
# - unpack the int4 weights to int8
|
||||
# - perform int8 x int8 -> int32 matmul
|
||||
# - dequantize the int32 output to float32/bfloat16 outputs
|
||||
class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Only CPU is supported"
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return False, f"Unsupported quant type {c.weight_type}"
|
||||
if (
|
||||
current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
and c.act_type
|
||||
not in [
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
]
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"Dynamic4bitLinearKernel on Arm requires Float32 or"
|
||||
" BFloat16 or Float16 activations",
|
||||
)
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) does not evenly divide"
|
||||
" the number of input features "
|
||||
f"({c.full_weight_shape[0]})",
|
||||
)
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
try:
|
||||
# Attempt to retrieve the operation
|
||||
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||
except AttributeError:
|
||||
return (
|
||||
False,
|
||||
f"PyTorch {torch.__version__} does not support"
|
||||
" _dyn_quant_matmul_4bit. Install a newer version",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
packed_weight = getattr(layer, self.w_q_name)
|
||||
packed_weight = packed_weight.add(8)
|
||||
uint8_packed = (packed_weight[::, 1::2] << 4 | packed_weight[::, ::2]).to(
|
||||
torch.uint8
|
||||
)
|
||||
|
||||
scales = getattr(layer, self.w_s_name)
|
||||
block_size = c.group_size
|
||||
|
||||
# Handle scaling factors for partitioned weights
|
||||
if block_size == c.partition_weight_shape[0]:
|
||||
scales = scales.to(
|
||||
torch.float32
|
||||
) # Float32 & Bfloat16 variants requires float32 scales
|
||||
scales = scales.view(-1, 1) # Channel-wise scales
|
||||
if layer.bias is not None:
|
||||
# Float32 & Bfloat16 variants requires float32 bias
|
||||
replace_parameter(
|
||||
layer,
|
||||
"bias",
|
||||
torch.nn.Parameter(
|
||||
layer.bias.to(torch.float32), requires_grad=False
|
||||
),
|
||||
)
|
||||
else:
|
||||
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
|
||||
scales = scales.to(torch.bfloat16)
|
||||
|
||||
# Repack weights as per kernel requirement
|
||||
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||
uint8_packed,
|
||||
scales,
|
||||
layer.bias,
|
||||
block_size,
|
||||
c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(w, requires_grad=False)
|
||||
)
|
||||
setattr(layer, self.w_s_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# PyTorch / KleidiAI kernels natively support the following configs:
|
||||
# - channelwise with bfloat16 / float32 activations
|
||||
# - groupwise with float32 activations
|
||||
# To support:
|
||||
# - groupwise with bfloat16/float16 activations: we need to upcast
|
||||
# activations to float32 before matmul and downcast back to bfloat16/float16
|
||||
# - channelwise with float16 activations, we need to upcast activations to
|
||||
# float32 before matmul and downcast back to float16
|
||||
# Note: these activations will be dynamically quantized to int8 by the kernel.
|
||||
|
||||
c = self.config
|
||||
is_groupwise = c.group_size != c.partition_weight_shape[0]
|
||||
# dtype of activations before they get dynamically quantized to int8
|
||||
original_pre_quant_act_dtype = x.dtype
|
||||
pre_quant_act_dtype = original_pre_quant_act_dtype
|
||||
if (
|
||||
is_groupwise and pre_quant_act_dtype == torch.bfloat16
|
||||
) or pre_quant_act_dtype == torch.float16:
|
||||
pre_quant_act_dtype = torch.float32
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
if pre_quant_act_dtype != original_pre_quant_act_dtype:
|
||||
x_2d = x_2d.to(pre_quant_act_dtype)
|
||||
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q = getattr(layer, self.w_q_name)
|
||||
output = torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
x_2d,
|
||||
w_q,
|
||||
c.group_size,
|
||||
c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
).reshape(out_shape)
|
||||
|
||||
if pre_quant_act_dtype != original_pre_quant_act_dtype:
|
||||
output = output.to(original_pre_quant_act_dtype)
|
||||
return output
|
||||
168
vllm/model_executor/kernels/linear/mixed_precision/exllama.py
Normal file
168
vllm/model_executor/kernels/linear/mixed_precision/exllama.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class ExllamaLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
|
||||
# currently untested so not added to the list
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda_alike():
|
||||
return (
|
||||
False,
|
||||
"Exllama is only supported on CUDA and ROCm",
|
||||
)
|
||||
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
"Act reordering currently not supported by Exllama, "
|
||||
"when the input features are partitioned across "
|
||||
"devices",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
|
||||
return (
|
||||
False,
|
||||
"Output features must be a multiple of the pack "
|
||||
"factor (32 / num_bits) so that we can correctly "
|
||||
"pack the zero points",
|
||||
)
|
||||
|
||||
if c.act_type != torch.float16:
|
||||
return False, "Exllama only supports float16 activations"
|
||||
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"Exllama, supported types are: "
|
||||
f"{cls.SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) does not evenly divide"
|
||||
" the number of input features "
|
||||
f"({c.full_weight_shape[0]})",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
# For Exllama, we need to set a zero-point tensor if there is not one
|
||||
if not c.zero_points:
|
||||
self.w_zp_name = "qzeros"
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
groups = c.partition_weight_shape[0] // c.group_size
|
||||
out_features = c.partition_weight_shape[1]
|
||||
|
||||
if c.weight_type.has_bias():
|
||||
# if the type has a bias we have to create a zeros tensor that
|
||||
# contains the bias values repeated for each group (-1 due to
|
||||
# a bug in the original GPTQ checkpoint format leading to
|
||||
# exllama kernel adding 1 to the zero points during inference)
|
||||
# Documentation of the bug can be found here:
|
||||
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
|
||||
zeros = torch.full(
|
||||
(groups, out_features),
|
||||
c.weight_type.bias - 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"A 0 zero-point is not supported by Exllama due to "
|
||||
"a bug in the original GPTQ checkpoint format leading to "
|
||||
"exllama kernel adding 1 to the zero points during "
|
||||
"inference"
|
||||
)
|
||||
zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
|
||||
setattr(
|
||||
layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)
|
||||
)
|
||||
|
||||
if c.has_g_idx:
|
||||
|
||||
def transform_w_g_idx(x):
|
||||
# Exllama wants the permutation array instead of the group
|
||||
# indices
|
||||
return torch.argsort(x).to(torch.int)
|
||||
|
||||
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) # type: ignore
|
||||
else:
|
||||
self.w_gidx_name = "g_idx"
|
||||
empty_g_idx = torch.nn.Parameter(
|
||||
torch.empty((0,), dtype=torch.int, device=device), requires_grad=False
|
||||
)
|
||||
setattr(layer, self.w_gidx_name, empty_g_idx)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
assert self.w_gidx_name is not None
|
||||
g_idx = getattr(layer, self.w_gidx_name)
|
||||
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x_cont = x.data.contiguous()
|
||||
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
|
||||
return x_cont
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x.to(dtype=c.act_type)
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||
|
||||
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
|
||||
# However, the MPLinearLayerConfig doesn't contain format info.
|
||||
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
|
||||
use_v2_format = False
|
||||
|
||||
assert w_zp is not None, "Zero points are required by Exllama"
|
||||
assert w_g_idx is not None, "Group index is required by Exllama"
|
||||
output = ops.gptq_gemm(
|
||||
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
159
vllm/model_executor/kernels/linear/mixed_precision/machete.py
Normal file
159
vllm/model_executor/kernels/linear/mixed_precision/machete.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
check_machete_supports_shape,
|
||||
query_machete_supported_group_sizes,
|
||||
query_machete_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MacheteLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# Machete uses CUTLASS, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Machete only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "Machete requires compute capability of 90 (Hopper)"
|
||||
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
"Act reordering currently not supported by Machete, "
|
||||
"when the input features are partitioned across "
|
||||
"devices",
|
||||
)
|
||||
|
||||
if c.weight_type not in query_machete_supported_quant_types(c.zero_points):
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"Machete, supported types are: "
|
||||
f"{query_machete_supported_quant_types(c.zero_points)}",
|
||||
)
|
||||
|
||||
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"Machete, supported group sizes are: "
|
||||
f"{query_machete_supported_group_sizes(c.act_type)}",
|
||||
)
|
||||
|
||||
return check_machete_supports_shape(
|
||||
c.partition_weight_shape[0], c.partition_weight_shape[1]
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
if c.has_g_idx:
|
||||
assert self.w_gidx_name is not None
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
|
||||
|
||||
self.act_perm = lambda x: x[:, perm]
|
||||
# use `ops.permute_cols` if possible
|
||||
if (
|
||||
c.act_type in [torch.float16, torch.bfloat16]
|
||||
and c.partition_weight_shape[0] % 8 == 0
|
||||
):
|
||||
self.act_perm = partial(ops.permute_cols, perm=perm)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = unpack_quantized_values_into_int32(
|
||||
x.data, c.weight_type, packed_dim=0
|
||||
)
|
||||
x_perm = x_unpacked[perm, :]
|
||||
x.data = pack_quantized_values_into_int32(
|
||||
x_perm, c.weight_type, packed_dim=0
|
||||
)
|
||||
x.data = ops.machete_prepack_B(
|
||||
x.data.t().contiguous().t(),
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_zp(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
|
||||
x_unpacked = unpack_quantized_values_into_int32(
|
||||
x.data, c.weight_type, packed_dim=1
|
||||
)
|
||||
w_s = getattr(layer, self.w_s_name).data
|
||||
# pre-apply scales to zero-points
|
||||
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
|
||||
return x
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
if c.zero_points:
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
|
||||
if c.zero_points:
|
||||
assert w_zp is not None
|
||||
else:
|
||||
w_zp = None
|
||||
|
||||
output = ops.machete_mm(
|
||||
a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=w_zp,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
200
vllm/model_executor/kernels/linear/mixed_precision/marlin.py
Normal file
200
vllm/model_executor/kernels/linear/mixed_precision/marlin.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_is_k_full,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
marlin_sort_g_idx,
|
||||
marlin_zero_points,
|
||||
query_marlin_supported_quant_types,
|
||||
unpack_cols,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MarlinLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# Marlin uses inline PTX, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Marlin only supported on CUDA"
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by"
|
||||
f" Marlin, supported types are: {quant_types}",
|
||||
)
|
||||
|
||||
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"Marlin, supported group sizes are: "
|
||||
f"{MARLIN_SUPPORTED_GROUP_SIZES}",
|
||||
)
|
||||
|
||||
return check_marlin_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size,
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert c.weight_type == scalar_types.uint4b8, (
|
||||
"W8A8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
if c.act_type == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
|
||||
getattr(layer, self.w_s_name).data = (
|
||||
getattr(layer, self.w_s_name).data * 512
|
||||
)
|
||||
|
||||
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "w_zp"
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.gptq_marlin_repack(
|
||||
x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(
|
||||
x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
if c.group_size == -1:
|
||||
num_groups = 1
|
||||
else:
|
||||
num_groups = c.partition_weight_shape[0] // c.group_size
|
||||
|
||||
if c.act_type == torch.int8 and num_groups > 1:
|
||||
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
|
||||
layer.register_parameter(
|
||||
"input_global_scale",
|
||||
torch.nn.Parameter(input_global_scale, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
layer.input_global_scale = None
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name)
|
||||
)
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (
|
||||
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
||||
)
|
||||
self._transform_param(
|
||||
layer,
|
||||
self.w_zp_name,
|
||||
lambda x: marlin_zero_points(
|
||||
unpack_cols(
|
||||
x.t(),
|
||||
c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1],
|
||||
),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
weight_scale=w_s,
|
||||
weight_zp=w_zp, # type: ignore
|
||||
g_idx=w_gidx, # type: ignore
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=self.workspace,
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
input_global_scale=getattr(layer, "input_global_scale", None),
|
||||
bias=bias,
|
||||
input_dtype=c.act_type,
|
||||
)
|
||||
88
vllm/model_executor/kernels/linear/mixed_precision/xpu.py
Normal file
88
vllm/model_executor/kernels/linear/mixed_precision/xpu.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_XPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
|
||||
|
||||
|
||||
class XPUwNa16LinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_xpu():
|
||||
return False, "XPUwNa16 only supported on XPU"
|
||||
|
||||
if c.act_type != torch.bfloat16 and c.act_type != torch.float16:
|
||||
return False, "XPUwNa16 only supports BF16/FP16 activations"
|
||||
|
||||
if c.weight_type not in _XPUWNA16_SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"XPUwNa16, supported types are: "
|
||||
f"{_XPUWNA16_SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
if c.group_size != -1 and c.group_size % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"XPUwNa16, supported group sizes are multiples of 32",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[0] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Input size ({c.partition_weight_shape[0]}) not supported by "
|
||||
"XPUwNa16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Output size ({c.partition_weight_shape[1]}) not supported by "
|
||||
"XPUWNA16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
layer.weight_scale.data = layer.weight_scale.t().contiguous()
|
||||
|
||||
if self.config.zero_points:
|
||||
layer.weight_zero_point.data = layer.weight_zero_point.t().contiguous()
|
||||
else:
|
||||
weight_zero_point = torch.Tensor([8]).to(torch.int8).to("xpu")
|
||||
layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False)
|
||||
if self.config.has_g_idx:
|
||||
layer.g_idx.data = layer.g_idx.t().contiguous()
|
||||
else:
|
||||
layer.g_idx = None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = torch.ops._xpu_C.int4_gemm_w4a16(
|
||||
reshaped_x,
|
||||
layer.weight_packed.t(),
|
||||
bias,
|
||||
layer.weight_scale,
|
||||
layer.weight_zero_point,
|
||||
self.config.group_size,
|
||||
layer.g_idx,
|
||||
)
|
||||
return out
|
||||
Reference in New Issue
Block a user