Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -5,7 +5,8 @@ import tilelang
|
||||
import tilelang.language as T
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import ceil_div
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
|
||||
)
|
||||
@@ -131,7 +132,7 @@ def fp8_gemm_deepgemm(
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run DeepGEMM kernel
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
return out
|
||||
|
||||
|
||||
@@ -179,7 +180,7 @@ def calculate_diff(m: int, n: int, k: int):
|
||||
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
|
||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
out_deepgemm = fp8_gemm_deepgemm(
|
||||
x_fp8.clone(),
|
||||
@@ -300,7 +301,7 @@ def get_benchmark(tp_size):
|
||||
# Preprocess data before benchmarking
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user