add unit test for block wise fp8 (#3156)
This commit is contained in:
@@ -52,6 +52,7 @@ suites = {
|
||||
"test_w8a8_quantization.py",
|
||||
"test_session_control.py",
|
||||
"test_fp8_kvcache.py",
|
||||
"test_fp8_kernel.py",
|
||||
],
|
||||
"nightly": [
|
||||
"test_nightly_gsm8k_eval.py",
|
||||
|
||||
129
test/srt/test_fp8_kernel.py
Normal file
129
test/srt/test_fp8_kernel.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
|
||||
class TestFP8Base(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.M = 256
|
||||
# test non-aligned
|
||||
cls.N = 1024 + 64
|
||||
cls.K = 512
|
||||
cls.group_size = 128
|
||||
cls.quant_type = torch.float8_e4m3fn
|
||||
cls.output_type = torch.float16
|
||||
|
||||
@staticmethod
|
||||
def _make_A(M, K, group_size, out_dtype):
|
||||
quant_A = torch.rand(
|
||||
M, K // group_size, group_size, dtype=torch.float32, device="cuda"
|
||||
)
|
||||
# -1 ~ 1
|
||||
quant_A = quant_A * 2 - 1
|
||||
# scaling abs max to fmax
|
||||
finfo = torch.finfo(out_dtype)
|
||||
fmax = finfo.max
|
||||
scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
|
||||
quant_A *= scaling
|
||||
quant_A = quant_A.to(out_dtype).to(torch.float32)
|
||||
|
||||
# create scale and A
|
||||
scale = torch.rand(M, K // group_size, dtype=torch.float32, device="cuda")
|
||||
scale /= fmax
|
||||
A = quant_A * scale[..., None]
|
||||
|
||||
A = A.reshape(M, K)
|
||||
quant_A = quant_A.reshape(M, K).to(out_dtype)
|
||||
return A, quant_A, scale
|
||||
|
||||
@staticmethod
|
||||
def _make_B(K, N, group_size, out_dtype):
|
||||
def _aligned_size(a, b):
|
||||
return (a + b - 1) // b * b
|
||||
|
||||
K_aligned = _aligned_size(K, group_size)
|
||||
N_aligned = _aligned_size(N, group_size)
|
||||
|
||||
quant_B = torch.rand(
|
||||
K_aligned // group_size,
|
||||
group_size,
|
||||
N_aligned // group_size,
|
||||
group_size,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
quant_B = quant_B * 2 - 1
|
||||
|
||||
# scaling abs max to fmax
|
||||
finfo = torch.finfo(out_dtype)
|
||||
fmax = finfo.max
|
||||
scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True)
|
||||
quant_B *= scaling
|
||||
quant_B = quant_B.to(out_dtype).to(torch.float32)
|
||||
|
||||
scale = torch.rand(
|
||||
K_aligned // group_size,
|
||||
1,
|
||||
N_aligned // group_size,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
scale /= fmax
|
||||
|
||||
B = quant_B * scale
|
||||
|
||||
B = B.reshape(K_aligned, N_aligned)[:K, :N]
|
||||
quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N]
|
||||
scale = scale.reshape(K_aligned // group_size, N_aligned // group_size)
|
||||
return B, quant_B, scale
|
||||
|
||||
|
||||
class TestPerTokenGroupQuantFP8(TestFP8Base):
|
||||
def test_per_token_group_quant_fp8(self):
|
||||
if torch.cuda.get_device_capability()[0] < 9:
|
||||
return
|
||||
A, A_quant_gt, scale_gt = self._make_A(
|
||||
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
|
||||
)
|
||||
A_quant, scale = per_token_group_quant_fp8(
|
||||
x=A, group_size=self.group_size, dtype=self.quant_type
|
||||
)
|
||||
torch.testing.assert_close(scale, scale_gt)
|
||||
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
|
||||
diff_count = (diff > 1e-5).count_nonzero()
|
||||
assert diff_count / diff.numel() < 1e-4
|
||||
|
||||
|
||||
class TestW8A8BlockFP8Matmul(TestFP8Base):
|
||||
def test_w8a8_block_fp8_matmul(self):
|
||||
if torch.cuda.get_device_capability()[0] < 9:
|
||||
return
|
||||
A, A_quant_gt, A_scale_gt = self._make_A(
|
||||
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
|
||||
)
|
||||
B, B_quant_gt, B_scale_gt = self._make_B(
|
||||
K=self.K, N=self.N, group_size=self.group_size, out_dtype=self.quant_type
|
||||
)
|
||||
C_gt = A.to(self.output_type) @ B.to(self.output_type)
|
||||
C = w8a8_block_fp8_matmul(
|
||||
A=A_quant_gt,
|
||||
B=B_quant_gt.T.contiguous(),
|
||||
As=A_scale_gt,
|
||||
Bs=B_scale_gt.T.contiguous(),
|
||||
block_size=[128, 128],
|
||||
output_dtype=self.output_type,
|
||||
)
|
||||
torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user