Reorganize CI and test files (#9027)

This commit is contained in:
Lianmin Zheng
2025-08-10 12:30:06 -07:00
committed by GitHub
parent b58ae7a2a0
commit 2c7f01bc89
66 changed files with 161 additions and 195 deletions

View File

@@ -0,0 +1,45 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestAWQ(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--trust-remote-code"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.64)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,175 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/quantization/test_awq_triton.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
unittest version of the AWQ Triton kernel tests.
Run with:
python -m unittest test_awq_dequant.py
"""
import unittest
import torch
from sglang.srt.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
awq_dequantize_triton,
awq_gemm_triton,
)
from sglang.test.test_utils import CustomTestCase
device = "cuda"
def reverse_awq_order(t: torch.Tensor) -> torch.Tensor:
bits = 4
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
idx = torch.arange(t.shape[-1], dtype=torch.int32, device=t.device)
idx = idx.view(-1, 32 // bits)[:, AWQ_REVERSE_ORDER].view(-1)
return (t[:, idx] & 0xF).contiguous()
def awq_dequantize_torch(
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
group_size: int,
) -> torch.Tensor:
if group_size == -1:
group_size = qweight.shape[0]
bits = 4
shifts = torch.arange(0, 32, bits, device=qzeros.device)
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8
)
iweights = reverse_awq_order(iweights.view(iweights.shape[0], -1))
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
torch.int8
)
zeros = reverse_awq_order(zeros.view(qzeros.shape[0], -1))
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
scales = scales.repeat_interleave(group_size, dim=0)
zeros = zeros.repeat_interleave(group_size, dim=0)
return (iweights - zeros) * scales
class TestAWQTriton(CustomTestCase):
def test_dequantize(self):
rows_list = [3584, 18944, 128, 256, 512, 1024]
cols_list = [448, 576, 4736, 16, 32, 64, 128]
for qweight_rows in rows_list:
for qweight_cols in cols_list:
for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
with self.subTest(
rows=qweight_rows, cols=qweight_cols, g=group_size
):
self._run_dequant_case(
qweight_rows=qweight_rows,
qweight_cols=qweight_cols,
group_size=group_size,
)
def _run_dequant_case(self, qweight_rows, qweight_cols, group_size):
if group_size == -1:
group_size = qweight_rows
torch.manual_seed(0)
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
dtype=torch.int32,
device=device,
)
scales = torch.rand(
qweight_rows // group_size,
qweight_cols * 8,
dtype=torch.float16,
device=device,
)
zeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_rows // group_size, qweight_cols),
dtype=torch.int32,
device=device,
)
ref = awq_dequantize_torch(qweight, scales, zeros, group_size)
tri = awq_dequantize_triton(qweight, scales, zeros)
# sanity
self.assertFalse(torch.any(torch.isinf(tri)) or torch.any(torch.isnan(tri)))
torch.testing.assert_close(ref, tri)
# GEMM
def test_gemm(self):
N_list = [1, 2, 4, 8, 14, 17, 23, 32]
K_list = [128]
M_list = [16, 24, 32]
splitK_list = [1, 8]
for N in N_list:
for K in K_list:
for M in M_list:
for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
for splitK in splitK_list:
with self.subTest(N=N, K=K, M=M, g=group_size, sk=splitK):
self._run_gemm_case(
N=N,
K=K,
M=M,
group_size=group_size,
splitK=splitK,
)
def _run_gemm_case(self, N, K, M, group_size, splitK):
if group_size == -1:
group_size = K
torch.manual_seed(0)
x = torch.rand((N, K), dtype=torch.float32, device=device)
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(K, M // 8),
dtype=torch.int32,
device=device,
)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(K // group_size, M // 8),
dtype=torch.int32,
device=device,
)
scales = torch.rand((K // group_size, M), dtype=torch.float32, device=device)
tri_out = awq_gemm_triton(x, qweight, scales, qzeros, splitK)
self.assertFalse(
torch.any(torch.isinf(tri_out)) or torch.any(torch.isnan(tri_out))
)
# dequantize & compare
w_deq = awq_dequantize_triton(qweight, scales, qzeros)
ref_out = torch.matmul(x, w_deq)
self.assertFalse(
torch.any(torch.isinf(ref_out)) or torch.any(torch.isnan(ref_out))
)
torch.testing.assert_close(tri_out.cpu(), ref_out.cpu(), atol=1e-1, rtol=1e-1)
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -0,0 +1,227 @@
import itertools
import unittest
import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.test.test_utils import CustomTestCase
# For test
def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Note that only `torch.float8_e4m3fn` is supported for now.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_min = iinfo.min
int8_max = iinfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_ / x_s).clamp(min=int8_min, max=int8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
# For test
def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
# For test
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_int8_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_int8_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
class TestW8A8BlockINT8FusedMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 33, 64, 222]
N = [128, 1024]
K = [256, 4096]
E = [8, 24]
TOP_KS = [2, 6]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_block_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed):
torch.manual_seed(seed)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale = 1e-2
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
* factor_for_scale
)
w2_s = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
* factor_for_scale
)
score = torch.randn((M, E), dtype=dtype)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
with torch.inference_mode():
out = fused_moe(
a,
w1,
w2,
topk_output,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
< 0.02
)
def test_w8a8_block_int8_fused_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
block_size=params[5],
dtype=params[6],
seed=params[7],
):
self._w8a8_block_int8_fused_moe(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -0,0 +1,126 @@
import unittest
import torch
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.test.test_utils import CustomTestCase
class TestFP8Base(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.M = 256
# test non-aligned
cls.N = 1024 + 64
cls.K = 512
cls.group_size = 128
cls.quant_type = torch.float8_e4m3fn
cls.output_type = torch.bfloat16
@staticmethod
def _make_A(M, K, group_size, out_dtype):
quant_A = torch.rand(
M, K // group_size, group_size, dtype=torch.float32, device="cuda"
)
# -1 ~ 1
quant_A = quant_A * 2 - 1
# scaling abs max to fmax
finfo = torch.finfo(out_dtype)
fmax = finfo.max
scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
quant_A *= scaling
quant_A = quant_A.to(out_dtype).to(torch.float32)
# create scale and A
scale = torch.rand(M, K // group_size, dtype=torch.float32, device="cuda")
scale /= fmax
A = quant_A * scale[..., None]
A = A.reshape(M, K)
quant_A = quant_A.reshape(M, K).to(out_dtype)
return A, quant_A, scale
@staticmethod
def _make_B(K, N, group_size, out_dtype):
def _aligned_size(a, b):
return (a + b - 1) // b * b
K_aligned = _aligned_size(K, group_size)
N_aligned = _aligned_size(N, group_size)
quant_B = torch.rand(
K_aligned // group_size,
group_size,
N_aligned // group_size,
group_size,
dtype=torch.float32,
device="cuda",
)
quant_B = quant_B * 2 - 1
# scaling abs max to fmax
finfo = torch.finfo(out_dtype)
fmax = finfo.max
scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True)
quant_B *= scaling
quant_B = quant_B.to(out_dtype).to(torch.float32)
scale = torch.rand(
K_aligned // group_size,
1,
N_aligned // group_size,
1,
dtype=torch.float32,
device="cuda",
)
scale /= fmax
B = quant_B * scale
B = B.reshape(K_aligned, N_aligned)[:K, :N]
quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N]
scale = scale.reshape(K_aligned // group_size, N_aligned // group_size)
return B, quant_B, scale
class TestPerTokenGroupQuantFP8(TestFP8Base):
def test_per_token_group_quant_fp8(self):
if torch.cuda.get_device_capability()[0] < 9:
return
A, A_quant_gt, scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
)
A_quant, scale = per_token_group_quant_fp8(x=A, group_size=self.group_size)
torch.testing.assert_close(scale, scale_gt)
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
diff_count = (diff > 1e-5).count_nonzero()
assert diff_count / diff.numel() < 1e-4
class TestW8A8BlockFP8Matmul(TestFP8Base):
def test_w8a8_block_fp8_matmul(self):
if torch.cuda.get_device_capability()[0] < 9:
return
A, A_quant_gt, A_scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
)
B, B_quant_gt, B_scale_gt = self._make_B(
K=self.K, N=self.N, group_size=self.group_size, out_dtype=self.quant_type
)
C_gt = A.to(self.output_type) @ B.to(self.output_type)
C = w8a8_block_fp8_matmul(
A=A_quant_gt,
B=B_quant_gt.T.contiguous(),
As=A_scale_gt,
Bs=B_scale_gt.T.contiguous(),
block_size=[128, 128],
output_dtype=self.output_type,
)
torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,114 @@
import os
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestFp8KvcacheBase(CustomTestCase):
model_config = None
@classmethod
def setUpClass(cls):
if cls.model_config is None:
raise NotImplementedError("model_config must be specified in subclass")
cls.model = cls.model_config["model_name"]
cls.base_url = DEFAULT_URL_FOR_TEST
dirpath = os.path.dirname(__file__)
config_file = os.path.join(dirpath, cls.model_config["config_filename"])
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--kv-cache-dtype",
"fp8_e4m3",
"--quantization-param-path",
config_file,
],
)
class TestFp8KvcacheLlama(TestFp8KvcacheBase):
model_config = {
"model_name": DEFAULT_MODEL_NAME_FOR_TEST,
"config_filename": "kv_cache_scales_llama3_8b.json",
}
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.80)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
class TestFp8KvcacheQwen(TestFp8KvcacheBase):
model_config = {
"model_name": DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
"config_filename": "kv_cache_scales_qwen2_1_5b.json",
}
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.01)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.3)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,169 @@
import itertools
import unittest
import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.test.test_utils import CustomTestCase
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B, D = a.shape
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
# Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
)
# Activation function
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
class TestW8A8Int8FusedMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 33]
N = [128, 1024]
K = [256, 4096]
E = [8]
TOP_KS = [2, 6]
BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed):
torch.manual_seed(seed)
# Initialize int8 quantization parameters
factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
# Input tensor
# M * K
a = torch.randn((M, K), dtype=dtype) / 10
# Generate int8 weights
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
# Generate scale for each column (per-column quantization)
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode():
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
out = fused_moe(
a,
w1,
w2,
topk_output,
use_fp8_w8a8=False, # Not using fp8
use_int8_w8a16=False, # Not using int8-w8a16
use_int8_w8a8=True, # Using int8-w8a8
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
)
# Check results
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
< 0.05
)
def test_w8a8_int8_fused_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
block_size=params[5],
dtype=params[6],
seed=params[7],
):
self._w8a8_int8_fused_moe(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -0,0 +1,75 @@
import time
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestW8A8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--quantization", "w8a8_int8"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.69)
def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
"ignore_eos": True,
},
)
return response.json()
def test_throughput(self):
max_tokens = 256
tic = time.perf_counter()
res = self.run_decode(max_tokens)
tok = time.perf_counter()
print(res["text"])
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
assert throughput >= 140
if __name__ == "__main__":
unittest.main()