Fix sgl-kernel benchmark dead code (#11022)

This commit is contained in:
Xiaoyu Zhang
2025-09-29 15:06:40 +08:00
committed by GitHub
parent 71959545df
commit 11965b0daf
25 changed files with 1019 additions and 260 deletions

View File

@@ -1,4 +1,11 @@
import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import torch
import torch.nn.functional as F
@@ -6,21 +13,37 @@ import triton
import triton.testing
from sgl_kernel import dsv3_router_gemm
# CI environment uses simplified parameters
if IS_CI:
num_tokens_vals = [1] # Only test 1 value in CI
line_vals = ["sgl-kernel-256"] # Only test one implementation in CI
else:
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
line_vals = ["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_vals=num_tokens_vals,
x_log=False,
line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
line_vals=line_vals,
line_names=(
[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
]
if not IS_CI
else ["dsv3_router_gemm-256"]
),
styles=(
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
if not IS_CI
else [("orange", "-")]
),
ylabel="TFLOPs",
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
args={},
@@ -64,17 +87,25 @@ def benchmark_bf16_output(num_tokens, impl):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_vals=num_tokens_vals,
x_log=False,
line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
line_vals=line_vals,
line_names=(
[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
]
if not IS_CI
else ["dsv3_router_gemm-256"]
),
styles=(
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
if not IS_CI
else [("orange", "-")]
),
ylabel="TFLOPs",
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
args={},