Add Benchmark for DeepGEMM Group GEMM (#3993)

This commit is contained in:
Stefan He
2025-03-02 17:47:21 -08:00
committed by GitHub
parent 9cf4077294
commit b7e274f2d9
3 changed files with 502 additions and 2 deletions

View File

@@ -211,6 +211,7 @@ def get_benchmark(tp_size):
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Preprocess data before benchmarking
x_fp8, x_scale = per_token_cast_to_fp8(x)
y_fp8, y_scale = per_block_cast_to_fp8(y)
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())