93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
|
|
|
|||
|
|
import torch
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from lightop import awq_marlin_repack_w4a8
|
|||
|
|
use_lightop = True
|
|||
|
|
except Exception:
|
|||
|
|
use_lightop = False
|
|||
|
|
|
|||
|
|
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
|||
|
|
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
|||
|
|
"""
|
|||
|
|
if tensor_int8.dtype != torch.int8:
|
|||
|
|
raise ValueError("Input tensor must be of type torch.int8")
|
|||
|
|
|
|||
|
|
N, K_half = tensor_int8.shape
|
|||
|
|
tensor_uint8 = tensor_int8.to(torch.uint8)
|
|||
|
|
high4 = tensor_uint8 & 0x0F
|
|||
|
|
low4 = (tensor_uint8 >> 4) & 0x0F
|
|||
|
|
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
|||
|
|
unpacked[:, 0::2] = low4.to(torch.int32)
|
|||
|
|
unpacked[:, 1::2] = high4.to(torch.int32)
|
|||
|
|
|
|||
|
|
return unpacked
|
|||
|
|
|
|||
|
|
def get_weight_perms(interleave: bool=True):
|
|||
|
|
perm = []
|
|||
|
|
for i in range(64):
|
|||
|
|
|
|||
|
|
for col in range(4):
|
|||
|
|
cur_col = (i % 16) * 4 + col
|
|||
|
|
for row in range(8):
|
|||
|
|
cur_row = (i // 16) * 8 + row
|
|||
|
|
cur_idx = cur_row * 64 + cur_col
|
|||
|
|
perm.append(cur_idx)
|
|||
|
|
|
|||
|
|
perm = np.array(perm)
|
|||
|
|
if interleave:
|
|||
|
|
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
|||
|
|
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
|||
|
|
|
|||
|
|
perm = torch.from_numpy(perm)
|
|||
|
|
|
|||
|
|
return perm
|
|||
|
|
|
|||
|
|
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
|||
|
|
size_k, size_n = q_w.shape
|
|||
|
|
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
|||
|
|
q_w = q_w.permute((0, 2, 1, 3))
|
|||
|
|
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
|||
|
|
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
|||
|
|
|
|||
|
|
orig_device = q_w.device
|
|||
|
|
q_w = q_w.contiguous().to(torch.int32)
|
|||
|
|
M, N = q_w.shape
|
|||
|
|
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
|||
|
|
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
|||
|
|
for i in range(pack_factor):
|
|||
|
|
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
|||
|
|
|
|||
|
|
return q_packed
|
|||
|
|
|
|||
|
|
def w4a8_2_marlin_weight(w4a8_w):
|
|||
|
|
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
|||
|
|
full_w4a8_w = full_w4a8_w.T
|
|||
|
|
weight_perm = get_weight_perms()
|
|||
|
|
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
|||
|
|
return marlin_q_w
|
|||
|
|
|
|||
|
|
def w4a8_weight_repack_impl(input):
|
|||
|
|
if use_lightop:
|
|||
|
|
size_batch = input.shape[0]
|
|||
|
|
size_n = input.shape[1]
|
|||
|
|
size_k = input.shape[2] * 2
|
|||
|
|
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
|||
|
|
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
|||
|
|
else:
|
|||
|
|
w_marlin_list = []
|
|||
|
|
for e in range(input.shape[0]):
|
|||
|
|
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
|||
|
|
w_marlin_list.append(w_marlin_in)
|
|||
|
|
output = torch.stack(w_marlin_list, dim=0)
|
|||
|
|
|
|||
|
|
return output
|