init src 0.9.2
This commit is contained in:
93
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
Normal file
93
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from lightop import awq_marlin_repack_w4a8
|
||||
use_lightop = True
|
||||
except Exception:
|
||||
use_lightop = False
|
||||
|
||||
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
||||
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
||||
|
||||
Args:
|
||||
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
||||
"""
|
||||
if tensor_int8.dtype != torch.int8:
|
||||
raise ValueError("Input tensor must be of type torch.int8")
|
||||
|
||||
N, K_half = tensor_int8.shape
|
||||
tensor_uint8 = tensor_int8.to(torch.uint8)
|
||||
high4 = tensor_uint8 & 0x0F
|
||||
low4 = (tensor_uint8 >> 4) & 0x0F
|
||||
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
||||
unpacked[:, 0::2] = low4.to(torch.int32)
|
||||
unpacked[:, 1::2] = high4.to(torch.int32)
|
||||
|
||||
return unpacked
|
||||
|
||||
def get_weight_perms(interleave: bool=True):
|
||||
perm = []
|
||||
for i in range(64):
|
||||
|
||||
for col in range(4):
|
||||
cur_col = (i % 16) * 4 + col
|
||||
for row in range(8):
|
||||
cur_row = (i // 16) * 8 + row
|
||||
cur_idx = cur_row * 64 + cur_col
|
||||
perm.append(cur_idx)
|
||||
|
||||
perm = np.array(perm)
|
||||
if interleave:
|
||||
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
||||
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
||||
|
||||
perm = torch.from_numpy(perm)
|
||||
|
||||
return perm
|
||||
|
||||
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
||||
size_k, size_n = q_w.shape
|
||||
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
||||
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
||||
|
||||
orig_device = q_w.device
|
||||
q_w = q_w.contiguous().to(torch.int32)
|
||||
M, N = q_w.shape
|
||||
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
||||
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
||||
for i in range(pack_factor):
|
||||
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
||||
|
||||
return q_packed
|
||||
|
||||
def w4a8_2_marlin_weight(w4a8_w):
|
||||
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
||||
full_w4a8_w = full_w4a8_w.T
|
||||
weight_perm = get_weight_perms()
|
||||
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
||||
return marlin_q_w
|
||||
|
||||
def w4a8_weight_repack_impl(input):
|
||||
if use_lightop:
|
||||
size_batch = input.shape[0]
|
||||
size_n = input.shape[1]
|
||||
size_k = input.shape[2] * 2
|
||||
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
||||
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
||||
else:
|
||||
w_marlin_list = []
|
||||
for e in range(input.shape[0]):
|
||||
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
||||
w_marlin_list.append(w_marlin_in)
|
||||
output = torch.stack(w_marlin_list, dim=0)
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user