Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -127,10 +127,10 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
MarlinLinearKernel,
|
||||
CutlassW4A8LinearKernel,
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
|
||||
@@ -69,7 +69,6 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
if c.has_g_idx:
|
||||
assert self.w_gidx_name is not None
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
|
||||
@@ -86,19 +85,17 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = unpack_quantized_values_into_int32(
|
||||
x.data, c.weight_type, packed_dim=0
|
||||
)
|
||||
x_unpacked = unpack_quantized_values_into_int32(x.data,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x_perm = x_unpacked[perm, :]
|
||||
x.data = pack_quantized_values_into_int32(
|
||||
x_perm, c.weight_type, packed_dim=0
|
||||
)
|
||||
x.data = ops.machete_prepack_B(
|
||||
x.data.t().contiguous().t(),
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type,
|
||||
)
|
||||
x.data = pack_quantized_values_into_int32(x_perm,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
@@ -144,16 +141,14 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
else:
|
||||
w_zp = None
|
||||
|
||||
output = ops.machete_mm(
|
||||
a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=w_zp,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
)
|
||||
output = ops.machete_mm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=w_zp,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
return output.reshape(out_shape)
|
||||
@@ -23,9 +23,95 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def unpack_rows(packed_w: torch.Tensor, num_bits: int) -> torch.Tensor:
|
||||
"""
|
||||
Efficient vectorized unpacking.
|
||||
Converts [K // pack_factor, N] int32 tensor → [K, N] int8 tensor.
|
||||
|
||||
Args:
|
||||
packed_w: torch.int32 tensor of shape [K // pack_factor, N].
|
||||
num_bits: Number of bits per packed element (e.g., 4).
|
||||
|
||||
Returns:
|
||||
unpacked: torch.int8 tensor of shape [K, N].
|
||||
"""
|
||||
pack_factor = 32 // num_bits
|
||||
k_packed, n = packed_w.shape
|
||||
k = k_packed * pack_factor
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
# [pack_factor, 1, 1]
|
||||
shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1)
|
||||
|
||||
# [pack_factor, k_packed, n]
|
||||
packed_expanded = packed_w.unsqueeze(0)
|
||||
|
||||
# Extract each group of num_bits using bitwise ops
|
||||
unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
|
||||
# [pack_factor, k_packed, n] → [k, n]
|
||||
unpacked = unpacked_groups.permute(1, 0, 2).reshape(k, n)
|
||||
|
||||
return unpacked
|
||||
|
||||
|
||||
def pack_cols(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
|
||||
"""
|
||||
Efficient vectorized version: pack int4 values (0–15) into int32.
|
||||
Each int32 element contains `pack_num` 4-bit values.
|
||||
|
||||
Args:
|
||||
x: Tensor of shape [rows, cols * pack_num], dtype=int32.
|
||||
Represents unpacked int4 values.
|
||||
pack_num: Number of 4-bit elements to pack into each int32.
|
||||
order_map: Index mapping defining the order of 4-bit packing,
|
||||
must match the unpack order used in `unpack_tensor`.
|
||||
|
||||
Returns:
|
||||
Tensor of shape [rows, cols], dtype=int32 — packed result.
|
||||
"""
|
||||
# Default sequential order if none provided
|
||||
if order_map is None:
|
||||
order_map = list(range(pack_num))
|
||||
order_map = torch.tensor(order_map, device=x.device)
|
||||
|
||||
# Number of bits per packed element (e.g., 32 / 8 = 4 bits)
|
||||
unit = 32 // pack_num
|
||||
rows, cols_pack = x.shape
|
||||
assert cols_pack % pack_num == 0, "Number of columns must be a multiple of pack_num"
|
||||
cols = cols_pack // pack_num
|
||||
|
||||
# Reshape input into groups of `pack_num` int4 values
|
||||
# Shape: [rows, cols, pack_num]
|
||||
x_reshape = x.view(rows, cols, pack_num)
|
||||
|
||||
# Reorder elements according to order_map
|
||||
# order_map is broadcasted to match shape [rows, cols, pack_num]
|
||||
x_reorder = torch.gather(x_reshape, 2, order_map.view(1, 1, -1).expand(rows, cols, -1))
|
||||
|
||||
# Keep only the lower 4 bits of each value
|
||||
x_reorder = x_reorder & 0xF
|
||||
|
||||
# Compute bit shifts for each position (e.g., [0, 4, 8, 12, 16, 20, 24, 28])
|
||||
shifts = (unit * torch.arange(pack_num, device=x.device)).view(1, 1, -1)
|
||||
|
||||
# Shift and combine (bitwise OR) along the last dimension
|
||||
# Using sum() is safe since bits don't overlap between 4-bit slots
|
||||
res = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
|
||||
|
||||
return res
|
||||
|
||||
class MarlinLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
@@ -79,96 +165,133 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
getattr(layer, self.w_s_name).data = (
|
||||
getattr(layer, self.w_s_name).data * 512
|
||||
)
|
||||
assert (c.weight_type.size_bits == 4) , f"MarlinLinearKernel now only support uint4, uint4b8, \
|
||||
now quant weight_type {c.weight_typ}"
|
||||
|
||||
# device = getattr(layer, self.w_q_name).device
|
||||
|
||||
|
||||
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
# row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
# self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace_new(device)
|
||||
# self.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "w_zp"
|
||||
# if self.w_gidx_name is None:
|
||||
# self.w_gidx_name = "g_idx"
|
||||
# if self.w_zp_name is None:
|
||||
# self.w_zp_name = "w_zp"
|
||||
if c.has_g_idx:
|
||||
assert self.w_gidx_name is not None
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
|
||||
|
||||
self.act_perm = lambda x: x[:, perm]
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.gptq_marlin_repack(
|
||||
x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
# assert isinstance(x, BasevLLMParameter)
|
||||
# permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
# x.data = ops.gptq_marlin_repack(
|
||||
# x.data.contiguous(),
|
||||
# perm=layer.g_idx_sort_indices,
|
||||
# size_k=c.partition_weight_shape[0],
|
||||
# size_n=c.partition_weight_shape[1],
|
||||
# num_bits=c.weight_type.size_bits,
|
||||
# is_a_8bit=is_a_8bit,
|
||||
# )
|
||||
assert x.data.ndim == 2
|
||||
if x._packed_dim == 1: #CompressedTensorsWNA16
|
||||
#[oc, ic // 8] - > [oc, ic]
|
||||
x_unpacked = unpack_quantized_values_into_int32(x.data,
|
||||
c.weight_type,
|
||||
packed_dim=1)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = x_unpacked[:,perm]
|
||||
#[oc, ic] -> [ic, oc]
|
||||
x_unpacked = x_unpacked.t().contiguous()
|
||||
|
||||
elif x._packed_dim == 0: #GPTQMarlinLinearMethod
|
||||
|
||||
#[ic // 8, oc] -> [ic , oc]
|
||||
x_unpacked = unpack_rows(x.data,c.weight_type.size_bits)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = x_unpacked[perm:]
|
||||
raise NotImplementedError(f"GPTQMarlinLinearMethod has_g_idx not test, \
|
||||
Please check whether the model's inference results are correct, and annotate/modify the statement. ")
|
||||
else:
|
||||
raise NotImplementedError(f"transform_w_q pack_dim {x._packed_dim} not implement")
|
||||
|
||||
#[ic, oc]-> [ic, oc//8]
|
||||
x_packed = pack_cols(x_unpacked, order_map=[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
x.data = x_packed.contiguous()
|
||||
x._packed_dim = 1
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(
|
||||
x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
x.data = x.data.contiguous()
|
||||
return x.to(dtype=c.act_type)
|
||||
|
||||
if c.group_size == -1:
|
||||
num_groups = 1
|
||||
else:
|
||||
num_groups = c.partition_weight_shape[0] // c.group_size
|
||||
|
||||
if c.act_type == torch.int8 and num_groups > 1:
|
||||
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
|
||||
layer.register_parameter(
|
||||
"input_global_scale",
|
||||
torch.nn.Parameter(input_global_scale, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
layer.input_global_scale = None
|
||||
# if c.has_g_idx:
|
||||
# g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
# getattr(layer, self.w_gidx_name)
|
||||
# )
|
||||
# self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
# layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
# else:
|
||||
# setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
# layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
def transform_w_zp(x):
|
||||
grouped_k = (c.partition_weight_shape[0] //
|
||||
c.group_size if c.group_size != -1 else 1)
|
||||
x_unpacked = unpack_cols(x.clone().t(), c.weight_type.size_bits, grouped_k, c.partition_weight_shape[1])
|
||||
x_packed = pack_cols(x_unpacked, order_map=[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
x.data = x_packed.contiguous()
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name)
|
||||
)
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (
|
||||
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
||||
)
|
||||
self._transform_param(
|
||||
layer,
|
||||
self.w_zp_name,
|
||||
lambda x: marlin_zero_points(
|
||||
unpack_cols(
|
||||
x.t(),
|
||||
c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1],
|
||||
),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
),
|
||||
)
|
||||
# grouped_k = (
|
||||
# c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
||||
# )
|
||||
# self._transform_param(
|
||||
# layer,
|
||||
# self.w_zp_name,
|
||||
# lambda x: marlin_zero_points(
|
||||
# unpack_cols(
|
||||
# x.t(),
|
||||
# c.weight_type.size_bits,
|
||||
# grouped_k,
|
||||
# c.partition_weight_shape[1],
|
||||
# ),
|
||||
# size_k=grouped_k,
|
||||
# size_n=c.partition_weight_shape[1],
|
||||
# num_bits=c.weight_type.size_bits,
|
||||
# ),
|
||||
# )
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
# setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
#weight_type = uint4b8, using c.weight_type.bias as zero point,according quant method.
|
||||
#[ic, oc]-> [ic, oc//8]
|
||||
w_zp = torch.full_like(getattr(layer, self.w_s_name), c.weight_type.bias, dtype=torch.int32)
|
||||
w_zp_pack = pack_cols(w_zp, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
|
||||
weight_zero_point = torch.nn.Parameter(
|
||||
w_zp_pack,
|
||||
requires_grad=False)
|
||||
|
||||
if hasattr(layer, self.w_zp_name):
|
||||
replace_parameter(layer, self.w_zp_name, weight_zero_point) #GPTQMarlinLinearMethod
|
||||
else:
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point) #CompressedTensorsWNA16
|
||||
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
# if hasattr(layer, "bias") and layer.bias is not None:
|
||||
# layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@@ -179,22 +302,39 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
pack_factor = 32 // c.weight_type.size_bits
|
||||
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
|
||||
out = ops.custom_gptq_marlin_gemm(input = x_2d,
|
||||
qweight = w_q,
|
||||
scales = w_s,
|
||||
qzeros = w_zp,
|
||||
pack_factor = pack_factor,
|
||||
group_size = c.group_size,
|
||||
bias = bias)
|
||||
out = out.reshape(out_shape)
|
||||
# if bias is not None:
|
||||
# out.add_(bias)
|
||||
return out
|
||||
|
||||
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
weight_scale=w_s,
|
||||
weight_zp=w_zp, # type: ignore
|
||||
g_idx=w_gidx, # type: ignore
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=self.workspace,
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
input_global_scale=getattr(layer, "input_global_scale", None),
|
||||
bias=bias,
|
||||
input_dtype=c.act_type,
|
||||
)
|
||||
# # `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# # None for marlin
|
||||
# return apply_gptq_marlin_linear(
|
||||
# input=x,
|
||||
# weight=w_q,
|
||||
# weight_scale=w_s,
|
||||
# weight_zp=w_zp, # type: ignore
|
||||
# g_idx=w_gidx, # type: ignore
|
||||
# g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
# workspace=self.workspace,
|
||||
# wtype=c.weight_type,
|
||||
# input_size_per_partition=c.partition_weight_shape[0],
|
||||
# output_size_per_partition=c.partition_weight_shape[1],
|
||||
# is_k_full=self.is_k_full,
|
||||
# bias=bias)
|
||||
|
||||
@@ -18,7 +18,6 @@ from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
@@ -38,13 +37,28 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
config = self.config
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
# torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
torch.nn.Parameter(weight.data if envs.VLLM_W8A8_LINEAR_USE_W4A8 else weight.t().data, requires_grad=False),
|
||||
)
|
||||
weight = getattr(layer, w_q_name)
|
||||
if layer.scheme.is_w4a8_linear:
|
||||
self.format = "NN"
|
||||
replace_parameter(layer, w_q_name, torch.nn.Parameter(weight.data.contiguous(), requires_grad=False))
|
||||
else:
|
||||
self.format = "TN" #默认weight都是按T排布
|
||||
m, k = weight.shape
|
||||
if(m % 64 == 0 and k % 64 == 0):
|
||||
self.format= "NN"
|
||||
replace_parameter(
|
||||
layer, w_q_name,
|
||||
torch.nn.Parameter(weight.t().data.contiguous(), requires_grad=False))#原始排布是T[m,k] 处理完后是N[k, m]
|
||||
else:
|
||||
if k % 64 != 0:
|
||||
pad_k = (k // 64 + 1) * 64
|
||||
weight_pad = torch.empty((m, pad_k), dtype=weight.dtype, device=weight.device)
|
||||
_weight = weight_pad[:, :k]
|
||||
_weight.copy_(weight)
|
||||
weight = _weight
|
||||
replace_parameter(
|
||||
layer, w_q_name,
|
||||
torch.nn.Parameter(weight.t(), requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
@@ -114,6 +128,7 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
is_w4a8_linear: bool = False,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
@@ -121,9 +136,15 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=symmetric
|
||||
)
|
||||
if isinstance(x, tuple):
|
||||
x_q, x_s, out_dtype = x
|
||||
x_zp = None
|
||||
else:
|
||||
out_dtype = x.dtype
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
|
||||
i_s,
|
||||
i_zp,
|
||||
symmetric=symmetric)
|
||||
|
||||
if x_zp is not None:
|
||||
# Currently, static is always per-tensor and dynamic is per-token
|
||||
@@ -134,14 +155,21 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
out_dtype=out_dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=azp,
|
||||
bias=bias,
|
||||
)
|
||||
if self.format == "NN" and x_q.shape[-1] != w_q.shape[0]:
|
||||
padding = w_q.shape[0] - x_q.shape[-1]
|
||||
x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
|
||||
elif self.format == "TN" and x_q.shape[-1] != w_q.shape[-1]:
|
||||
padding = w_q.shape[-1] - x_q.shape[-1]
|
||||
x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
|
||||
else:
|
||||
x_align = x_q
|
||||
return ops.cutlass_scaled_mm(
|
||||
# x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias, format="NN" if envs.VLLM_W8A8_LINEAR_USE_W4A8 else "TN"
|
||||
x_align, w_q, scale_a=x_s, scale_b=w_s, out_dtype=out_dtype, bias=bias, format=self.format, is_w4a8_linear=is_w4a8_linear
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user