Apply sgl w8a8 fp8 kernel (#3148)

This commit is contained in:
HandH1998
2025-03-09 16:03:32 +08:00
committed by GitHub
parent 9fb48f951f
commit 0dd6cda288
13 changed files with 523 additions and 37 deletions

View File

@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
)
@@ -63,7 +64,7 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
out, scale = per_token_group_quant_fp8(x, group_size)
self.assertTrue(
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20)
)
self.assertTrue(torch.allclose(scale, ref_scale))
@@ -85,6 +86,71 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
self._per_token_group_quant_fp8(*params)
# For test
def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
"""Function to perform static quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1])
x_s_inv = 1.0 / x_s
x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
return x_q, x_s
class TestStaticQuantFP8(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _static_quant_fp8(self, num_tokens, d, dtype, seed):
torch.manual_seed(seed)
x = torch.rand(num_tokens, d, dtype=dtype)
fp8_max = torch.finfo(torch.float8_e4m3fn).max
x_s = x.max() / fp8_max
with torch.inference_mode():
ref_out, _ = native_static_quant_fp8(x, x_s)
out, _ = static_quant_fp8(x, x_s, repeat_scale=True)
self.assertTrue(
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
)
def test_static_quant_fp8(self):
for params in itertools.product(
self.NUM_TOKENS,
self.D,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
d=params[1],
dtype=params[2],
seed=params[3],
):
self._static_quant_fp8(*params)
# For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.