Fix sgl-kernel benchmark dead code (#11022)

This commit is contained in:
Xiaoyu Zhang
2025-09-29 15:06:40 +08:00
committed by GitHub
parent 71959545df
commit 11965b0daf
25 changed files with 1019 additions and 260 deletions

View File

@@ -2,6 +2,7 @@
# (batch_size, seq_len, hidden_size) and prints speed-up.
import argparse
import itertools
import os
import re
from typing import List, Optional, Tuple, Union
@@ -10,9 +11,31 @@ import torch
import torch.nn as nn
import triton
import triton.testing
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from sgl_kernel.utils import is_arch_support_pdl
from vllm import _custom_ops as vllm_ops
# Optional imports
try:
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
FLASHINFER_AVAILABLE = True
except ImportError:
fused_add_rmsnorm = None
rmsnorm = None
FLASHINFER_AVAILABLE = False
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def str2int_list(arg: str) -> List[int]:
@@ -79,6 +102,10 @@ def rmsnorm_flashinfer(
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
if not FLASHINFER_AVAILABLE:
# Fallback to naive implementation if FlashInfer is not available
return rmsnorm_naive(x, weight, residual, eps)
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
@@ -103,6 +130,10 @@ def rmsnorm_vllm(
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
if not VLLM_AVAILABLE:
# Fallback to naive implementation if vLLM is not available
return rmsnorm_naive(x, weight, residual, eps)
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
@@ -179,37 +210,72 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
output_sglang = output_sglang[0]
print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}")
if FLASHINFER_AVAILABLE:
print(f"FlashInfer output={output_flashinfer}")
else:
print("FlashInfer not available, skipped")
if VLLM_AVAILABLE:
print(f"VLLM output={output_vllm}")
else:
print("vLLM not available, skipped")
print(f"SGLang output={output_sglang}")
if (
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match")
# Only compare available implementations
all_match = torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
if FLASHINFER_AVAILABLE:
all_match = all_match and torch.allclose(
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
)
if VLLM_AVAILABLE:
all_match = all_match and torch.allclose(
output_naive, output_vllm, atol=1e-2, rtol=1e-2
)
if all_match:
print("✅ All available implementations match")
else:
print("❌ Implementations differ")
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
# CI environment uses simplified parameters
if IS_CI:
default_batch_sizes = [1] # Single batch size for CI
default_seq_lens = [64] # Single sequence length for CI
default_hidden_sizes = [4096] # Single hidden size for CI
else:
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
return list(itertools.product(bsizes, slens, hsizes))
# Filter providers based on availability
available_providers = ["huggingface", "sglang"]
available_names = ["HuggingFace", "SGL Kernel"]
available_styles = [("blue", "-"), ("orange", "-")]
if FLASHINFER_AVAILABLE:
available_providers.insert(-1, "flashinfer")
available_names.insert(-1, "FlashInfer")
available_styles.insert(-1, ("green", "-"))
if VLLM_AVAILABLE:
available_providers.insert(-1, "vllm")
available_names.insert(-1, "vLLM")
available_styles.insert(-1, ("red", "-"))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "hidden_size"],
x_vals=[],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm", "sglang"],
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")],
line_vals=available_providers,
line_names=available_names,
styles=available_styles,
ylabel="µs (median) or × (speed-up)",
plot_name="rmsnorm-performance",
args={},
@@ -242,6 +308,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
)
)
elif provider == "flashinfer":
if not FLASHINFER_AVAILABLE:
return (0, 0, 0)
return timed(
lambda: rmsnorm_flashinfer(
x.clone(),
@@ -250,6 +318,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
)
)
elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
return timed(
lambda: rmsnorm_vllm(
x.clone(),
@@ -267,13 +337,22 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
)
# provider == "speedup"
t_ref, _, _ = timed(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
if VLLM_AVAILABLE:
t_ref, _, _ = timed(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
else:
t_ref, _, _ = timed(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
)
t_sgl, _, _ = timed(
lambda: rmsnorm_sglang(
x.clone(),
@@ -281,7 +360,7 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
residual.clone() if residual is not None else None,
)
)
spd = t_ref / t_sgl
spd = t_ref / t_sgl if t_ref > 0 else 1.0
return (spd, spd, spd)