Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,4 +1,11 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -6,16 +13,28 @@ import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import dsv3_fused_a_gemm
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
num_tokens_vals = [1] # Only test 1 value in CI
|
||||
line_vals = ["sgl-kernel"] # Only test sgl-kernel implementation in CI
|
||||
else:
|
||||
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
|
||||
line_vals = ["torch", "sgl-kernel"]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_vals=num_tokens_vals,
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch", "sgl-kernel"],
|
||||
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
|
||||
styles=[("blue", "-"), ("orange", "-")],
|
||||
line_vals=line_vals,
|
||||
line_names=(
|
||||
["torch (bf16)", "dsv3_fused_a_gemm"]
|
||||
if not IS_CI
|
||||
else ["dsv3_fused_a_gemm"]
|
||||
),
|
||||
styles=[("blue", "-"), ("orange", "-")] if not IS_CI else [("orange", "-")],
|
||||
ylabel="TFLOPs",
|
||||
plot_name="bf16 dsv3 fused a GEMM throughput",
|
||||
args={},
|
||||
|
||||
Reference in New Issue
Block a user