sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
121
sgl-kernel/tests/test_marlin_gemm.py
Normal file
121
sgl-kernel/tests/test_marlin_gemm.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import gptq_marlin_gemm
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace
|
||||
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(1, 7, 5),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(67, 13, 11),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
]
|
||||
|
||||
|
||||
# uint4 for awq
|
||||
# uint4b8 for gptq
|
||||
@pytest.mark.parametrize("k_chunk", [128])
|
||||
@pytest.mark.parametrize("n_chunk", [64, 256])
|
||||
@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", [False, True])
|
||||
@pytest.mark.parametrize("is_k_full", [False, True])
|
||||
@pytest.mark.parametrize("use_atomic_add", [False, True])
|
||||
@pytest.mark.parametrize("use_fp32_reduce", [False, True])
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
|
||||
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace(w_ref.device)
|
||||
|
||||
# marlin gemm
|
||||
output = gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
# ref gemm
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref)
|
||||
)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", str(__file__)])
|
||||
Reference in New Issue
Block a user