sgl-kernel use cutlass latest version for fp8 blockwise gemm (#5207)

This commit is contained in:
Yi Zhang
2025-04-10 02:47:04 +08:00
committed by GitHub
parent 7f875f1293
commit ebf495f013
6 changed files with 86 additions and 923 deletions

View File

@@ -2,18 +2,22 @@ import argparse
import copy
import itertools
import deep_gemm
import torch
import triton
from deep_gemm import get_col_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul
def get_weight_shapes(args):
models_tps = list(itertools.product(args.models, args.tp_sizes))
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
total = [
# (512 + 64, 7168), # this weight is not supported by current kernel
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
@@ -52,6 +56,23 @@ def cdiv(a: int, b: int) -> int:
return -(a // -b)
def fp8_gemm_deepgemm(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""DeepGEMM implementation of FP8 GEMM"""
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
@@ -60,12 +81,12 @@ def scale_shape(shape, group_shape):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
x_log=False,
line_arg="provider",
line_vals=["vllm", "sgl-kernel"],
line_names=["vllm fp8 blockwise gemm", "sgl-kernel fp8 blockwise gemm"],
styles=[("blue", "-"), ("orange", "-")],
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
ylabel="GB/s",
plot_name="fp8 blockwise scaled matmul",
args={},
@@ -80,7 +101,7 @@ def benchmark(batch_size, provider, N, K):
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
@@ -89,11 +110,11 @@ def benchmark(batch_size, provider, N, K):
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_blockwise_scaled_mm(
a_fp8, b_fp8, scale_a, scale_b, torch.float16
@@ -101,19 +122,28 @@ def benchmark(batch_size, provider, N, K):
quantiles=quantiles,
)
if provider == "vllm":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
gbps = (
lambda ms: (
(2 * M * N * K - M * N) * a_fp8.element_size()
+ (3 * M * N) * scale_a.element_size()
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: w8a8_block_fp8_matmul(
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
),
quantiles=quantiles,
)
* 1e-9
/ (ms * 1e-3)
)
return gbps(ms), gbps(max_ms), gbps(min_ms)
if provider == "deepgemm":
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_deepgemm(
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
),
quantiles=quantiles,
)
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
if __name__ == "__main__":
@@ -136,6 +166,9 @@ if __name__ == "__main__":
NK_model_names = get_weight_shapes(args)
for N, K, model_name in NK_model_names:
if N % 128 != 0 or K % 128 != 0:
print(f"Skip {N=}, {K=} now")
continue
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,