Files
2026-01-09 15:09:53 +08:00

93 lines
3.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import numpy as np
try:
from lightop import awq_marlin_repack_w4a8
use_lightop = True
except Exception:
use_lightop = False
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
"""
将[N, K//2]大小的torch.int8 Tensor转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4分别提取到int32的低4位其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2]类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K]类型为torch.int32。
"""
if tensor_int8.dtype != torch.int8:
raise ValueError("Input tensor must be of type torch.int8")
N, K_half = tensor_int8.shape
tensor_uint8 = tensor_int8.to(torch.uint8)
high4 = tensor_uint8 & 0x0F
low4 = (tensor_uint8 >> 4) & 0x0F
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
unpacked[:, 0::2] = low4.to(torch.int32)
unpacked[:, 1::2] = high4.to(torch.int32)
return unpacked
def get_weight_perms(interleave: bool=True):
perm = []
for i in range(64):
for col in range(4):
cur_col = (i % 16) * 4 + col
for row in range(8):
cur_row = (i // 16) * 8 + row
cur_idx = cur_row * 64 + cur_col
perm.append(cur_idx)
perm = np.array(perm)
if interleave:
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
size_k, size_n = q_w.shape
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
orig_device = q_w.device
q_w = q_w.contiguous().to(torch.int32)
M, N = q_w.shape
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
for i in range(pack_factor):
q_packed += q_w[:, i::pack_factor] << (4 * i)
return q_packed
def w4a8_2_marlin_weight(w4a8_w):
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
full_w4a8_w = full_w4a8_w.T
weight_perm = get_weight_perms()
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
return marlin_q_w
def w4a8_weight_repack_impl(input):
if use_lightop:
size_batch = input.shape[0]
size_n = input.shape[1]
size_k = input.shape[2] * 2
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
else:
w_marlin_list = []
for e in range(input.shape[0]):
w_marlin_in = w4a8_2_marlin_weight(input[e])
w_marlin_list.append(w_marlin_in)
output = torch.stack(w_marlin_list, dim=0)
return output