Restruct sgl-kernel benchmark (#10861)

This commit is contained in:
Xiaoyu Zhang
2025-09-25 07:45:25 +08:00
committed by GitHub
parent 7a06ef984d
commit c4e314f986
27 changed files with 425 additions and 319 deletions

View File

@@ -5,7 +5,8 @@ import tilelang
import tilelang.language as T
import torch
import triton
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
from deep_gemm import ceil_div
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
)
@@ -131,7 +132,7 @@ def fp8_gemm_deepgemm(
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
@@ -179,7 +180,7 @@ def calculate_diff(m: int, n: int, k: int):
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
out_deepgemm = fp8_gemm_deepgemm(
x_fp8.clone(),
@@ -300,7 +301,7 @@ def get_benchmark(tp_size):
# Preprocess data before benchmarking
x_fp8, x_scale = per_token_cast_to_fp8(x)
y_fp8, y_scale = per_block_cast_to_fp8(y)
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
quantiles = [0.5, 0.2, 0.8]