Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# (batch_size, seq_len, hidden_size) and prints speed-up.
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -10,9 +11,31 @@ import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.testing
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
from sgl_kernel.utils import is_arch_support_pdl
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
# Optional imports
|
||||
try:
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
|
||||
FLASHINFER_AVAILABLE = True
|
||||
except ImportError:
|
||||
fused_add_rmsnorm = None
|
||||
rmsnorm = None
|
||||
FLASHINFER_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
vllm_ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def str2int_list(arg: str) -> List[int]:
|
||||
@@ -79,6 +102,10 @@ def rmsnorm_flashinfer(
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
if not FLASHINFER_AVAILABLE:
|
||||
# Fallback to naive implementation if FlashInfer is not available
|
||||
return rmsnorm_naive(x, weight, residual, eps)
|
||||
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
@@ -103,6 +130,10 @@ def rmsnorm_vllm(
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
if not VLLM_AVAILABLE:
|
||||
# Fallback to naive implementation if vLLM is not available
|
||||
return rmsnorm_naive(x, weight, residual, eps)
|
||||
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
@@ -179,37 +210,72 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
output_sglang = output_sglang[0]
|
||||
|
||||
print(f"Naive output={output_naive}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"VLLM output={output_vllm}")
|
||||
if FLASHINFER_AVAILABLE:
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
else:
|
||||
print("FlashInfer not available, skipped")
|
||||
if VLLM_AVAILABLE:
|
||||
print(f"VLLM output={output_vllm}")
|
||||
else:
|
||||
print("vLLM not available, skipped")
|
||||
print(f"SGLang output={output_sglang}")
|
||||
|
||||
if (
|
||||
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
# Only compare available implementations
|
||||
all_match = torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
|
||||
if FLASHINFER_AVAILABLE:
|
||||
all_match = all_match and torch.allclose(
|
||||
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
if VLLM_AVAILABLE:
|
||||
all_match = all_match and torch.allclose(
|
||||
output_naive, output_vllm, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
if all_match:
|
||||
print("✅ All available implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
|
||||
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
|
||||
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
default_batch_sizes = [1] # Single batch size for CI
|
||||
default_seq_lens = [64] # Single sequence length for CI
|
||||
default_hidden_sizes = [4096] # Single hidden size for CI
|
||||
else:
|
||||
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
|
||||
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
|
||||
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
|
||||
|
||||
|
||||
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
|
||||
return list(itertools.product(bsizes, slens, hsizes))
|
||||
|
||||
|
||||
# Filter providers based on availability
|
||||
available_providers = ["huggingface", "sglang"]
|
||||
available_names = ["HuggingFace", "SGL Kernel"]
|
||||
available_styles = [("blue", "-"), ("orange", "-")]
|
||||
|
||||
if FLASHINFER_AVAILABLE:
|
||||
available_providers.insert(-1, "flashinfer")
|
||||
available_names.insert(-1, "FlashInfer")
|
||||
available_styles.insert(-1, ("green", "-"))
|
||||
|
||||
if VLLM_AVAILABLE:
|
||||
available_providers.insert(-1, "vllm")
|
||||
available_names.insert(-1, "vLLM")
|
||||
available_styles.insert(-1, ("red", "-"))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "hidden_size"],
|
||||
x_vals=[],
|
||||
line_arg="provider",
|
||||
line_vals=["huggingface", "flashinfer", "vllm", "sglang"],
|
||||
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")],
|
||||
line_vals=available_providers,
|
||||
line_names=available_names,
|
||||
styles=available_styles,
|
||||
ylabel="µs (median) or × (speed-up)",
|
||||
plot_name="rmsnorm-performance",
|
||||
args={},
|
||||
@@ -242,6 +308,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||
)
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
if not FLASHINFER_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
return timed(
|
||||
lambda: rmsnorm_flashinfer(
|
||||
x.clone(),
|
||||
@@ -250,6 +318,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||
)
|
||||
)
|
||||
elif provider == "vllm":
|
||||
if not VLLM_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
return timed(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
@@ -267,13 +337,22 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||
)
|
||||
|
||||
# provider == "speedup"
|
||||
t_ref, _, _ = timed(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
if VLLM_AVAILABLE:
|
||||
t_ref, _, _ = timed(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
t_ref, _, _ = timed(
|
||||
lambda: rmsnorm_naive(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
)
|
||||
t_sgl, _, _ = timed(
|
||||
lambda: rmsnorm_sglang(
|
||||
x.clone(),
|
||||
@@ -281,7 +360,7 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
spd = t_ref / t_sgl
|
||||
spd = t_ref / t_sgl if t_ref > 0 else 1.0
|
||||
return (spd, spd, spd)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user