First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,83 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.scalar_type import ScalarType
@dataclass
class MPLinearLayerConfig:
full_weight_shape: Tuple[int, int] # [in, out]
partition_weight_shape: Tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
fn: Callable) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name,
torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)

View File

@@ -0,0 +1,72 @@
import os
from typing import List, Optional, Type
from vllm.model_executor.layers.quantization.kernels.machete import (
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
MarlinLinearKernel)
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
MPLinearKernel, MPLinearLayerConfig)
from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
]
def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute
capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
.split(","):
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue
if kernel.get_min_capability() > compute_capability:
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute capability "
f"is {compute_capability}")
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
)
raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))

View File

@@ -0,0 +1,118 @@
from functools import partial
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_weights_into_int32, unpack_weights_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MacheteLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.zero_points:
return False, "Zero points currently not supported by "\
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
return False, f"Quant type ({c.weight_type}) not supported by "\
"Machete, supported types are: "\
f"{query_machete_supported_quant_types(c.zero_points)}"
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Machete, supported group sizes are: "\
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
return check_machete_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1])
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
.to(torch.int)
self.act_perm = lambda x: x[:, perm]
# use `ops.permute_cols` if possible
if c.act_type in [torch.float16, torch.bfloat16] \
and c.partition_weight_shape[0] % 8 == 0:
self.act_perm = partial(ops.permute_cols, perm=perm)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_weights_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_perm = x_unpacked[perm, :]
x.data = pack_weights_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
output = ops.machete_gemm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_zeros=None,
b_scales=w_s,
b_group_size=c.group_size)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -0,0 +1,133 @@
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MarlinLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.zero_points:
return False, "Zero points currently not supported by "\
" MarlinLinearKernel. Will be added when AWQMarlin "\
"is migrated over to using MPLinearKernel backend"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, f"Quant type ({c.weight_type}) not supported by"\
f" Marlin, supported types are: {quant_types}"
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Marlin, supported group sizes are: "\
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
return check_marlin_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "w_zp"
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size)
return x
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
bias=bias)