Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -5,7 +5,7 @@ import itertools
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
@@ -71,7 +71,7 @@ def fp8_gemm_deepgemm(
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run DeepGEMM kernel
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
return out
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ def benchmark(batch_size, provider, N, K):
|
||||
if provider == "sgl-kernel":
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: fp8_blockwise_scaled_mm(
|
||||
a_fp8, b_fp8, scale_a, scale_b, torch.float16
|
||||
),
|
||||
@@ -126,20 +126,20 @@ def benchmark(batch_size, provider, N, K):
|
||||
if provider == "vllm":
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: w8a8_block_fp8_matmul(
|
||||
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "deepgemm":
|
||||
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
scale_a_col_major = get_mn_major_tma_aligned_tensor(scale_a.clone())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: fp8_gemm_deepgemm(
|
||||
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
|
||||
),
|
||||
@@ -174,8 +174,6 @@ if __name__ == "__main__":
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_fp8_blockwise_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user