Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -4,7 +4,8 @@ import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import calc_diff
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
|
||||
# Import shared functionality from the regular GEMM benchmark
|
||||
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
|
||||
@@ -71,9 +72,9 @@ def construct_grouped_and_flat_fp8(
|
||||
# Transpose earlier for testing
|
||||
x_fp8_grouped = (
|
||||
x_fp8_grouped[0],
|
||||
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
||||
get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
||||
)
|
||||
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
|
||||
x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1]))
|
||||
|
||||
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
|
||||
|
||||
@@ -240,7 +241,7 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
|
||||
|
||||
|
||||
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out,
|
||||
|
||||
Reference in New Issue
Block a user