Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -127,10 +127,10 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
MarlinLinearKernel,
CutlassW4A8LinearKernel,
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
],

View File

@@ -69,7 +69,6 @@ class MacheteLinearKernel(MPLinearKernel):
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
@@ -86,19 +85,17 @@ class MacheteLinearKernel(MPLinearKernel):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_quantized_values_into_int32(
x.data, c.weight_type, packed_dim=0
)
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_perm = x_unpacked[perm, :]
x.data = pack_quantized_values_into_int32(
x_perm, c.weight_type, packed_dim=0
)
x.data = ops.machete_prepack_B(
x.data.t().contiguous().t(),
a_type=c.act_type,
b_type=c.weight_type,
group_scales_type=c.act_type,
)
x.data = pack_quantized_values_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
a_type=c.act_type,
b_type=c.weight_type,
group_scales_type=c.act_type)
return x
def transform_w_s(x):
@@ -144,16 +141,14 @@ class MacheteLinearKernel(MPLinearKernel):
else:
w_zp = None
output = ops.machete_mm(
a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_group_zeros=w_zp,
b_group_scales=w_s,
b_group_size=c.group_size,
)
output = ops.machete_mm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_group_zeros=w_zp,
b_group_scales=w_s,
b_group_size=c.group_size)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
return output.reshape(out_shape)

View File

