[perf] experimental enhance fp8 per-tensor quant (#5370)

This commit is contained in:
JieXin Liang
2025-04-15 03:35:43 +08:00
committed by GitHub
parent e9fc2ac7b6
commit bdde237562
4 changed files with 178 additions and 13 deletions

View File

@@ -7,10 +7,12 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_fp8,
per_token_group_quant_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.srt.layers.quantization.fp8_utils import input_to_float8
from sglang.test.test_utils import CustomTestCase
_is_cuda = torch.cuda.is_available() and torch.version.cuda
@@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase):
self._static_quant_fp8(*params)
class TestPerTensorQuantMlaFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
LAST_D_EXT = [1024, 0]
LAST_D = [512]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _per_tensor_quant_mla_fp8(self, num_tokens, d, last_d_ext, last_d, dtype, seed):
torch.manual_seed(seed)
x = torch.rand(
(num_tokens, d // last_d, last_d + last_d_ext),
dtype=dtype,
)
x_sub, _ = x.split([last_d, last_d_ext], dim=-1)
with torch.inference_mode():
ref_out, ref_s = input_to_float8(x_sub.transpose(0, 1))
out, out_s = per_tensor_quant_mla_fp8(x_sub.transpose(0, 1))
self.assertTrue(out.is_contiguous())
self.assertTrue(
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
)
self.assertTrue(
torch.allclose(out_s.to(torch.float32), ref_s.to(torch.float32))
)
def test_per_tensor_quant_mla_fp8(self):
for params in itertools.product(
self.NUM_TOKENS,
self.D,
self.LAST_D_EXT,
self.LAST_D,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
d=params[1],
last_d_ext=params[2],
last_d=params[3],
dtype=params[4],
seed=params[5],
):
self._per_tensor_quant_mla_fp8(*params)
# For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.