Restruct sgl-kernel benchmark (#10861)

This commit is contained in:
Xiaoyu Zhang
2025-09-25 07:45:25 +08:00
committed by GitHub
parent 7a06ef984d
commit c4e314f986
27 changed files with 425 additions and 319 deletions

View File

@@ -4,7 +4,8 @@ import deep_gemm
import torch
import triton
import triton.language as tl
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
from deep_gemm import calc_diff
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
# Import shared functionality from the regular GEMM benchmark
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
@@ -71,9 +72,9 @@ def construct_grouped_and_flat_fp8(
# Transpose earlier for testing
x_fp8_grouped = (
x_fp8_grouped[0],
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]),
)
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1]))
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
@@ -240,7 +241,7 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
x_fp8_grouped,
y_fp8_grouped,
out,