93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
|
||
import torch
|
||
import numpy as np
|
||
|
||
try:
|
||
from lightop import awq_marlin_repack_w4a8
|
||
use_lightop = True
|
||
except Exception:
|
||
use_lightop = False
|
||
|
||
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
||
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
||
|
||
Args:
|
||
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
||
|
||
Returns:
|
||
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
||
"""
|
||
if tensor_int8.dtype != torch.int8:
|
||
raise ValueError("Input tensor must be of type torch.int8")
|
||
|
||
N, K_half = tensor_int8.shape
|
||
tensor_uint8 = tensor_int8.to(torch.uint8)
|
||
high4 = tensor_uint8 & 0x0F
|
||
low4 = (tensor_uint8 >> 4) & 0x0F
|
||
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
||
unpacked[:, 0::2] = low4.to(torch.int32)
|
||
unpacked[:, 1::2] = high4.to(torch.int32)
|
||
|
||
return unpacked
|
||
|
||
def get_weight_perms(interleave: bool=True):
|
||
perm = []
|
||
for i in range(64):
|
||
|
||
for col in range(4):
|
||
cur_col = (i % 16) * 4 + col
|
||
for row in range(8):
|
||
cur_row = (i // 16) * 8 + row
|
||
cur_idx = cur_row * 64 + cur_col
|
||
perm.append(cur_idx)
|
||
|
||
perm = np.array(perm)
|
||
if interleave:
|
||
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
||
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
||
|
||
perm = torch.from_numpy(perm)
|
||
|
||
return perm
|
||
|
||
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
||
size_k, size_n = q_w.shape
|
||
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
||
q_w = q_w.permute((0, 2, 1, 3))
|
||
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
||
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
||
|
||
orig_device = q_w.device
|
||
q_w = q_w.contiguous().to(torch.int32)
|
||
M, N = q_w.shape
|
||
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
||
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
||
for i in range(pack_factor):
|
||
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
||
|
||
return q_packed
|
||
|
||
def w4a8_2_marlin_weight(w4a8_w):
|
||
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
||
full_w4a8_w = full_w4a8_w.T
|
||
weight_perm = get_weight_perms()
|
||
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
||
return marlin_q_w
|
||
|
||
def w4a8_weight_repack_impl(input):
|
||
if use_lightop:
|
||
size_batch = input.shape[0]
|
||
size_n = input.shape[1]
|
||
size_k = input.shape[2] * 2
|
||
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
||
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
||
else:
|
||
w_marlin_list = []
|
||
for e in range(input.shape[0]):
|
||
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
||
w_marlin_list.append(w_marlin_in)
|
||
output = torch.stack(w_marlin_list, dim=0)
|
||
|
||
return output |