[2/n]decouple quantization implementation from vLLM dependency (#8112)

Co-authored-by: walker-ai <yiyun.wyt@antgroup.com>
Co-authored-by: leoneo <1320612015@qq.com>
This commit is contained in:
Peng Zhang
2025-08-14 18:19:03 +08:00
committed by GitHub
parent 4dbf43601d
commit 5aa1ebd242
32 changed files with 6506 additions and 202 deletions

View File

@@ -0,0 +1,131 @@
import pytest
import torch
from sgl_kernel import gptq_gemm
from sglang.srt.layers.quantization.utils import pack_cols, pack_rows
def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N):
assert bit == 4, "Reference dequantization only supports 4-bit"
group_size = K // scales.shape[0]
pack_factor = 32 // bit
# unpack q_weight: (K//pack_factor, N) -> (K, N)
unpacked_q_weight = torch.empty(
q_weight.shape[0] * pack_factor,
q_weight.shape[1],
dtype=torch.uint8,
device=q_weight.device,
)
for i in range(pack_factor):
unpacked_q_weight[i::pack_factor, :] = (q_weight >> (i * 4)) & 0x0F
# unpack q_zeros: (num_groups, N//pack_factor) -> (num_groups, N)
unpacked_q_zeros = torch.empty(
q_zeros.shape[0],
q_zeros.shape[1] * pack_factor,
dtype=torch.uint8,
device=q_zeros.device,
)
for i in range(pack_factor):
unpacked_q_zeros[:, i::pack_factor] = (q_zeros >> (i * 4)) & 0x0F
unpacked_q_zeros += 1
unpacked_q_zeros = unpacked_q_zeros.to(scales.dtype)
scale_zeros = unpacked_q_zeros * scales # (num_groups, N)
current_g_idx = torch.tensor(
[i // group_size for i in range(K)], dtype=torch.int32, device=q_weight.device
)
scale_mat = scales[current_g_idx] # (K, N)
scale_zeros_mat = scale_zeros[current_g_idx] # (K, N)
# dequant: weight * scale - scale_zeros
dequantized_b = unpacked_q_weight.to(scales.dtype) * scale_mat - scale_zeros_mat
return dequantized_b.reshape(K, N)
def torch_gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
):
K, N = a.shape[1], b_q_weight.shape[1]
b_dequant = torch_dequantize(
b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit, K, N
)
c = torch.matmul(a, b_dequant)
return c
def _test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, device="cuda"):
b_fp = torch.randn(K, N, dtype=dtype, device=device)
assert K % group_size == 0, "K must be divisible by group_size"
num_groups = K // group_size
if use_shuffle:
return
else:
g_idx = torch.tensor(
[i // group_size for i in range(K)], dtype=torch.int32, device=device
)
b_shuffled = b_fp[g_idx]
b_grouped = b_shuffled.reshape(num_groups, group_size, N)
b_max = torch.max(b_grouped, dim=1, keepdim=True)[0]
b_min = torch.min(b_grouped, dim=1, keepdim=True)[0]
scales = (b_max - b_min) / (2**bit - 1)
scales = scales.clamp(min=1e-6)
zeros_float = (-b_min / scales).round()
q_b = (
(b_grouped / scales + zeros_float).round().clamp(0, 2**bit - 1).to(torch.uint8)
)
q_zeros_unpacked = zeros_float.to(torch.uint8) - 1
b_q_weight = pack_rows(q_b.reshape(K, N), bit, K, N)
q_zeros_unpacked = q_zeros_unpacked.reshape(num_groups, N)
b_gptq_qzeros = pack_cols(q_zeros_unpacked, bit, num_groups, N)
b_gptq_scales = scales.squeeze(1)
a = torch.randn(M, K, dtype=dtype, device=device)
c_ref = torch_gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
)
c_out = gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
)
rtol = 4e-2
atol = 4e-2
torch.testing.assert_close(c_ref, c_out, rtol=rtol, atol=atol)
print(
f"✅ Test passed: M={M}, N={N}, K={K}, bit={bit}, group_size={group_size}, use_shuffle={use_shuffle}, dtype={dtype}"
)
@pytest.mark.parametrize("M", [1, 8, 128])
@pytest.mark.parametrize("N", [2048, 4096])
@pytest.mark.parametrize("K", [2048, 4096])
@pytest.mark.parametrize("bit", [4])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("use_shuffle", [False])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gptq_gemm(M, N, K, bit, group_size, use_shuffle, dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
_test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,121 @@
import pytest
import torch
from sgl_kernel import gptq_marlin_gemm
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
# uint4 for awq
# uint4b8 for gptq
@pytest.mark.parametrize("k_chunk", [128])
@pytest.mark.parametrize("n_chunk", [64, 256])
@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", [False, True])
@pytest.mark.parametrize("is_k_full", [False, True])
@pytest.mark.parametrize("use_atomic_add", [False, True])
@pytest.mark.parametrize("use_fp32_reduce", [False, True])
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
if has_zp:
return
if size_k % group_size != 0:
return
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
if has_zp:
# AWQ style, unsigned + runtime zero-point
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size
)
g_idx = None
sort_indices = None
marlin_s2 = None
else:
# GPTQ style, unsigned + symmetric bias
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order
)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace(w_ref.device)
# marlin gemm
output = gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
# ref gemm
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref)
)
assert max_diff < 0.04
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])

View File

@@ -1,16 +1,32 @@
import numpy as np
import pytest
import torch
from sgl_kernel import awq_marlin_repack
from sgl_kernel import awq_marlin_repack, gptq_marlin_repack
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.utils import (
get_pack_factor,
gptq_quantize_weights,
pack_cols,
pack_rows,
quantize_weights,
sort_weights,
)
from sglang.test.test_marlin_utils import get_weight_perm, marlin_weights
GPTQ_MARLIN_TILE = 16
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 256]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
def awq_pack(
@@ -35,70 +51,6 @@ def awq_pack(
return pack_cols(q_w, num_bits, size_k, size_n)
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_packed
def get_weight_perm(num_bits: int):
perm_list: list[int] = []
for i in range(32):
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)])
@pytest.mark.parametrize("group_size", [16, 32])
@@ -130,6 +82,66 @@ def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):
torch.testing.assert_close(out_gpu, q_w_marlin)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", [scalar_types.uint4b8])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [False, True])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
):
m_factor, n_factor, k_factor = mnk_factors
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
if size_k % group_size != 0:
pytest.skip("size_k must be divisible by group_size")
# Create input
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
b_weight, quant_type, group_size, act_order
)
q_w_gptq = pack_rows(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
if act_order:
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
marlin_layout_perm = get_weight_perm(quant_type.size_bits)
q_w_marlin_ref = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, marlin_layout_perm
)
# Run Marlin repack GPU kernel
q_w_marlin = gptq_marlin_repack(
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits
)
torch.cuda.synchronize()
torch.testing.assert_close(q_w_marlin, q_w_marlin_ref)
if __name__ == "__main__":
import subprocess