linear support deepgemm (#4199)
Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@@ -11,6 +12,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
|
||||
|
||||
# For test
|
||||
def native_per_token_group_quant_fp8(
|
||||
@@ -208,13 +211,35 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
|
||||
|
||||
|
||||
class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
|
||||
M = [1, 7, 83, 512, 2048]
|
||||
N = [128, 512, 1024, 4096, 7748, 13824]
|
||||
K = [256, 4096, 5120, 3884, 13824]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
if not _is_cuda:
|
||||
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
|
||||
M = [1, 7, 83, 512, 2048]
|
||||
NKs = [
|
||||
(N, K)
|
||||
for N in [128, 512, 1024, 4096, 7748, 13824]
|
||||
for K in [256, 4096, 5120, 3884, 13824]
|
||||
]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
else:
|
||||
# use practical shape in DeepSeek V3 for test
|
||||
OUT_DTYPES = [torch.bfloat16]
|
||||
M = [64, 128, 512, 1024, 4096]
|
||||
NKs = [
|
||||
(1536, 7168),
|
||||
(3072, 1536),
|
||||
(24576, 7168),
|
||||
(4096, 512),
|
||||
(7168, 2048),
|
||||
(4608, 7168),
|
||||
(512, 7168),
|
||||
(7168, 2304),
|
||||
(7168, 512),
|
||||
]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -222,7 +247,8 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
|
||||
def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed):
|
||||
N, K = NK
|
||||
torch.manual_seed(seed)
|
||||
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
||||
factor_for_scale = 1e-2
|
||||
@@ -257,19 +283,17 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||
def test_w8a8_block_fp8_matmul(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.NKs,
|
||||
self.BLOCK_SIZE,
|
||||
self.OUT_DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
block_size=params[3],
|
||||
out_dtype=params[4],
|
||||
seed=params[5],
|
||||
NKs=params[1],
|
||||
block_size=params[2],
|
||||
out_dtype=params[3],
|
||||
seed=params[4],
|
||||
):
|
||||
self._w8a8_block_fp8_matmul(*params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user