support blockwise fp8 matmul kernel (#3267)
This commit is contained in:
148
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
Normal file
148
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
|
||||
def get_weight_shapes(args):
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
|
||||
# cannot TP
|
||||
total = [
|
||||
# (512 + 64, 7168), # this weight is not supported by current kernel
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
# only support Deepseek-V3
|
||||
SUPPORT_MODEL = ["deepseek-ai/DeepSeek-V3"]
|
||||
|
||||
weight_shapes = []
|
||||
for model, tp_size in models_tps:
|
||||
assert model in SUPPORT_MODEL
|
||||
for t in total:
|
||||
new_t = [t[0], t[1], model]
|
||||
weight_shapes.append(new_t)
|
||||
for n_t in n_tp:
|
||||
new_t = [n_t[0] // tp_size, n_t[1], model]
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = [k_t[0], k_t[1] // tp_size, model]
|
||||
weight_shapes.append(new_t)
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sgl-kernel"],
|
||||
line_names=["vllm fp8 blockwise gemm", "sgl-kernel fp8 blockwise gemm"],
|
||||
styles=[("blue", "-"), ("orange", "-")],
|
||||
ylabel="GB/s",
|
||||
plot_name="fp8 blockwise scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
|
||||
|
||||
scale_a_group_shape = (1, 128)
|
||||
scale_b_group_shape = (128, 128)
|
||||
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
|
||||
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
|
||||
|
||||
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl-kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_blockwise_scaled_mm(
|
||||
a_fp8, b_fp8, scale_a, scale_b, torch.float16
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
gbps = (
|
||||
lambda ms: (
|
||||
(2 * M * N * K - M * N) * a_fp8.element_size()
|
||||
+ (3 * M * N) * scale_a.element_size()
|
||||
)
|
||||
* 1e-9
|
||||
/ (ms * 1e-3)
|
||||
)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["deepseek-ai/DeepSeek-V3"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
NK_model_names = get_weight_shapes(args)
|
||||
for N, K, model_name in NK_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_fp8_blockwise_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
Reference in New Issue
Block a user