@@ -23,9 +23,95 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
from vllm.scalar_type import ScalarType, scalar_types
import ixformer.inference.functions as ixf_ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.logger import init_logger
logger = init_logger(__name__)
def unpack_rows(packed_w: torch.Tensor, num_bits: int) -> torch.Tensor:
"""
Efficient vectorized unpacking.
Converts [K // pack_factor, N] int32 tensor → [K, N] int8 tensor.
Args:
packed_w: torch.int32 tensor of shape [K // pack_factor, N].
num_bits: Number of bits per packed element (e.g., 4).
Returns:
unpacked: torch.int8 tensor of shape [K, N].
"""
pack_factor = 32 // num_bits
k_packed, n = packed_w.shape
k = k_packed * pack_factor
mask = (1 << num_bits) - 1
# [pack_factor, 1, 1]
shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1)
# [pack_factor, k_packed, n]
packed_expanded = packed_w.unsqueeze(0)
# Extract each group of num_bits using bitwise ops
unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
# [pack_factor, k_packed, n] → [k, n]
unpacked = unpacked_groups.permute(1, 0, 2).reshape(k, n)
return unpacked
def pack_cols(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
"""
Efficient vectorized version: pack int4 values (015) into int32.
Each int32 element contains `pack_num` 4-bit values.
Args:
x: Tensor of shape [rows, cols * pack_num], dtype=int32.
Represents unpacked int4 values.
pack_num: Number of 4-bit elements to pack into each int32.
order_map: Index mapping defining the order of 4-bit packing,
must match the unpack order used in `unpack_tensor`.
Returns:
Tensor of shape [rows, cols], dtype=int32 — packed result.
"""
# Default sequential order if none provided
if order_map is None:
order_map = list(range(pack_num))
order_map = torch.tensor(order_map, device=x.device)
# Number of bits per packed element (e.g., 32 / 8 = 4 bits)
unit = 32 // pack_num
rows, cols_pack = x.shape
assert cols_pack % pack_num == 0, "Number of columns must be a multiple of pack_num"
cols = cols_pack // pack_num
# Reshape input into groups of `pack_num` int4 values
# Shape: [rows, cols, pack_num]
x_reshape = x.view(rows, cols, pack_num)
# Reorder elements according to order_map
# order_map is broadcasted to match shape [rows, cols, pack_num]
x_reorder = torch.gather(x_reshape, 2, order_map.view(1, 1, -1).expand(rows, cols, -1))
# Keep only the lower 4 bits of each value
x_reorder = x_reorder & 0xF
# Compute bit shifts for each position (e.g., [0, 4, 8, 12, 16, 20, 24, 28])
shifts = (unit * torch.arange(pack_num, device=x.device)).view(1, 1, -1)
# Shift and combine (bitwise OR) along the last dimension
# Using sum() is safe since bits don't overlap between 4-bit slots
res = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
return res
class MarlinLinearKernel(MPLinearKernel):
@classmethod
@@ -79,96 +165,133 @@ class MarlinLinearKernel(MPLinearKernel):
getattr(layer, self.w_s_name).data = (
getattr(layer, self.w_s_name).data * 512
)
assert (c.weight_type.size_bits == 4) , f"MarlinLinearKernel now only support uint4, uint4b8, \
now quant weight_type {c.weight_typ}"
# device = getattr(layer, self.w_q_name).device
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
# self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace_new(device)
# self.workspace = marlin_make_workspace_new(device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "w_zp"
# if self.w_gidx_name is None:
# self.w_gidx_name = "g_idx"
# if self.w_zp_name is None:
# self.w_zp_name = "w_zp"
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
self.act_perm = lambda x: x[:, perm]
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(
x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
)
# assert isinstance(x, BasevLLMParameter)
# permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
# x.data = ops.gptq_marlin_repack(
# x.data.contiguous(),
# perm=layer.g_idx_sort_indices,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits,
# is_a_8bit=is_a_8bit,
# )
assert x.data.ndim == 2
if x._packed_dim == 1: #CompressedTensorsWNA16
#[oc, ic // 8] - > [oc, ic]
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=1)
if c.has_g_idx:
x_unpacked = x_unpacked[:,perm]
#[oc, ic] -> [ic, oc]
x_unpacked = x_unpacked.t().contiguous()
elif x._packed_dim == 0: #GPTQMarlinLinearMethod
#[ic // 8, oc] -> [ic , oc]
x_unpacked = unpack_rows(x.data,c.weight_type.size_bits)
if c.has_g_idx:
x_unpacked = x_unpacked[perm:]
raise NotImplementedError(f"GPTQMarlinLinearMethod has_g_idx not test, \
Please check whether the model's inference results are correct, and annotate/modify the statement. ")
else:
raise NotImplementedError(f"transform_w_q pack_dim {x._packed_dim} not implement")
#[ic, oc]-> [ic, oc//8]
x_packed = pack_cols(x_unpacked, order_map=[0, 2, 4, 6, 1, 3, 5, 7])
x.data = x_packed.contiguous()
x._packed_dim = 1
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(
x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size,
is_a_8bit=is_a_8bit,
)
x.data = x.data.contiguous()
return x.to(dtype=c.act_type)
if c.group_size == -1:
num_groups = 1
else:
num_groups = c.partition_weight_shape[0] // c.group_size
if c.act_type == torch.int8 and num_groups > 1:
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
layer.register_parameter(
"input_global_scale",
torch.nn.Parameter(input_global_scale, requires_grad=False),
)
else:
layer.input_global_scale = None
# if c.has_g_idx:
# g_idx, g_idx_sort_indices = marlin_sort_g_idx(
# getattr(layer, self.w_gidx_name)
# )
# self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
# layer.g_idx_sort_indices = g_idx_sort_indices
# else:
# setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
# layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
def transform_w_zp(x):
grouped_k = (c.partition_weight_shape[0] //
c.group_size if c.group_size != -1 else 1)
x_unpacked = unpack_cols(x.clone().t(), c.weight_type.size_bits, grouped_k, c.partition_weight_shape[1])
x_packed = pack_cols(x_unpacked, order_map=[0, 2, 4, 6, 1, 3, 5, 7])
x.data = x_packed.contiguous()
return x
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name)
)
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
grouped_k = (
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
)
self._transform_param(
layer,
self.w_zp_name,
lambda x: marlin_zero_points(
unpack_cols(
x.t(),
c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1],
),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
),
)
# grouped_k = (
# c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
# )
# self._transform_param(
# layer,
# self.w_zp_name,
# lambda x: marlin_zero_points(
# unpack_cols(
# x.t(),
# c.weight_type.size_bits,
# grouped_k,
# c.partition_weight_shape[1],
# ),
# size_k=grouped_k,
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits,
# ),
# )
self._transform_param(layer, self.w_zp_name, transform_w_zp)
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
# setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
#weight_type = uint4b8, using c.weight_type.bias as zero point,according quant method.
#[ic, oc]-> [ic, oc//8]
w_zp = torch.full_like(getattr(layer, self.w_s_name), c.weight_type.bias, dtype=torch.int32)
w_zp_pack = pack_cols(w_zp, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
weight_zero_point = torch.nn.Parameter(
w_zp_pack,
requires_grad=False)
if hasattr(layer, self.w_zp_name):
replace_parameter(layer, self.w_zp_name, weight_zero_point) #GPTQMarlinLinearMethod
else:
layer.register_parameter("weight_zero_point", weight_zero_point) #CompressedTensorsWNA16
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data = marlin_permute_bias(layer.bias)
# if hasattr(layer, "bias") and layer.bias is not None:
# layer.bias.data = marlin_permute_bias(layer.bias)
def apply_weights(
self,
@@ -179,22 +302,39 @@ class MarlinLinearKernel(MPLinearKernel):
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
pack_factor = 32 // c.weight_type.size_bits
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
x_2d = x.reshape(-1, x.shape[-1])
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
out = ops.custom_gptq_marlin_gemm(input = x_2d,
qweight = w_q,
scales = w_s,
qzeros = w_zp,
pack_factor = pack_factor,
group_size = c.group_size,
bias = bias)
out = out.reshape(out_shape)
# if bias is not None:
# out.add_(bias)
return out
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias,
input_dtype=c.act_type,
)
# # `process_weights_after_loading` will ensure w_zp and w_gidx are not
# # None for marlin
# return apply_gptq_marlin_linear(
# input=x,
# weight=w_q,
# weight_scale=w_s,
# weight_zp=w_zp, # type: ignore
# g_idx=w_gidx, # type: ignore
# g_idx_sort_indices=layer.g_idx_sort_indices,
# workspace=self.workspace,
# wtype=c.weight_type,
# input_size_per_partition=c.partition_weight_shape[0],
# output_size_per_partition=c.partition_weight_shape[1],
# is_k_full=self.is_k_full,
# bias=bias)

View File

@@ -18,7 +18,6 @@ from .ScaledMMLinearKernel import (
Int8ScaledMMLinearLayerConfig,
)
import vllm.envs as envs
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
@@ -38,13 +37,28 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
config = self.config
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
w_q_name,
# torch.nn.Parameter(weight.t().data, requires_grad=False),
torch.nn.Parameter(weight.data if envs.VLLM_W8A8_LINEAR_USE_W4A8 else weight.t().data, requires_grad=False),
)
weight = getattr(layer, w_q_name)
if layer.scheme.is_w4a8_linear:
self.format = "NN"
replace_parameter(layer, w_q_name, torch.nn.Parameter(weight.data.contiguous(), requires_grad=False))
else:
self.format = "TN" #默认weight都是按T排布
m, k = weight.shape
if(m % 64 == 0 and k % 64 == 0):
self.format= "NN"
replace_parameter(
layer, w_q_name,
torch.nn.Parameter(weight.t().data.contiguous(), requires_grad=False))#原始排布是T[m,k] 处理完后是N[k, m]
else:
if k % 64 != 0:
pad_k = (k // 64 + 1) * 64
weight_pad = torch.empty((m, pad_k), dtype=weight.dtype, device=weight.device)
_weight = weight_pad[:, :k]
_weight.copy_(weight)
weight = _weight
replace_parameter(
layer, w_q_name,
torch.nn.Parameter(weight.t(), requires_grad=False))
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
@@ -114,6 +128,7 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
is_w4a8_linear: bool = False,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
@@ -121,9 +136,15 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=symmetric
)
if isinstance(x, tuple):
x_q, x_s, out_dtype = x
x_zp = None
else:
out_dtype = x.dtype
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
i_s,
i_zp,
symmetric=symmetric)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
@@ -134,14 +155,21 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
w_q,
scale_a=x_s,
scale_b=w_s,
out_dtype=x.dtype,
out_dtype=out_dtype,
azp_adj=azp_adj,
azp=azp,
bias=bias,
)
if self.format == "NN" and x_q.shape[-1] != w_q.shape[0]:
padding = w_q.shape[0] - x_q.shape[-1]
x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
elif self.format == "TN" and x_q.shape[-1] != w_q.shape[-1]:
padding = w_q.shape[-1] - x_q.shape[-1]
x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
else:
x_align = x_q
return ops.cutlass_scaled_mm(
# x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias, format="NN" if envs.VLLM_W8A8_LINEAR_USE_W4A8 else "TN"
x_align, w_q, scale_a=x_s, scale_b=w_s, out_dtype=out_dtype, bias=bias, format=self.format, is_w4a8_linear=is_w4a8_linear
)