linear support deepgemm (#4199)

Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
lukec
2025-03-11 15:38:37 +08:00
committed by GitHub
parent 4d27eb9ad1
commit dce303e279
3 changed files with 76 additions and 44 deletions

View File

@@ -1,4 +1,5 @@
import itertools
import os
import unittest
import torch
@@ -11,6 +12,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul,
)
_is_cuda = torch.cuda.is_available() and torch.version.cuda
# For test
def native_per_token_group_quant_fp8(
@@ -208,13 +211,35 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
class TestW8A8BlockFP8Matmul(unittest.TestCase):
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 7, 83, 512, 2048]
N = [128, 512, 1024, 4096, 7748, 13824]
K = [256, 4096, 5120, 3884, 13824]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
if not _is_cuda:
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 7, 83, 512, 2048]
NKs = [
(N, K)
for N in [128, 512, 1024, 4096, 7748, 13824]
for K in [256, 4096, 5120, 3884, 13824]
]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
else:
# use practical shape in DeepSeek V3 for test
OUT_DTYPES = [torch.bfloat16]
M = [64, 128, 512, 1024, 4096]
NKs = [
(1536, 7168),
(3072, 1536),
(24576, 7168),
(4096, 512),
(7168, 2048),
(4608, 7168),
(512, 7168),
(7168, 2304),
(7168, 512),
]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
@@ -222,7 +247,8 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed):
N, K = NK
torch.manual_seed(seed)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale = 1e-2
@@ -257,19 +283,17 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
def test_w8a8_block_fp8_matmul(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.NKs,
self.BLOCK_SIZE,
self.OUT_DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
block_size=params[3],
out_dtype=params[4],
seed=params[5],
NKs=params[1],
block_size=params[2],
out_dtype=params[3],
seed=params[4],
):
self._w8a8_block_fp8_matmul(*params)