Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,4 +1,11 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -6,21 +13,37 @@ import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import dsv3_router_gemm
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
num_tokens_vals = [1] # Only test 1 value in CI
|
||||
line_vals = ["sgl-kernel-256"] # Only test one implementation in CI
|
||||
else:
|
||||
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
|
||||
line_vals = ["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_vals=num_tokens_vals,
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
|
||||
line_names=[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
line_vals=line_vals,
|
||||
line_names=(
|
||||
[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
]
|
||||
if not IS_CI
|
||||
else ["dsv3_router_gemm-256"]
|
||||
),
|
||||
styles=(
|
||||
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
|
||||
if not IS_CI
|
||||
else [("orange", "-")]
|
||||
),
|
||||
ylabel="TFLOPs",
|
||||
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
|
||||
args={},
|
||||
@@ -64,17 +87,25 @@ def benchmark_bf16_output(num_tokens, impl):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_vals=num_tokens_vals,
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
|
||||
line_names=[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
line_vals=line_vals,
|
||||
line_names=(
|
||||
[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
]
|
||||
if not IS_CI
|
||||
else ["dsv3_router_gemm-256"]
|
||||
),
|
||||
styles=(
|
||||
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
|
||||
if not IS_CI
|
||||
else [("orange", "-")]
|
||||
),
|
||||
ylabel="TFLOPs",
|
||||
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
|
||||
args={},
|
||||
|
||||
Reference in New Issue
Block a user