[perf] introduce deep gemm group_gemm_masked as bmm (#5432)

This commit is contained in:
JieXin Liang
2025-04-20 15:38:27 +08:00
committed by GitHub
parent d07e797ace
commit 99456bcacb
3 changed files with 361 additions and 20 deletions

View File

@@ -7,6 +7,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
per_token_group_quant_fp8,
static_quant_fp8,
@@ -212,6 +213,62 @@ class TestPerTensorQuantMlaFP8(CustomTestCase):
self._per_tensor_quant_mla_fp8(*params)
class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
B = [128]
NUM_TOKENS = [7, 83, 2048, 1024 * 16]
D = [512, 128]
GROUP_SIZE = [128]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
self, b, num_tokens, d, dtype, group_size, seed
):
torch.manual_seed(seed)
x = torch.rand(b, num_tokens, d, dtype=dtype)
with torch.inference_mode():
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8(
x, group_size
)
out = out[:, :num_tokens, :]
scale = scale[:, :num_tokens, :]
self.assertTrue(
torch.allclose(
out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20, atol=1e-2
)
)
self.assertTrue(torch.allclose(scale, ref_scale))
def test_per_token_group_quant_mla_deep_gemm_masked_fp8(self):
for params in itertools.product(
self.B,
self.NUM_TOKENS,
self.D,
self.DTYPES,
self.GROUP_SIZE,
self.SEEDS,
):
with self.subTest(
b=params[0],
num_tokens=params[1],
d=params[2],
dtype=params[3],
group_size=params[4],
seed=params[5],
):
self._per_token_group_quant_mla_deep_gemm_masked_fp8(*params)
# For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
@@ -485,5 +542,115 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
self._w8a8_block_fp8_fused_moe(*params)
# For test
def torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_shape, out_dtype):
"""This function performs bmm with block-wise quantization using native torch."""
B, N, _ = w.shape
_, M, _ = a.shape
out = torch.empty((B, M, N), dtype=out_dtype, device=a.device)
for i in range(B):
out[i] = native_w8a8_block_fp8_matmul(
a[i], w[i], a_s[i], w_s[i], block_shape, output_dtype=out_dtype
)
return out
class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
DTYPES = [torch.bfloat16]
M = [1, 33, 64, 222, 8192]
N = [128, 512]
K = [128, 512]
BATCH = [128]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
try:
import deep_gemm
except ImportError:
raise unittest.SkipTest("DeepGEMM is not available")
torch.set_default_device("cuda")
def _w8a8_block_fp8_batched_deep_gemm(self, M, N, K, B, block_size, dtype, seed):
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = torch.randn((B, M, K), dtype=torch.float32) / 10
a = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w_fp32 = (torch.rand((B, N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max
w = w_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w = (N + block_n - 1) // block_n
k_tiles_w = (K + block_k - 1) // block_k
w_s = (
torch.rand((B, n_tiles_w, k_tiles_w), dtype=torch.float32)
* factor_for_scale
)
a_s = torch.rand((B, M, k_tiles_w), dtype=torch.float32) * factor_for_scale
ae = a.new_empty(B, (M + 255) // 256 * 256, K)
ae_s = a_s.new_empty(B, (M + 255) // 256 * 256, k_tiles_w)
oe = torch.empty((B, (M + 255) // 256 * 256, N), dtype=dtype)
ae[:, :M, :] = a
ae_s[:, :M, :] = a_s
masked_m = torch.full((B,), M, dtype=torch.int)
expected_m = M
lhs = (
ae,
ae_s,
)
rhs = (
w,
w_s,
)
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
with torch.inference_mode():
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
out = oe[:, :M, :]
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
< 0.0001
)
def test_w8a8_block_fp8_batched_deep_gemm(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.BATCH,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
B=params[3],
block_size=params[4],
dtype=params[5],
seed=params[6],
):
self._w8a8_block_fp8_batched_deep_gemm(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)