linear support deepgemm (#4199)
Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
@@ -29,10 +29,13 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
if _is_cuda:
|
||||
import deep_gemm
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_group_quant_fp8(
|
||||
@@ -722,34 +725,39 @@ def w8a8_block_fp8_matmul(
|
||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, config["BLOCK_SIZE_N"]
|
||||
)
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
# deepgemm only support bf16
|
||||
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
else:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@@ -11,6 +12,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
|
||||
|
||||
# For test
|
||||
def native_per_token_group_quant_fp8(
|
||||
@@ -208,13 +211,35 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
|
||||
|
||||
|
||||
class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
|
||||
M = [1, 7, 83, 512, 2048]
|
||||
N = [128, 512, 1024, 4096, 7748, 13824]
|
||||
K = [256, 4096, 5120, 3884, 13824]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
if not _is_cuda:
|
||||
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
|
||||
M = [1, 7, 83, 512, 2048]
|
||||
NKs = [
|
||||
(N, K)
|
||||
for N in [128, 512, 1024, 4096, 7748, 13824]
|
||||
for K in [256, 4096, 5120, 3884, 13824]
|
||||
]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
else:
|
||||
# use practical shape in DeepSeek V3 for test
|
||||
OUT_DTYPES = [torch.bfloat16]
|
||||
M = [64, 128, 512, 1024, 4096]
|
||||
NKs = [
|
||||
(1536, 7168),
|
||||
(3072, 1536),
|
||||
(24576, 7168),
|
||||
(4096, 512),
|
||||
(7168, 2048),
|
||||
(4608, 7168),
|
||||
(512, 7168),
|
||||
(7168, 2304),
|
||||
(7168, 512),
|
||||
]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -222,7 +247,8 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
|
||||
def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed):
|
||||
N, K = NK
|
||||
torch.manual_seed(seed)
|
||||
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
||||
factor_for_scale = 1e-2
|
||||
@@ -257,19 +283,17 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||
def test_w8a8_block_fp8_matmul(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.NKs,
|
||||
self.BLOCK_SIZE,
|
||||
self.OUT_DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
block_size=params[3],
|
||||
out_dtype=params[4],
|
||||
seed=params[5],
|
||||
NKs=params[1],
|
||||
block_size=params[2],
|
||||
out_dtype=params[3],
|
||||
seed=params[4],
|
||||
):
|
||||
self._w8a8_block_fp8_matmul(*params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user