linear support deepgemm (#4199)
Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
@@ -29,10 +29,13 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
|||||||
|
|
||||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
|
import deep_gemm
|
||||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _per_token_group_quant_fp8(
|
def _per_token_group_quant_fp8(
|
||||||
@@ -722,34 +725,39 @@ def w8a8_block_fp8_matmul(
|
|||||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||||
N, config["BLOCK_SIZE_N"]
|
N, config["BLOCK_SIZE_N"]
|
||||||
)
|
)
|
||||||
kernel = (
|
|
||||||
_w8a8_block_fp8_matmul_unrolledx4
|
|
||||||
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
|
||||||
else _w8a8_block_fp8_matmul
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel[grid](
|
# deepgemm only support bf16
|
||||||
A,
|
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||||
B,
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||||
C,
|
else:
|
||||||
As,
|
kernel = (
|
||||||
Bs,
|
_w8a8_block_fp8_matmul_unrolledx4
|
||||||
M,
|
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
||||||
N,
|
else _w8a8_block_fp8_matmul
|
||||||
K,
|
)
|
||||||
block_n,
|
|
||||||
block_k,
|
kernel[grid](
|
||||||
A.stride(-2),
|
A,
|
||||||
A.stride(-1),
|
B,
|
||||||
B.stride(1),
|
C,
|
||||||
B.stride(0),
|
As,
|
||||||
C.stride(-2),
|
Bs,
|
||||||
C.stride(-1),
|
M,
|
||||||
As.stride(-2),
|
N,
|
||||||
As.stride(-1),
|
K,
|
||||||
Bs.stride(1),
|
block_n,
|
||||||
Bs.stride(0),
|
block_k,
|
||||||
**config,
|
A.stride(-2),
|
||||||
)
|
A.stride(-1),
|
||||||
|
B.stride(1),
|
||||||
|
B.stride(0),
|
||||||
|
C.stride(-2),
|
||||||
|
C.stride(-1),
|
||||||
|
As.stride(-2),
|
||||||
|
As.stride(-1),
|
||||||
|
Bs.stride(1),
|
||||||
|
Bs.stride(0),
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
return C
|
return C
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -11,6 +12,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||||
|
|
||||||
|
|
||||||
# For test
|
# For test
|
||||||
def native_per_token_group_quant_fp8(
|
def native_per_token_group_quant_fp8(
|
||||||
@@ -208,13 +211,35 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
|
|||||||
|
|
||||||
|
|
||||||
class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||||
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
|
|
||||||
M = [1, 7, 83, 512, 2048]
|
if not _is_cuda:
|
||||||
N = [128, 512, 1024, 4096, 7748, 13824]
|
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
|
||||||
K = [256, 4096, 5120, 3884, 13824]
|
M = [1, 7, 83, 512, 2048]
|
||||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
NKs = [
|
||||||
BLOCK_SIZE = [[128, 128]]
|
(N, K)
|
||||||
SEEDS = [0]
|
for N in [128, 512, 1024, 4096, 7748, 13824]
|
||||||
|
for K in [256, 4096, 5120, 3884, 13824]
|
||||||
|
]
|
||||||
|
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||||
|
BLOCK_SIZE = [[128, 128]]
|
||||||
|
SEEDS = [0]
|
||||||
|
else:
|
||||||
|
# use practical shape in DeepSeek V3 for test
|
||||||
|
OUT_DTYPES = [torch.bfloat16]
|
||||||
|
M = [64, 128, 512, 1024, 4096]
|
||||||
|
NKs = [
|
||||||
|
(1536, 7168),
|
||||||
|
(3072, 1536),
|
||||||
|
(24576, 7168),
|
||||||
|
(4096, 512),
|
||||||
|
(7168, 2048),
|
||||||
|
(4608, 7168),
|
||||||
|
(512, 7168),
|
||||||
|
(7168, 2304),
|
||||||
|
(7168, 512),
|
||||||
|
]
|
||||||
|
BLOCK_SIZE = [[128, 128]]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -222,7 +247,8 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
|||||||
raise unittest.SkipTest("CUDA is not available")
|
raise unittest.SkipTest("CUDA is not available")
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
|
def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed):
|
||||||
|
N, K = NK
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
||||||
factor_for_scale = 1e-2
|
factor_for_scale = 1e-2
|
||||||
@@ -257,19 +283,17 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
|||||||
def test_w8a8_block_fp8_matmul(self):
|
def test_w8a8_block_fp8_matmul(self):
|
||||||
for params in itertools.product(
|
for params in itertools.product(
|
||||||
self.M,
|
self.M,
|
||||||
self.N,
|
self.NKs,
|
||||||
self.K,
|
|
||||||
self.BLOCK_SIZE,
|
self.BLOCK_SIZE,
|
||||||
self.OUT_DTYPES,
|
self.OUT_DTYPES,
|
||||||
self.SEEDS,
|
self.SEEDS,
|
||||||
):
|
):
|
||||||
with self.subTest(
|
with self.subTest(
|
||||||
M=params[0],
|
M=params[0],
|
||||||
N=params[1],
|
NKs=params[1],
|
||||||
K=params[2],
|
block_size=params[2],
|
||||||
block_size=params[3],
|
out_dtype=params[3],
|
||||||
out_dtype=params[4],
|
seed=params[4],
|
||||||
seed=params[5],
|
|
||||||
):
|
):
|
||||||
self._w8a8_block_fp8_matmul(*params)
|
self._w8a8_block_fp8_matmul(*params)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class TestFP8Base(unittest.TestCase):
|
|||||||
cls.K = 512
|
cls.K = 512
|
||||||
cls.group_size = 128
|
cls.group_size = 128
|
||||||
cls.quant_type = torch.float8_e4m3fn
|
cls.quant_type = torch.float8_e4m3fn
|
||||||
cls.output_type = torch.float16
|
cls.output_type = torch.bfloat16
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _make_A(M, K, group_size, out_dtype):
|
def _make_A(M, K, group_size, out_dtype):
|
||||||
|
|||||||
Reference in New Issue
Block a user