Reorganize CI and test files (#9027)
This commit is contained in:
45
test/srt/quant/test_awq.py
Normal file
45
test/srt/quant/test_awq.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestAWQ(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--trust-remote-code"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
175
test/srt/quant/test_awq_dequant.py
Normal file
175
test/srt/quant/test_awq_dequant.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/quantization/test_awq_triton.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
unittest version of the AWQ Triton kernel tests.
|
||||
|
||||
Run with:
|
||||
python -m unittest test_awq_dequant.py
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.awq_triton import (
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
|
||||
awq_dequantize_triton,
|
||||
awq_gemm_triton,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def reverse_awq_order(t: torch.Tensor) -> torch.Tensor:
|
||||
bits = 4
|
||||
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
idx = torch.arange(t.shape[-1], dtype=torch.int32, device=t.device)
|
||||
idx = idx.view(-1, 32 // bits)[:, AWQ_REVERSE_ORDER].view(-1)
|
||||
return (t[:, idx] & 0xF).contiguous()
|
||||
|
||||
|
||||
def awq_dequantize_torch(
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> torch.Tensor:
|
||||
if group_size == -1:
|
||||
group_size = qweight.shape[0]
|
||||
|
||||
bits = 4
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
iweights = reverse_awq_order(iweights.view(iweights.shape[0], -1))
|
||||
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
zeros = reverse_awq_order(zeros.view(qzeros.shape[0], -1))
|
||||
|
||||
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
zeros = zeros.repeat_interleave(group_size, dim=0)
|
||||
return (iweights - zeros) * scales
|
||||
|
||||
|
||||
class TestAWQTriton(CustomTestCase):
|
||||
def test_dequantize(self):
|
||||
rows_list = [3584, 18944, 128, 256, 512, 1024]
|
||||
cols_list = [448, 576, 4736, 16, 32, 64, 128]
|
||||
|
||||
for qweight_rows in rows_list:
|
||||
for qweight_cols in cols_list:
|
||||
for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
|
||||
with self.subTest(
|
||||
rows=qweight_rows, cols=qweight_cols, g=group_size
|
||||
):
|
||||
self._run_dequant_case(
|
||||
qweight_rows=qweight_rows,
|
||||
qweight_cols=qweight_cols,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
def _run_dequant_case(self, qweight_rows, qweight_cols, group_size):
|
||||
if group_size == -1:
|
||||
group_size = qweight_rows
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
scales = torch.rand(
|
||||
qweight_rows // group_size,
|
||||
qweight_cols * 8,
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
zeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows // group_size, qweight_cols),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
ref = awq_dequantize_torch(qweight, scales, zeros, group_size)
|
||||
tri = awq_dequantize_triton(qweight, scales, zeros)
|
||||
|
||||
# sanity
|
||||
self.assertFalse(torch.any(torch.isinf(tri)) or torch.any(torch.isnan(tri)))
|
||||
torch.testing.assert_close(ref, tri)
|
||||
|
||||
# GEMM
|
||||
def test_gemm(self):
|
||||
N_list = [1, 2, 4, 8, 14, 17, 23, 32]
|
||||
K_list = [128]
|
||||
M_list = [16, 24, 32]
|
||||
splitK_list = [1, 8]
|
||||
|
||||
for N in N_list:
|
||||
for K in K_list:
|
||||
for M in M_list:
|
||||
for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
|
||||
for splitK in splitK_list:
|
||||
with self.subTest(N=N, K=K, M=M, g=group_size, sk=splitK):
|
||||
self._run_gemm_case(
|
||||
N=N,
|
||||
K=K,
|
||||
M=M,
|
||||
group_size=group_size,
|
||||
splitK=splitK,
|
||||
)
|
||||
|
||||
def _run_gemm_case(self, N, K, M, group_size, splitK):
|
||||
if group_size == -1:
|
||||
group_size = K
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
x = torch.rand((N, K), dtype=torch.float32, device=device)
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(K, M // 8),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(K // group_size, M // 8),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
scales = torch.rand((K // group_size, M), dtype=torch.float32, device=device)
|
||||
|
||||
tri_out = awq_gemm_triton(x, qweight, scales, qzeros, splitK)
|
||||
|
||||
self.assertFalse(
|
||||
torch.any(torch.isinf(tri_out)) or torch.any(torch.isnan(tri_out))
|
||||
)
|
||||
|
||||
# dequantize & compare
|
||||
w_deq = awq_dequantize_triton(qweight, scales, qzeros)
|
||||
ref_out = torch.matmul(x, w_deq)
|
||||
|
||||
self.assertFalse(
|
||||
torch.any(torch.isinf(ref_out)) or torch.any(torch.isnan(ref_out))
|
||||
)
|
||||
|
||||
torch.testing.assert_close(tri_out.cpu(), ref_out.cpu(), atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
227
test/srt/quant/test_block_int8.py
Normal file
227
test/srt/quant/test_block_int8.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
# For test
|
||||
def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8):
|
||||
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
|
||||
|
||||
It converts the tensor values into float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
Note that only `torch.float8_e4m3fn` is supported for now.
|
||||
"""
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
iinfo = torch.iinfo(dtype)
|
||||
int8_min = iinfo.min
|
||||
int8_max = iinfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / int8_max
|
||||
x_q = (x_ / x_s).clamp(min=int8_min, max=int8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
# For test
|
||||
def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (N,)
|
||||
A = A.reshape(M, A.shape[-1])
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0]
|
||||
assert k_tiles == Bs.shape[1]
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||
|
||||
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
|
||||
B_tiles = [
|
||||
[
|
||||
B[
|
||||
j * block_n : min((j + 1) * block_n, N),
|
||||
i * block_k : min((i + 1) * block_k, K),
|
||||
]
|
||||
for i in range(k_tiles)
|
||||
]
|
||||
for j in range(n_tiles)
|
||||
]
|
||||
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
|
||||
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
a = A_tiles[i]
|
||||
b = B_tiles[j][i]
|
||||
c = C_tiles[j]
|
||||
s = As_tiles[i] * Bs[j][i]
|
||||
c[:, :] += torch.matmul(a, b.t()) * s
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
|
||||
# For test
|
||||
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
"""This function performs fused moe with block-wise quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_int8_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_int8_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
class TestW8A8BlockINT8FusedMoE(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33, 64, 222]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8, 24]
|
||||
TOP_KS = [2, 6]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _w8a8_block_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
||||
factor_for_scale = 1e-2
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
int8_max, int8_min = int8_info.max, int8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1_s = (
|
||||
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
w2_s = (
|
||||
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(
|
||||
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
< 0.02
|
||||
)
|
||||
|
||||
def test_w8a8_block_int8_fused_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.E,
|
||||
self.TOP_KS,
|
||||
self.BLOCK_SIZE,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
block_size=params[5],
|
||||
dtype=params[6],
|
||||
seed=params[7],
|
||||
):
|
||||
self._w8a8_block_int8_fused_moe(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
126
test/srt/quant/test_fp8_kernel.py
Normal file
126
test/srt/quant/test_fp8_kernel.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestFP8Base(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.M = 256
|
||||
# test non-aligned
|
||||
cls.N = 1024 + 64
|
||||
cls.K = 512
|
||||
cls.group_size = 128
|
||||
cls.quant_type = torch.float8_e4m3fn
|
||||
cls.output_type = torch.bfloat16
|
||||
|
||||
@staticmethod
|
||||
def _make_A(M, K, group_size, out_dtype):
|
||||
quant_A = torch.rand(
|
||||
M, K // group_size, group_size, dtype=torch.float32, device="cuda"
|
||||
)
|
||||
# -1 ~ 1
|
||||
quant_A = quant_A * 2 - 1
|
||||
# scaling abs max to fmax
|
||||
finfo = torch.finfo(out_dtype)
|
||||
fmax = finfo.max
|
||||
scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
|
||||
quant_A *= scaling
|
||||
quant_A = quant_A.to(out_dtype).to(torch.float32)
|
||||
|
||||
# create scale and A
|
||||
scale = torch.rand(M, K // group_size, dtype=torch.float32, device="cuda")
|
||||
scale /= fmax
|
||||
A = quant_A * scale[..., None]
|
||||
|
||||
A = A.reshape(M, K)
|
||||
quant_A = quant_A.reshape(M, K).to(out_dtype)
|
||||
return A, quant_A, scale
|
||||
|
||||
@staticmethod
|
||||
def _make_B(K, N, group_size, out_dtype):
|
||||
def _aligned_size(a, b):
|
||||
return (a + b - 1) // b * b
|
||||
|
||||
K_aligned = _aligned_size(K, group_size)
|
||||
N_aligned = _aligned_size(N, group_size)
|
||||
|
||||
quant_B = torch.rand(
|
||||
K_aligned // group_size,
|
||||
group_size,
|
||||
N_aligned // group_size,
|
||||
group_size,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
quant_B = quant_B * 2 - 1
|
||||
|
||||
# scaling abs max to fmax
|
||||
finfo = torch.finfo(out_dtype)
|
||||
fmax = finfo.max
|
||||
scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True)
|
||||
quant_B *= scaling
|
||||
quant_B = quant_B.to(out_dtype).to(torch.float32)
|
||||
|
||||
scale = torch.rand(
|
||||
K_aligned // group_size,
|
||||
1,
|
||||
N_aligned // group_size,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
scale /= fmax
|
||||
|
||||
B = quant_B * scale
|
||||
|
||||
B = B.reshape(K_aligned, N_aligned)[:K, :N]
|
||||
quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N]
|
||||
scale = scale.reshape(K_aligned // group_size, N_aligned // group_size)
|
||||
return B, quant_B, scale
|
||||
|
||||
|
||||
class TestPerTokenGroupQuantFP8(TestFP8Base):
|
||||
def test_per_token_group_quant_fp8(self):
|
||||
if torch.cuda.get_device_capability()[0] < 9:
|
||||
return
|
||||
A, A_quant_gt, scale_gt = self._make_A(
|
||||
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
|
||||
)
|
||||
A_quant, scale = per_token_group_quant_fp8(x=A, group_size=self.group_size)
|
||||
torch.testing.assert_close(scale, scale_gt)
|
||||
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
|
||||
diff_count = (diff > 1e-5).count_nonzero()
|
||||
assert diff_count / diff.numel() < 1e-4
|
||||
|
||||
|
||||
class TestW8A8BlockFP8Matmul(TestFP8Base):
|
||||
def test_w8a8_block_fp8_matmul(self):
|
||||
if torch.cuda.get_device_capability()[0] < 9:
|
||||
return
|
||||
A, A_quant_gt, A_scale_gt = self._make_A(
|
||||
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
|
||||
)
|
||||
B, B_quant_gt, B_scale_gt = self._make_B(
|
||||
K=self.K, N=self.N, group_size=self.group_size, out_dtype=self.quant_type
|
||||
)
|
||||
C_gt = A.to(self.output_type) @ B.to(self.output_type)
|
||||
C = w8a8_block_fp8_matmul(
|
||||
A=A_quant_gt,
|
||||
B=B_quant_gt.T.contiguous(),
|
||||
As=A_scale_gt,
|
||||
Bs=B_scale_gt.T.contiguous(),
|
||||
block_size=[128, 128],
|
||||
output_dtype=self.output_type,
|
||||
)
|
||||
torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
114
test/srt/quant/test_fp8_kvcache.py
Normal file
114
test/srt/quant/test_fp8_kvcache.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestFp8KvcacheBase(CustomTestCase):
|
||||
model_config = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if cls.model_config is None:
|
||||
raise NotImplementedError("model_config must be specified in subclass")
|
||||
|
||||
cls.model = cls.model_config["model_name"]
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
dirpath = os.path.dirname(__file__)
|
||||
config_file = os.path.join(dirpath, cls.model_config["config_filename"])
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--kv-cache-dtype",
|
||||
"fp8_e4m3",
|
||||
"--quantization-param-path",
|
||||
config_file,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestFp8KvcacheLlama(TestFp8KvcacheBase):
|
||||
model_config = {
|
||||
"model_name": DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
"config_filename": "kv_cache_scales_llama3_8b.json",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mgsm_en(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mgsm_en",
|
||||
num_examples=None,
|
||||
num_threads=1024,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.80)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
|
||||
|
||||
class TestFp8KvcacheQwen(TestFp8KvcacheBase):
|
||||
model_config = {
|
||||
"model_name": DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
|
||||
"config_filename": "kv_cache_scales_qwen2_1_5b.json",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mgsm_en(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mgsm_en",
|
||||
num_examples=None,
|
||||
num_threads=1024,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.01)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
169
test/srt/quant/test_int8_kernel.py
Normal file
169
test/srt/quant/test_int8_kernel.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
"""This function performs fused moe with per-column int8 quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
class TestW8A8Int8FusedMoE(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8]
|
||||
TOP_KS = [2, 6]
|
||||
BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _w8a8_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
# Initialize int8 quantization parameters
|
||||
factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
)
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
use_fp8_w8a8=False, # Not using fp8
|
||||
use_int8_w8a16=False, # Not using int8-w8a16
|
||||
use_int8_w8a8=True, # Using int8-w8a8
|
||||
per_channel_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
)
|
||||
|
||||
# Check results
|
||||
self.assertTrue(
|
||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
< 0.05
|
||||
)
|
||||
|
||||
def test_w8a8_int8_fused_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.E,
|
||||
self.TOP_KS,
|
||||
self.BLOCK_SIZE,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
block_size=params[5],
|
||||
dtype=params[6],
|
||||
seed=params[7],
|
||||
):
|
||||
self._w8a8_int8_fused_moe(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
75
test/srt/quant/test_w8a8_quantization.py
Normal file
75
test/srt/quant/test_w8a8_quantization.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestW8A8(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--quantization", "w8a8_int8"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.69)
|
||||
|
||||
def run_decode(self, max_new_tokens):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
"ignore_eos": True,
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def test_throughput(self):
|
||||
max_tokens = 256
|
||||
|
||||
tic = time.perf_counter()
|
||||
res = self.run_decode(max_tokens)
|
||||
tok = time.perf_counter()
|
||||
print(res["text"])
|
||||
throughput = max_tokens / (tok - tic)
|
||||
print(f"Throughput: {throughput} tokens/s")
|
||||
assert throughput >= 140
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user