[perf] experimental enhance fp8 per-tensor quant (#5370)
This commit is contained in:
@@ -7,10 +7,12 @@ import torch
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import input_to_float8
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
@@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase):
|
||||
self._static_quant_fp8(*params)
|
||||
|
||||
|
||||
class TestPerTensorQuantMlaFP8(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
LAST_D_EXT = [1024, 0]
|
||||
LAST_D = [512]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _per_tensor_quant_mla_fp8(self, num_tokens, d, last_d_ext, last_d, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
x = torch.rand(
|
||||
(num_tokens, d // last_d, last_d + last_d_ext),
|
||||
dtype=dtype,
|
||||
)
|
||||
x_sub, _ = x.split([last_d, last_d_ext], dim=-1)
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out, ref_s = input_to_float8(x_sub.transpose(0, 1))
|
||||
out, out_s = per_tensor_quant_mla_fp8(x_sub.transpose(0, 1))
|
||||
|
||||
self.assertTrue(out.is_contiguous())
|
||||
self.assertTrue(
|
||||
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(out_s.to(torch.float32), ref_s.to(torch.float32))
|
||||
)
|
||||
|
||||
def test_per_tensor_quant_mla_fp8(self):
|
||||
for params in itertools.product(
|
||||
self.NUM_TOKENS,
|
||||
self.D,
|
||||
self.LAST_D_EXT,
|
||||
self.LAST_D,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
num_tokens=params[0],
|
||||
d=params[1],
|
||||
last_d_ext=params[2],
|
||||
last_d=params[3],
|
||||
dtype=params[4],
|
||||
seed=params[5],
|
||||
):
|
||||
self._per_tensor_quant_mla_fp8(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
|
||||
Reference in New Issue
Block a user