134 lines
5.7 KiB
Python
134 lines
5.7 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
|
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
|
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
|
query_marlin_supported_quant_types)
|
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
|
permute_param_layout_)
|
|
|
|
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
|
|
|
|
|
class MarlinLinearKernel(MPLinearKernel):
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 80
|
|
|
|
@classmethod
|
|
def can_implement(cls,
|
|
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
|
if c.zero_points:
|
|
return False, "Zero points currently not supported by "\
|
|
" MarlinLinearKernel. Will be added when AWQMarlin "\
|
|
"is migrated over to using MPLinearKernel backend"
|
|
|
|
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
|
if c.weight_type not in quant_types:
|
|
return False, f"Quant type ({c.weight_type}) not supported by"\
|
|
f" Marlin, supported types are: {quant_types}"
|
|
|
|
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
|
return False, f"Group size ({c.group_size}) not supported by "\
|
|
"Marlin, supported group sizes are: "\
|
|
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
|
|
|
|
return check_marlin_supports_shape(
|
|
c.partition_weight_shape[1], # out_features
|
|
c.partition_weight_shape[0], # in_features
|
|
c.full_weight_shape[0], # in_features
|
|
c.group_size)
|
|
|
|
# note assumes that
|
|
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
|
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
device = getattr(layer, self.w_q_name).device
|
|
c = self.config
|
|
|
|
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
|
|
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
|
|
|
# Allocate marlin workspace.
|
|
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
|
|
device)
|
|
|
|
# Default names since marlin requires empty parameters for these,
|
|
# TODO: remove this requirement from marlin (allow optional tensors)
|
|
if self.w_gidx_name is None:
|
|
self.w_gidx_name = "g_idx"
|
|
if self.w_zp_name is None:
|
|
self.w_zp_name = "w_zp"
|
|
|
|
if c.has_g_idx:
|
|
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
|
getattr(layer, self.w_gidx_name))
|
|
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
|
layer.g_idx_sort_indices = g_idx_sort_indices
|
|
else:
|
|
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
|
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
|
|
|
if c.zero_points:
|
|
pass
|
|
# TODO (lucas): add the following when AWQMarlin is migrated over to
|
|
# using MPLinearKernel backend
|
|
# self._transform_param(layer, self.w_zp_name, lambda x: \
|
|
# marlin_zero_points(
|
|
# x,
|
|
# size_k=c.partition_weight_shape[0],
|
|
# size_n=c.partition_weight_shape[1],
|
|
# num_bits=c.weight_type.size_bits))
|
|
else:
|
|
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
|
|
|
def transform_w_q(x):
|
|
assert isinstance(x, BasevLLMParameter)
|
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
|
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
|
|
perm=layer.g_idx_sort_indices,
|
|
size_k=c.partition_weight_shape[0],
|
|
size_n=c.partition_weight_shape[1],
|
|
num_bits=c.weight_type.size_bits)
|
|
return x
|
|
|
|
def transform_w_s(x):
|
|
assert isinstance(x, BasevLLMParameter)
|
|
permute_param_layout_(x, input_dim=0, output_dim=1)
|
|
x.data = marlin_permute_scales(x.data.contiguous(),
|
|
size_k=c.partition_weight_shape[0],
|
|
size_n=c.partition_weight_shape[1],
|
|
group_size=c.group_size)
|
|
return x
|
|
|
|
self._transform_param(layer, self.w_q_name, transform_w_q)
|
|
self._transform_param(layer, self.w_s_name, transform_w_s)
|
|
|
|
def apply_weights(self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
c = self.config
|
|
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
|
|
|
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
|
# None for marlin
|
|
return apply_gptq_marlin_linear(
|
|
input=x,
|
|
weight=w_q,
|
|
weight_scale=w_s,
|
|
weight_zp=w_zp, # type: ignore
|
|
g_idx=w_gidx, # type: ignore
|
|
g_idx_sort_indices=layer.g_idx_sort_indices,
|
|
workspace=self.workspace,
|
|
wtype=c.weight_type,
|
|
input_size_per_partition=c.partition_weight_shape[0],
|
|
output_size_per_partition=c.partition_weight_shape[1],
|
|
is_k_full=self.is_k_full,
|
|
bias=bias)
|