[perf] introduce deep gemm group_gemm_masked as bmm (#5432)
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_fp8,
|
||||
static_quant_fp8,
|
||||
@@ -212,6 +213,62 @@ class TestPerTensorQuantMlaFP8(CustomTestCase):
|
||||
self._per_tensor_quant_mla_fp8(*params)
|
||||
|
||||
|
||||
class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
B = [128]
|
||||
NUM_TOKENS = [7, 83, 2048, 1024 * 16]
|
||||
D = [512, 128]
|
||||
GROUP_SIZE = [128]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||
self, b, num_tokens, d, dtype, group_size, seed
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
x = torch.rand(b, num_tokens, d, dtype=dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
|
||||
out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
x, group_size
|
||||
)
|
||||
out = out[:, :num_tokens, :]
|
||||
scale = scale[:, :num_tokens, :]
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20, atol=1e-2
|
||||
)
|
||||
)
|
||||
self.assertTrue(torch.allclose(scale, ref_scale))
|
||||
|
||||
def test_per_token_group_quant_mla_deep_gemm_masked_fp8(self):
|
||||
for params in itertools.product(
|
||||
self.B,
|
||||
self.NUM_TOKENS,
|
||||
self.D,
|
||||
self.DTYPES,
|
||||
self.GROUP_SIZE,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
b=params[0],
|
||||
num_tokens=params[1],
|
||||
d=params[2],
|
||||
dtype=params[3],
|
||||
group_size=params[4],
|
||||
seed=params[5],
|
||||
):
|
||||
self._per_token_group_quant_mla_deep_gemm_masked_fp8(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
@@ -485,5 +542,115 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
||||
self._w8a8_block_fp8_fused_moe(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_shape, out_dtype):
|
||||
"""This function performs bmm with block-wise quantization using native torch."""
|
||||
|
||||
B, N, _ = w.shape
|
||||
_, M, _ = a.shape
|
||||
out = torch.empty((B, M, N), dtype=out_dtype, device=a.device)
|
||||
|
||||
for i in range(B):
|
||||
out[i] = native_w8a8_block_fp8_matmul(
|
||||
a[i], w[i], a_s[i], w_s[i], block_shape, output_dtype=out_dtype
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
|
||||
DTYPES = [torch.bfloat16]
|
||||
M = [1, 33, 64, 222, 8192]
|
||||
N = [128, 512]
|
||||
K = [128, 512]
|
||||
BATCH = [128]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
try:
|
||||
import deep_gemm
|
||||
except ImportError:
|
||||
raise unittest.SkipTest("DeepGEMM is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _w8a8_block_fp8_batched_deep_gemm(self, M, N, K, B, block_size, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a_fp32 = torch.randn((B, M, K), dtype=torch.float32) / 10
|
||||
a = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w_fp32 = (torch.rand((B, N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
w = w_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w = (N + block_n - 1) // block_n
|
||||
k_tiles_w = (K + block_k - 1) // block_k
|
||||
|
||||
w_s = (
|
||||
torch.rand((B, n_tiles_w, k_tiles_w), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
a_s = torch.rand((B, M, k_tiles_w), dtype=torch.float32) * factor_for_scale
|
||||
|
||||
ae = a.new_empty(B, (M + 255) // 256 * 256, K)
|
||||
ae_s = a_s.new_empty(B, (M + 255) // 256 * 256, k_tiles_w)
|
||||
oe = torch.empty((B, (M + 255) // 256 * 256, N), dtype=dtype)
|
||||
ae[:, :M, :] = a
|
||||
ae_s[:, :M, :] = a_s
|
||||
|
||||
masked_m = torch.full((B,), M, dtype=torch.int)
|
||||
expected_m = M
|
||||
lhs = (
|
||||
ae,
|
||||
ae_s,
|
||||
)
|
||||
rhs = (
|
||||
w,
|
||||
w_s,
|
||||
)
|
||||
|
||||
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
||||
out = oe[:, :M, :]
|
||||
|
||||
self.assertTrue(
|
||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
< 0.0001
|
||||
)
|
||||
|
||||
def test_w8a8_block_fp8_batched_deep_gemm(self):
|
||||
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.BATCH,
|
||||
self.BLOCK_SIZE,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
B=params[3],
|
||||
block_size=params[4],
|
||||
dtype=params[5],
|
||||
seed=params[6],
|
||||
):
|
||||
self._w8a8_block_fp8_batched_deep_gemm(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user