Fix sgl-kernel benchmark dead code (#11022)

This commit is contained in:
Xiaoyu Zhang
2025-09-29 15:06:40 +08:00
committed by GitHub
parent 71959545df
commit 11965b0daf
25 changed files with 1019 additions and 260 deletions

View File

@@ -155,6 +155,50 @@ jobs:
cd test/srt cd test/srt
python3 test_mla_deepseek_v3.py python3 test_mla_deepseek_v3.py
sgl-kernel-benchmark-test:
needs: [check-changes, sgl-kernel-build-wheels]
if: always() && !failure() && !cancelled()
runs-on: 1-gpu-runner
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
CI: true
steps:
- uses: actions/checkout@v4
- name: Cleanup
run: |
ls -alh sgl-kernel/dist || true
rm -rf sgl-kernel/dist/* || true
- name: Download artifacts
uses: actions/download-artifact@v4
with:
path: sgl-kernel/dist/
merge-multiple: true
pattern: wheel-python3.10-cuda12.9
- name: Install dependencies
run: |
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/ci_install_dependency.sh
- name: Run benchmark tests
timeout-minutes: 45
run: |
cd sgl-kernel/benchmark
echo "Running sgl-kernel benchmark tests in CI mode..."
echo "CI environment variable: $CI"
echo "GITHUB_ACTIONS environment variable: $GITHUB_ACTIONS"
for bench_file in bench_*.py; do
echo "Testing $bench_file..."
timeout 60 python3 "$bench_file" || echo "Warning: $bench_file timed out or failed, continuing..."
echo "Completed $bench_file"
echo "---"
done
echo "All benchmark tests completed!"
# =============================================== primary ==================================================== # =============================================== primary ====================================================
unit-test-frontend: unit-test-frontend:
@@ -647,7 +691,7 @@ jobs:
check-changes, check-changes,
sgl-kernel-build-wheels, sgl-kernel-build-wheels,
sgl-kernel-unit-test, sgl-kernel-mla-test, sgl-kernel-unit-test, sgl-kernel-mla-test, sgl-kernel-benchmark-test,
unit-test-frontend, unit-test-backend-1-gpu, unit-test-frontend, unit-test-backend-1-gpu,
unit-test-backend-2-gpu, unit-test-backend-4-gpu, unit-test-backend-8-gpu, unit-test-backend-2-gpu, unit-test-backend-4-gpu, unit-test-backend-8-gpu,

View File

@@ -2460,7 +2460,7 @@ class BumpAllocator:
def log_info_on_rank0(logger, msg): def log_info_on_rank0(logger, msg):
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() == 0: if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
logger.info(msg) logger.info(msg)

View File

@@ -2,6 +2,7 @@
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up. # (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse import argparse
import itertools import itertools
import os
import re import re
from typing import List, Tuple from typing import List, Tuple
@@ -11,7 +12,21 @@ import torch.nn.functional as F
import triton import triton
import triton.testing import triton.testing
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
# Optional vLLM import
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# gelu_quick is only available on HIP/ROCm platforms # gelu_quick is only available on HIP/ROCm platforms
try: try:
@@ -22,7 +37,7 @@ except ImportError:
GELU_QUICK_AVAILABLE = False GELU_QUICK_AVAILABLE = False
gelu_quick = None gelu_quick = None
if not hasattr(vllm_ops, "silu_and_mul"): if VLLM_AVAILABLE and not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C vllm_ops = torch.ops._C
@@ -40,6 +55,13 @@ def calculate_diff(
"""Compare vLLM with SGLang for one shape.""" """Compare vLLM with SGLang for one shape."""
device = torch.device("cuda") device = torch.device("cuda")
if not VLLM_AVAILABLE:
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] ⚠️ vLLM not available, skipping comparison"
)
return True
# activation-only quick GELU # activation-only quick GELU
if kernel == "gelu_quick": if kernel == "gelu_quick":
if not GELU_QUICK_AVAILABLE: if not GELU_QUICK_AVAILABLE:
@@ -68,19 +90,30 @@ def calculate_diff(
return ok return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"] # CI environment uses simplified parameters for kernels and dtypes too
if GELU_QUICK_AVAILABLE: if IS_CI:
kernels.append("gelu_quick") kernels = ["silu_and_mul"] # Only test one kernel in CI
dtypes = [torch.float16, torch.bfloat16] dtypes = [torch.float16] # Only test one dtype in CI
else:
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
kernels.append("gelu_quick")
dtypes = [torch.float16, torch.bfloat16]
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]: def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_)) return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16 # CI environment uses simplified parameters
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64 if IS_CI:
default_dims = [2**i for i in range(10, 15)] # 1024...16384 default_batch_sizes = [1] # Single batch size for CI
default_seq_lens = [1] # Single sequence length for CI
default_dims = [1024] # Single dimension for CI
else:
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(10, 15)] # 1024...16384
@triton.testing.perf_report( @triton.testing.perf_report(
@@ -102,16 +135,24 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device) x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel) if not VLLM_AVAILABLE and provider in ["vllm", "speedup"]:
# Skip vLLM-related benchmarks if vLLM is not available
return (0, 0, 0)
if VLLM_AVAILABLE:
vllm_kernel = getattr(vllm_ops, kernel)
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE: if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
# Skip benchmark for gelu_quick if not available # Skip benchmark for gelu_quick if not available
return (0, 0, 0) return (0, 0, 0)
sglang_kernel = getattr(sgl_kernel, kernel) sglang_kernel = getattr(sgl_kernel, kernel)
def baseline(): def baseline():
tmp = y0.clone() if VLLM_AVAILABLE:
vllm_kernel(tmp, x) tmp = y0.clone()
return tmp vllm_kernel(tmp, x)
return tmp
else:
return torch.zeros_like(y0)
def sglang(): def sglang():
return sglang_kernel(x) return sglang_kernel(x)
@@ -134,7 +175,7 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
# provider == "speedup" # provider == "speedup"
t_ref, _, _ = timed(baseline) t_ref, _, _ = timed(baseline)
t_sgl, _, _ = timed(sglang) t_sgl, _, _ = timed(sglang)
spd = t_ref / t_sgl spd = t_ref / t_sgl if t_ref > 0 else 1.0
return (spd, spd, spd) return (spd, spd, spd)

View File

@@ -1,16 +1,34 @@
import itertools import itertools
import os
from typing import List, Tuple from typing import List, Tuple
import torch import torch
import triton import triton
import triton.testing import triton.testing
from sgl_kernel import awq_dequantize from sgl_kernel import awq_dequantize
from vllm import _custom_ops as ops
# Optional vLLM import
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def vllm_awq_dequantize( def vllm_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation
return sglang_awq_dequantize(qweight, scales, qzeros)
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
@@ -43,6 +61,10 @@ def calculate_diff(qweight_row: int, qweight_col: int):
device=device, device=device,
) )
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping comparison")
return
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros) vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros) sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
@@ -56,8 +78,13 @@ def calculate_diff(qweight_row: int, qweight_col: int):
print("❌ Implementations differ") print("❌ Implementations differ")
qweight_row_range = [3584, 18944, 128, 256, 512, 1024] # CI environment uses simplified parameters
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128] if IS_CI:
qweight_row_range = [128] # Single row size for CI
qweight_cols_range = [16] # Single column size for CI
else:
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
configs = list(itertools.product(qweight_row_range, qweight_cols_range)) configs = list(itertools.product(qweight_row_range, qweight_cols_range))
@@ -67,9 +94,9 @@ configs = list(itertools.product(qweight_row_range, qweight_cols_range))
x_names=["qweight_row", "qweight_col"], x_names=["qweight_row", "qweight_col"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["vllm", "sglang"], line_vals=["vllm", "sglang"] if VLLM_AVAILABLE else ["sglang"],
line_names=["VLLM", "SGL Kernel"], line_names=["VLLM", "SGL Kernel"] if VLLM_AVAILABLE else ["SGL Kernel"],
styles=[("blue", "-"), ("green", "-")], styles=[("blue", "-"), ("green", "-")] if VLLM_AVAILABLE else [("green", "-")],
ylabel="us", ylabel="us",
plot_name="awq-dequantize-performance", plot_name="awq-dequantize-performance",
args={}, args={},
@@ -100,6 +127,8 @@ def benchmark(qweight_row, qweight_col, provider):
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "vllm": if provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
fn = lambda: vllm_awq_dequantize( fn = lambda: vllm_awq_dequantize(
qweight.clone(), scales.clone(), qzeros.clone() qweight.clone(), scales.clone(), qzeros.clone()
) )
@@ -114,5 +143,11 @@ def benchmark(qweight_row, qweight_col, provider):
if __name__ == "__main__": if __name__ == "__main__":
calculate_diff(qweight_row=3584, qweight_col=448) # Simplify for CI environment
if IS_CI:
qweight_row, qweight_col = 128, 16 # Smaller values for CI
else:
qweight_row, qweight_col = 3584, 448
calculate_diff(qweight_row=qweight_row, qweight_col=qweight_col)
benchmark.run(print_data=True) benchmark.run(print_data=True)

View File

@@ -1,13 +1,27 @@
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import torch import torch
import triton import triton
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
bs_range = [1, 8, 32, 64, 128, 256] from sglang.srt.utils import get_device_capability
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# CI environment uses simplified parameters
if IS_CI:
bs_range = [1] # Single batch size for CI
qlen_range = [64] # Single sequence length for CI
else:
bs_range = [1, 8, 32, 64, 128, 256]
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
configs = list(itertools.product(bs_range, qlen_range)) configs = list(itertools.product(bs_range, qlen_range))
@@ -131,13 +145,34 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
for block_size in args.block_sizes: # Skip in CI environment or unsupported architectures
for kv_split in args.num_kv_splits: if IS_CI:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ") major, minor = get_device_capability()
benchmark.run( if major is None or major < 10: # Requires compute capability 10.0+
print_data=True, print("Skipping Cutlass MLA benchmark in CI environment")
block_size=block_size, if major is not None:
num_kv_splits=kv_split, print(
) f"Cutlass MLA requires compute capability 10.0+, but found {major}.{minor}"
)
print("Benchmark finished!") else:
print("Could not determine device capability")
else:
for block_size in args.block_sizes:
for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
block_size=block_size,
num_kv_splits=kv_split,
)
print("Benchmark finished!")
else:
for block_size in args.block_sizes:
for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
block_size=block_size,
num_kv_splits=kv_split,
)
print("Benchmark finished!")

View File

@@ -1,4 +1,11 @@
import argparse import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -6,16 +13,28 @@ import triton
import triton.testing import triton.testing
from sgl_kernel import dsv3_fused_a_gemm from sgl_kernel import dsv3_fused_a_gemm
# CI environment uses simplified parameters
if IS_CI:
num_tokens_vals = [1] # Only test 1 value in CI
line_vals = ["sgl-kernel"] # Only test sgl-kernel implementation in CI
else:
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
line_vals = ["torch", "sgl-kernel"]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens"], x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)], x_vals=num_tokens_vals,
x_log=False, x_log=False,
line_arg="impl", line_arg="impl",
line_vals=["torch", "sgl-kernel"], line_vals=line_vals,
line_names=["torch (bf16)", "dsv3_fused_a_gemm"], line_names=(
styles=[("blue", "-"), ("orange", "-")], ["torch (bf16)", "dsv3_fused_a_gemm"]
if not IS_CI
else ["dsv3_fused_a_gemm"]
),
styles=[("blue", "-"), ("orange", "-")] if not IS_CI else [("orange", "-")],
ylabel="TFLOPs", ylabel="TFLOPs",
plot_name="bf16 dsv3 fused a GEMM throughput", plot_name="bf16 dsv3 fused a GEMM throughput",
args={}, args={},

View File

@@ -1,4 +1,11 @@
import argparse import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -6,21 +13,37 @@ import triton
import triton.testing import triton.testing
from sgl_kernel import dsv3_router_gemm from sgl_kernel import dsv3_router_gemm
# CI environment uses simplified parameters
if IS_CI:
num_tokens_vals = [1] # Only test 1 value in CI
line_vals = ["sgl-kernel-256"] # Only test one implementation in CI
else:
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
line_vals = ["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens"], x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)], x_vals=num_tokens_vals,
x_log=False, x_log=False,
line_arg="impl", line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"], line_vals=line_vals,
line_names=[ line_names=(
"torch-256", [
"dsv3_router_gemm-256", "torch-256",
"torch-384", "dsv3_router_gemm-256",
"dsv3_router_gemm-384", "torch-384",
], "dsv3_router_gemm-384",
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")], ]
if not IS_CI
else ["dsv3_router_gemm-256"]
),
styles=(
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
if not IS_CI
else [("orange", "-")]
),
ylabel="TFLOPs", ylabel="TFLOPs",
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput", plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
args={}, args={},
@@ -64,17 +87,25 @@ def benchmark_bf16_output(num_tokens, impl):
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens"], x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)], x_vals=num_tokens_vals,
x_log=False, x_log=False,
line_arg="impl", line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"], line_vals=line_vals,
line_names=[ line_names=(
"torch-256", [
"dsv3_router_gemm-256", "torch-256",
"torch-384", "dsv3_router_gemm-256",
"dsv3_router_gemm-384", "torch-384",
], "dsv3_router_gemm-384",
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")], ]
if not IS_CI
else ["dsv3_router_gemm-256"]
),
styles=(
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
if not IS_CI
else [("orange", "-")]
),
ylabel="TFLOPs", ylabel="TFLOPs",
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput", plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
args={}, args={},

View File

@@ -2,6 +2,7 @@ import argparse
import copy import copy
import csv import csv
import itertools import itertools
import os
import pytest import pytest
import torch import torch
@@ -9,6 +10,14 @@ import triton
from flashinfer import mm_fp4 from flashinfer import mm_fp4
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import get_device_capability
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
FLOAT4_E2M1_MAX = 6.0 FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
@@ -33,27 +42,34 @@ def get_weight_shapes(args):
] ]
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1, 8] # Simplified for CI
else:
batch_sizes = [
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[ x_vals=batch_sizes,
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
],
# x_vals = [64], # x_vals = [64],
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
@@ -188,21 +204,38 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
if args.csv: if args.csv:
with open(args.csv, "w", newline="") as f: with open(args.csv, "w", newline="") as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerow(["provider", "m", "n", "k", "time_ms"]) writer.writerow(["provider", "m", "n", "k", "time_ms"])
NKs = get_weight_shapes(args) # Check architecture compatibility - FP4 operations require sm100a/sm103a
for N, K in NKs: major, minor = get_device_capability()
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
benchmark.run( print("Skipping FP4 GEMM benchmark")
print_data=True, if major is not None:
N=N, print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
K=K, else:
dtype=args.dtype, print("Could not determine device capability")
correctness=args.correctness, else:
csv_file=args.csv, NKs = get_weight_shapes(args)
)
print("Benchmark finished!") # Limit iterations in CI
if IS_CI:
NKs = NKs[:2] # Only test first 2 shapes in CI
for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
N=N,
K=K,
dtype=args.dtype,
correctness=args.correctness,
csv_file=args.csv,
)
print("Benchmark finished!")

View File

@@ -1,18 +1,33 @@
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import deep_gemm import deep_gemm
import torch import torch
import triton import triton
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
# Optional vLLM import
try:
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
VLLM_AVAILABLE = True
except ImportError:
vllm_scaled_mm = None
VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul, w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
) )
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def get_weight_shapes(args): def get_weight_shapes(args):
models_tps = list(itertools.product(args.models, args.tp_sizes)) models_tps = list(itertools.product(args.models, args.tp_sizes))
@@ -80,15 +95,46 @@ def scale_shape(shape, group_shape):
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1, 8] # Simplified for CI
else:
batch_sizes = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
# Filter providers based on availability
available_providers = ["sgl-kernel"]
available_names = ["sgl-kernel"]
available_styles = [("orange", "-")]
if VLLM_AVAILABLE:
available_providers.insert(0, "vllm")
available_names.insert(0, "vllm")
available_styles.insert(0, ("blue", "-"))
available_providers.append("triton")
available_names.append("sglang triton")
available_styles.append(("red", "-"))
# Add deepgemm if available
try:
import deep_gemm
available_providers.append("deepgemm")
available_names.append("deepgemm")
available_styles.append(("yellow", "-"))
except ImportError:
pass
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], x_vals=batch_sizes,
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"], line_vals=available_providers,
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"], line_names=available_names,
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")], styles=available_styles,
ylabel="GB/s", ylabel="GB/s",
plot_name="fp8 blockwise scaled matmul", plot_name="fp8 blockwise scaled matmul",
args={}, args={},
@@ -123,14 +169,16 @@ def benchmark(batch_size, provider, N, K):
), ),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "vllm": elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
scale_a = scale_a.t().contiguous().t() scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t() b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16), lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "triton": elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: w8a8_block_fp8_matmul( lambda: w8a8_block_fp8_matmul(
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16 a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
@@ -166,7 +214,17 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.models = [args.models[0]] # Use only first model
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
NK_model_names = get_weight_shapes(args) NK_model_names = get_weight_shapes(args)
# Limit iterations in CI
if IS_CI:
NK_model_names = NK_model_names[:2] # Only test first 2 shapes in CI
for N, K, model_name in NK_model_names: for N, K, model_name in NK_model_names:
if N % 128 != 0 or K % 128 != 0: if N % 128 != 0 or K % 128 != 0:
print(f"Skip {N=}, {K=} now") print(f"Skip {N=}, {K=} now")

View File

@@ -1,4 +1,11 @@
import argparse import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Tuple from typing import List, Tuple
@@ -290,36 +297,44 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--num-warmup", type=int, default=3) parser.add_argument("--num-warmup", type=int, default=3)
parser.add_argument("--num-run", type=int, default=10) parser.add_argument("--num-run", type=int, default=10)
shape_args = [
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8 # CI environment uses simplified parameters
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256), if IS_CI:
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8 shape_args = [
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256), # Only test one simple shape in CI
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16 ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256), ]
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16 else:
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256), shape_args = [
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8 # Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256), ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16 # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256), ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8 # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32), ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16 # Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16), ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8 # Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32), ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16 # Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16), ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4 # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128), ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4 # Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128), ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4 # Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128), ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4 # Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128), ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
] # Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
]
args = parser.parse_args() args = parser.parse_args()
benchmark_one_shape(shape_args, args.num_warmup, args.num_run) benchmark_one_shape(shape_args, args.num_warmup, args.num_run)

View File

@@ -1,14 +1,30 @@
import argparse import argparse
import copy import copy
import itertools import itertools
import os
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import triton import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from sgl_kernel import sgl_per_tensor_quant_fp8 from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant # Optional vLLM import
try:
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
VLLM_AVAILABLE = True
except ImportError:
vllm_scaled_mm = None
vllm_scaled_fp8_quant = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# Weight Shapes are in the format # Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM) # ([K, N], TP_SPLIT_DIM)
@@ -86,25 +102,48 @@ def sglang_scaled_fp8_quant(
return output, scale return output, scale
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1] # Single batch size for CI
else:
batch_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048]
# Filter line_vals based on vLLM availability
if VLLM_AVAILABLE:
line_vals = [
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
line_names = [
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
styles = [("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")]
else:
line_vals = [
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
line_names = [
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
styles = [("blue", "-"), ("blue", "--")]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], x_vals=batch_sizes,
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
line_vals=[ line_vals=line_vals,
"vllm-fp8-fp16", line_names=line_names,
"vllm-fp8-bf16", styles=styles,
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
line_names=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s", ylabel="GB/s",
plot_name="fp8 scaled matmul", plot_name="fp8 scaled matmul",
args={}, args={},
@@ -115,6 +154,9 @@ def benchmark(batch_size, provider, N, K):
M = batch_size M = batch_size
a = torch.ones((M, K), device="cuda") * 5.0 a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0 b = torch.ones((N, K), device="cuda") * 5.0
# vLLM expects scalar scales, while sglang can handle per-token scales
scale_a_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
scale_b_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
@@ -122,8 +164,11 @@ def benchmark(batch_size, provider, N, K):
dtype = torch.float16 if "fp16" in provider else torch.bfloat16 dtype = torch.float16 if "fp16" in provider else torch.bfloat16
if "vllm-fp8" in provider: if "vllm-fp8" in provider:
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) if not VLLM_AVAILABLE:
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) # Return zero if vLLM is not available
return (0, 0, 0)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_scalar)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b_scalar)
b_fp8 = b_fp8.t() b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
@@ -174,6 +219,11 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.models = [args.models[0]] # Use only first model
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
KN_model_names = prepare_shapes(args) KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names: for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")

View File

@@ -1,11 +1,26 @@
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import torch import torch
import triton import triton
from sgl_kernel import int8_scaled_mm from sgl_kernel import int8_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
# Optional vLLM import
try:
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
VLLM_AVAILABLE = True
except ImportError:
vllm_scaled_mm = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
@@ -62,15 +77,32 @@ WEIGHT_SHAPES = {
} }
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1] # Single batch size for CI
else:
batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048]
# Filter providers based on vLLM availability
if VLLM_AVAILABLE:
line_vals = ["vllm", "sgl-kernel"]
line_names = ["vllm int8 gemm", "sgl-kernel int8 gemm"]
styles = [("blue", "-"), ("orange", "-")]
else:
line_vals = ["sgl-kernel"]
line_names = ["sgl-kernel int8 gemm"]
styles = [("orange", "-")]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], x_vals=batch_sizes,
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
line_vals=["vllm", "sgl-kernel"], line_vals=line_vals,
line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"], line_names=line_names,
styles=[("blue", "-"), ("orange", "-")], styles=styles,
ylabel="GB/s", ylabel="GB/s",
plot_name="int8 scaled matmul", plot_name="int8 scaled matmul",
args={}, args={},
@@ -90,7 +122,9 @@ def benchmark(batch_size, provider, N, K):
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "vllm": elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles, quantiles=quantiles,
@@ -136,9 +170,16 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
KN_model_names = prepare_shapes(args) # Skip in CI environment due to architecture compatibility issues
for K, N, model_name in KN_model_names: if IS_CI:
print(f"{model_name} N={N} K={K}: ") print(
benchmark.run(print_data=True, N=N, K=K) "Skipping INT8 GEMM benchmark in CI environment due to architecture compatibility issues"
)
print("INT8 operations may not be supported on all GPU architectures")
else:
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!") print("Benchmark finished!")

View File

@@ -1,11 +1,18 @@
import itertools import itertools
import math import math
import os
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import lightning_attention_decode from sgl_kernel import lightning_attention_decode
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def next_power_of_2(n): def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2)))) return 2 ** (int(math.ceil(math.log(n, 2))))
@@ -207,7 +214,12 @@ def calculate_diff(batch_size):
print("❌ Implementations differ") print("❌ Implementations differ")
batch_size_range = [i for i in range(1, 65)] # 1 to 128 # Simplified for CI environment
if IS_CI:
batch_size_range = [1] # Single batch size for CI
else:
batch_size_range = [i for i in range(1, 65)] # 1 to 64
configs = [(bs,) for bs in batch_size_range] configs = [(bs,) for bs in batch_size_range]
@@ -292,8 +304,9 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Run correctness test # Run correctness test - simplified for CI
calculate_diff(batch_size=4) test_batch_size = 1 if IS_CI else 4
calculate_diff(batch_size=test_batch_size)
# Run performance benchmark # Run performance benchmark
benchmark.run(print_data=True) benchmark.run(print_data=True)

View File

@@ -1,5 +1,6 @@
import argparse import argparse
import itertools import itertools
import os
import torch import torch
import triton import triton
@@ -8,8 +9,17 @@ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError: except ImportError:
ops = None ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
USE_RANDOM_PERM = False USE_RANDOM_PERM = False
@@ -197,19 +207,23 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
num_tokens_post_pad_triton, num_tokens_post_pad_triton,
) )
try: if VLLM_AVAILABLE:
ops.moe_align_block_size( try:
topk_ids, ops.moe_align_block_size(
num_experts, topk_ids,
block_size, num_experts,
sorted_ids_vllm, block_size,
expert_ids_vllm, sorted_ids_vllm,
num_tokens_post_pad_vllm, expert_ids_vllm,
) num_tokens_post_pad_vllm,
print(f"✅ VLLM implementation works with {num_experts} experts!") )
vllm_works = True print(f"✅ VLLM implementation works with {num_experts} experts!")
except Exception as e: vllm_works = True
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}") except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False
else:
print("⚠️ vLLM not available, skipping vLLM test")
vllm_works = False vllm_works = False
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
@@ -394,8 +408,18 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) # Simplify for CI environment
if IS_CI:
num_tokens = 256 # Smaller for CI
num_experts = 8 # Smaller for CI
topk = 2 # Smaller for CI
else:
num_tokens = 1024
num_experts = args.num_experts
topk = args.topk
if not args.skip_full_benchmark: calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk)
if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...") print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
benchmark.run(print_data=True) benchmark.run(print_data=True)

View File

@@ -1,9 +1,22 @@
import os
import torch import torch
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import triton import triton
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096] # CI environment uses simplified parameters
if IS_CI:
batch_sizes = [64, 128] # Only test 2 values in CI
else:
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
configs = [(bs,) for bs in batch_sizes] configs = [(bs,) for bs in batch_sizes]

View File

@@ -1,5 +1,6 @@
import itertools import itertools
import math import math
import os
import torch import torch
import triton import triton
@@ -8,6 +9,12 @@ from sgl_kernel import moe_fused_gate
from sglang.srt.layers.moe.topk import biased_grouped_topk from sglang.srt.layers.moe.topk import biased_grouped_topk
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk): def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
return biased_grouped_topk( return biased_grouped_topk(
@@ -28,7 +35,12 @@ def biased_grouped_topk_org_fuse_kernel(
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk) return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000] # CI environment uses simplified parameters
if IS_CI:
seq_length_range = [5000] # Only test one sequence length in CI
else:
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
configs = [(sq,) for sq in seq_length_range] configs = [(sq,) for sq in seq_length_range]

View File

@@ -1,13 +1,32 @@
import itertools import itertools
import os
import pytest import pytest
import torch import torch
import triton import triton
from sgl_kernel import topk_softmax from sgl_kernel import topk_softmax
from vllm import _custom_ops as vllm_custom_ops
# Optional vLLM import
try:
from vllm import _custom_ops as vllm_custom_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_custom_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def vllm_topk_softmax(gating_output, topk): def vllm_topk_softmax(gating_output, topk):
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation if vLLM is not available
return sglang_topk_softmax(gating_output, topk)
num_tokens, num_experts = gating_output.shape num_tokens, num_experts = gating_output.shape
topk_weights = torch.empty( topk_weights = torch.empty(
@@ -54,6 +73,10 @@ def calculate_diff(num_tokens, num_experts, topk):
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item() weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
indices_match = torch.equal(indices_vllm, indices_sglang) indices_match = torch.equal(indices_vllm, indices_sglang)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping comparison")
return
if ( if (
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3) torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
and indices_match and indices_match
@@ -65,21 +88,38 @@ def calculate_diff(num_tokens, num_experts, topk):
) )
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768] # CI environment uses simplified parameters
num_experts_range = [32, 64, 128, 256, 12, 512] if IS_CI:
topk_range = [1, 2, 4, 8] num_tokens_range = [128] # Single value for CI
num_experts_range = [32] # Single value for CI
topk_range = [2] # Single value for CI
else:
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
num_experts_range = [32, 64, 128, 256, 12, 512]
topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
# Filter providers based on vLLM availability
if VLLM_AVAILABLE:
line_vals = ["sglang", "vllm"]
line_names = ["SGLang", "VLLM"]
styles = [("blue", "-"), ("green", "-")]
else:
line_vals = ["sglang"]
line_names = ["SGLang"]
styles = [("blue", "-")]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"], x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["sglang", "vllm"], line_vals=line_vals,
line_names=["SGLang", "VLLM"], line_names=line_names,
styles=[("blue", "-"), ("green", "-")], styles=styles,
ylabel="Latency (us)", ylabel="Latency (us)",
plot_name="topk-softmax-performance", plot_name="topk-softmax-performance",
args={}, args={},
@@ -92,6 +132,8 @@ def benchmark(num_tokens, num_experts, topk, provider):
) )
if provider == "vllm" or provider == "vllm1": if provider == "vllm" or provider == "vllm1":
if not VLLM_AVAILABLE:
return (0, 0, 0)
fn = lambda: vllm_topk_softmax(gating_output, topk) fn = lambda: vllm_topk_softmax(gating_output, topk)
elif provider == "sglang" or provider == "sglang1": elif provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk) fn = lambda: sglang_topk_softmax(gating_output, topk)
@@ -103,14 +145,19 @@ def benchmark(num_tokens, num_experts, topk, provider):
if __name__ == "__main__": if __name__ == "__main__":
configs = [ # Simplify configs for CI environment
(20, 256, 4), if IS_CI:
(20, 256, 8), test_configs = [(20, 32, 2)] # Single config for CI
(20, 12, 4), else:
(20, 12, 1), test_configs = [
(20, 512, 4), (20, 256, 4),
(20, 512, 1), (20, 256, 8),
] (20, 12, 4),
for num_tokens, num_experts, topk in configs: (20, 12, 1),
(20, 512, 4),
(20, 512, 1),
]
for num_tokens, num_experts, topk in test_configs:
calculate_diff(num_tokens, num_experts, topk) calculate_diff(num_tokens, num_experts, topk)
benchmark.run(print_data=True) benchmark.run(print_data=True)

View File

@@ -1,11 +1,20 @@
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import torch import torch
import triton import triton
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import get_device_capability
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
FLOAT4_E2M1_MAX = 6.0 FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
@@ -162,9 +171,22 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
KN_model_names = prepare_shapes(args) # Check architecture compatibility - FP4 operations require sm100a/sm103a
for K, N, model_name in KN_model_names: major, minor = get_device_capability()
print(f"{model_name} N={N} K={K}: ") if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
benchmark.run(print_data=True, N=N, K=K) print("Skipping NVIDIA FP4 scaled GEMM benchmark")
if major is not None:
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
else:
print("Could not determine device capability")
else:
KN_model_names = prepare_shapes(args)
print("Benchmark finished!") # Limit iterations in CI
if IS_CI:
KN_model_names = KN_model_names[:2] # Only test first 2 shapes in CI
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")

View File

@@ -1,5 +1,6 @@
import itertools import itertools
import math import math
import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
@@ -7,11 +8,26 @@ import torch
import triton import triton
import triton.testing import triton.testing
from sgl_kernel import sgl_per_tensor_quant_fp8 from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm import _custom_ops as ops
# Optional imports
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
@@ -19,6 +35,9 @@ def vllm_scaled_fp8_quant(
input: torch.Tensor, input: torch.Tensor,
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation
return sglang_scaled_fp8_quant(input, scale)
return ops.scaled_fp8_quant(input, scale) return ops.scaled_fp8_quant(input, scale)
@@ -42,6 +61,10 @@ def calculate_diff(batch_size: int, seq_len: int):
device = torch.device("cuda") device = torch.device("cuda")
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device) x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping comparison")
return
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x) vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
@@ -56,8 +79,13 @@ def calculate_diff(batch_size: int, seq_len: int):
print("❌ Implementations differ") print("❌ Implementations differ")
batch_size_range = [16, 32, 64, 128] # CI environment uses simplified parameters
seq_len_range = [64, 128, 256, 512, 1024, 2048] if IS_CI:
batch_size_range = [16] # Single batch size for CI
seq_len_range = [64] # Single sequence length for CI
else:
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
configs = list(itertools.product(batch_size_range, seq_len_range)) configs = list(itertools.product(batch_size_range, seq_len_range))

View File

@@ -1,4 +1,5 @@
import itertools import itertools
import os
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@@ -16,15 +17,28 @@ from sglang.srt.layers.quantization.fp8_kernel import (
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] # CI environment uses simplified parameters
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 if IS_CI:
group_size_range = [128] # For DeepSeek V3/R1 num_tokens_range = [64] # Single value for CI
# TODO test int8 hidden_dim_range = [1536] # Single value for CI
dst_dtype_range = [fp8_type_] group_size_range = [128] # Keep as is
dst_dtype_range = [fp8_type_] # Keep as is
else:
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
group_size_range = [128] # For DeepSeek V3/R1
# TODO test int8
dst_dtype_range = [fp8_type_]
flags_range = [ flags_range = [
dict( dict(
column_major_scales=False, column_major_scales=False,
@@ -82,7 +96,7 @@ def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
fn, kernel_names = { fn, kernel_names = {
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"), "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_8bit"),
"sglang": ( "sglang": (
sglang_per_token_group_quant_8bit, sglang_per_token_group_quant_8bit,
"per_token_group_quant_8bit_kernel", "per_token_group_quant_8bit_kernel",

View File

@@ -1,15 +1,31 @@
import itertools import itertools
import os
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import triton import triton
import triton.testing import triton.testing
from sgl_kernel import sgl_per_token_quant_fp8 from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops
# Optional vLLM import
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# Get correct FP8 E4M3 maximum value # Get correct FP8 E4M3 maximum value
@@ -49,6 +65,9 @@ def torch_per_token_quant_fp8(
def vllm_per_token_quant_fp8( def vllm_per_token_quant_fp8(
input: torch.Tensor, input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation
return sglang_per_token_quant_fp8(input)
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True) return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
@@ -74,6 +93,17 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping vLLM comparison")
# Only compare Torch vs SGLang
torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
torch_sglang_out_diff = (
torch.abs(torch_out.float() - sglang_out.float()).mean().item()
)
print(f"Scale difference (Torch vs SGLang): {torch_sglang_scale_diff:.8f}")
print(f"Output difference (Torch vs SGLang): {torch_sglang_out_diff:.8f}")
return
print(f"\n=== Comparison for hidden_dim={hidden_dim} ===") print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
# Compare scales # Compare scales
@@ -125,9 +155,15 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
print(f" VLLM vs SGLang: {'' if vllm_sglang_match else ''}") print(f" VLLM vs SGLang: {'' if vllm_sglang_match else ''}")
batch_size_range = [16, 32, 64, 128] # CI environment uses simplified parameters
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096] if IS_CI:
hidden_dim_range = [1368, 2048, 4096] batch_size_range = [16] # Single batch size for CI
seq_len_range = [64] # Single sequence length for CI
hidden_dim_range = [2048] # Single hidden dimension for CI
else:
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
hidden_dim_range = [1368, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range)) configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
@@ -137,9 +173,19 @@ configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_ran
x_names=["batch_size", "seq_len", "hidden_dim"], x_names=["batch_size", "seq_len", "hidden_dim"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["torch", "vllm", "sglang"], line_vals=(
line_names=["Torch Reference", "VLLM", "SGL Kernel"], ["torch", "vllm", "sglang"] if VLLM_AVAILABLE else ["torch", "sglang"]
styles=[("red", "-"), ("blue", "-"), ("green", "-")], ),
line_names=(
["Torch Reference", "VLLM", "SGL Kernel"]
if VLLM_AVAILABLE
else ["Torch Reference", "SGL Kernel"]
),
styles=(
[("red", "-"), ("blue", "-"), ("green", "-")]
if VLLM_AVAILABLE
else [("red", "-"), ("green", "-")]
),
ylabel="us", ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance", plot_name="per-token-dynamic-quant-fp8-performance",
args={}, args={},
@@ -156,6 +202,8 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
if provider == "torch": if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone()) fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "vllm": elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
fn = lambda: vllm_per_token_quant_fp8(x.clone()) fn = lambda: vllm_per_token_quant_fp8(x.clone())
elif provider == "sglang": elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone()) fn = lambda: sglang_per_token_quant_fp8(x.clone())
@@ -166,11 +214,16 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
if __name__ == "__main__": if __name__ == "__main__":
# Test various hidden dimensions for correctness # Test various hidden dimensions for correctness - simplified for CI
test_dims = [1368, 2048, 4096] if IS_CI:
test_dims = [2048] # Single dimension for CI
batch_size, seq_len = 4, 64 # Smaller values for CI
else:
test_dims = [1368, 2048, 4096]
batch_size, seq_len = 4, 4096
for dim in test_dims: for dim in test_dims:
calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim) calculate_diff(batch_size=batch_size, seq_len=seq_len, hidden_dim=dim)
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("Starting performance benchmark...") print("Starting performance benchmark...")

View File

@@ -1,6 +1,7 @@
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import torch import torch
import triton import triton
@@ -10,6 +11,12 @@ from sgl_kernel import (
qserve_w4a8_per_group_gemm, qserve_w4a8_per_group_gemm,
) )
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
@@ -65,10 +72,17 @@ WEIGHT_SHAPES = {
} }
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1, 16] # Simplified for CI
else:
batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], x_vals=batch_sizes,
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"], line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
@@ -184,13 +198,19 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
KN_model_names = prepare_shapes(args) # Skip in CI environment
for K, N, model_name in KN_model_names: if IS_CI:
print(f"{model_name} N={N} K={K}: ") print("Skipping QServe W4A8 GEMM benchmark in CI environment")
benchmark.run( print("QServe operations may have compatibility issues in CI")
print_data=True, else:
N=N, KN_model_names = prepare_shapes(args)
K=K,
)
print("Benchmark finished!") for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
N=N,
K=K,
)
print("Benchmark finished!")

View File

@@ -2,6 +2,7 @@
# (batch_size, seq_len, hidden_size) and prints speed-up. # (batch_size, seq_len, hidden_size) and prints speed-up.
import argparse import argparse
import itertools import itertools
import os
import re import re
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -10,9 +11,31 @@ import torch
import torch.nn as nn import torch.nn as nn
import triton import triton
import triton.testing import triton.testing
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from sgl_kernel.utils import is_arch_support_pdl from sgl_kernel.utils import is_arch_support_pdl
from vllm import _custom_ops as vllm_ops
# Optional imports
try:
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
FLASHINFER_AVAILABLE = True
except ImportError:
fused_add_rmsnorm = None
rmsnorm = None
FLASHINFER_AVAILABLE = False
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def str2int_list(arg: str) -> List[int]: def str2int_list(arg: str) -> List[int]:
@@ -79,6 +102,10 @@ def rmsnorm_flashinfer(
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
if not FLASHINFER_AVAILABLE:
# Fallback to naive implementation if FlashInfer is not available
return rmsnorm_naive(x, weight, residual, eps)
orig_shape = x.shape orig_shape = x.shape
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
if residual is not None: if residual is not None:
@@ -103,6 +130,10 @@ def rmsnorm_vllm(
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
if not VLLM_AVAILABLE:
# Fallback to naive implementation if vLLM is not available
return rmsnorm_naive(x, weight, residual, eps)
orig_shape = x.shape orig_shape = x.shape
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
if residual is not None: if residual is not None:
@@ -179,37 +210,72 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
output_sglang = output_sglang[0] output_sglang = output_sglang[0]
print(f"Naive output={output_naive}") print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}") if FLASHINFER_AVAILABLE:
print(f"VLLM output={output_vllm}") print(f"FlashInfer output={output_flashinfer}")
else:
print("FlashInfer not available, skipped")
if VLLM_AVAILABLE:
print(f"VLLM output={output_vllm}")
else:
print("vLLM not available, skipped")
print(f"SGLang output={output_sglang}") print(f"SGLang output={output_sglang}")
if ( # Only compare available implementations
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2) all_match = torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2) if FLASHINFER_AVAILABLE:
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2) all_match = all_match and torch.allclose(
): output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
print("✅ All implementations match") )
if VLLM_AVAILABLE:
all_match = all_match and torch.allclose(
output_naive, output_vllm, atol=1e-2, rtol=1e-2
)
if all_match:
print("✅ All available implementations match")
else: else:
print("❌ Implementations differ") print("❌ Implementations differ")
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64 # CI environment uses simplified parameters
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024 if IS_CI:
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144 default_batch_sizes = [1] # Single batch size for CI
default_seq_lens = [64] # Single sequence length for CI
default_hidden_sizes = [4096] # Single hidden size for CI
else:
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]: def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
return list(itertools.product(bsizes, slens, hsizes)) return list(itertools.product(bsizes, slens, hsizes))
# Filter providers based on availability
available_providers = ["huggingface", "sglang"]
available_names = ["HuggingFace", "SGL Kernel"]
available_styles = [("blue", "-"), ("orange", "-")]
if FLASHINFER_AVAILABLE:
available_providers.insert(-1, "flashinfer")
available_names.insert(-1, "FlashInfer")
available_styles.insert(-1, ("green", "-"))
if VLLM_AVAILABLE:
available_providers.insert(-1, "vllm")
available_names.insert(-1, "vLLM")
available_styles.insert(-1, ("red", "-"))
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "hidden_size"], x_names=["batch_size", "seq_len", "hidden_size"],
x_vals=[], x_vals=[],
line_arg="provider", line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm", "sglang"], line_vals=available_providers,
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"], line_names=available_names,
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")], styles=available_styles,
ylabel="µs (median) or × (speed-up)", ylabel="µs (median) or × (speed-up)",
plot_name="rmsnorm-performance", plot_name="rmsnorm-performance",
args={}, args={},
@@ -242,6 +308,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
) )
) )
elif provider == "flashinfer": elif provider == "flashinfer":
if not FLASHINFER_AVAILABLE:
return (0, 0, 0)
return timed( return timed(
lambda: rmsnorm_flashinfer( lambda: rmsnorm_flashinfer(
x.clone(), x.clone(),
@@ -250,6 +318,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
) )
) )
elif provider == "vllm": elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
return timed( return timed(
lambda: rmsnorm_vllm( lambda: rmsnorm_vllm(
x.clone(), x.clone(),
@@ -267,13 +337,22 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
) )
# provider == "speedup" # provider == "speedup"
t_ref, _, _ = timed( if VLLM_AVAILABLE:
lambda: rmsnorm_vllm( t_ref, _, _ = timed(
x.clone(), lambda: rmsnorm_vllm(
weight, x.clone(),
residual.clone() if residual is not None else None, weight,
residual.clone() if residual is not None else None,
)
)
else:
t_ref, _, _ = timed(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
) )
)
t_sgl, _, _ = timed( t_sgl, _, _ = timed(
lambda: rmsnorm_sglang( lambda: rmsnorm_sglang(
x.clone(), x.clone(),
@@ -281,7 +360,7 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
residual.clone() if residual is not None else None, residual.clone() if residual is not None else None,
) )
) )
spd = t_ref / t_sgl spd = t_ref / t_sgl if t_ref > 0 else 1.0
return (spd, spd, spd) return (spd, spd, spd)

View File

@@ -1,4 +1,5 @@
import itertools import itertools
import os
import torch import torch
import triton import triton
@@ -12,17 +13,31 @@ from sgl_kernel.testing.rotary_embedding import (
from sglang.srt.bench_utils import bench_kineto from sglang.srt.bench_utils import bench_kineto
configs = [ # CI environment detection
(batch_size, seq_len, save_kv_cache) IS_CI = (
for batch_size, seq_len in ( os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# CI environment uses simplified parameters
if IS_CI:
batch_seq_configs = [(1, 1)] # Single config for CI
save_kv_configs = [False] # Single option for CI
else:
batch_seq_configs = [
(1, 1), (1, 1),
(32, 1), (32, 1),
(128, 1), (128, 1),
(512, 1), (512, 1),
(2, 512), (2, 512),
(4, 4096), (4, 4096),
) ]
for save_kv_cache in (False, True) save_kv_configs = [False, True]
configs = [
(batch_size, seq_len, save_kv_cache)
for batch_size, seq_len in batch_seq_configs
for save_kv_cache in save_kv_configs
] ]

View File

@@ -1,10 +1,17 @@
import itertools import itertools
import os
import sgl_kernel import sgl_kernel
import torch import torch
import triton import triton
import triton.testing import triton.testing
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def torch_top_k_top_p_joint_sampling_from_probs( def torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k, top_p, eps=1e-4 normalized_prob, top_k, top_p, eps=1e-4
@@ -67,10 +74,16 @@ def calculate_diff(batch_size, vocab_size, p):
) )
# parameter space # parameter space - simplified for CI
batch_size_range = [16, 64, 128] if IS_CI:
vocab_size_range = [111, 32000] batch_size_range = [16] # Single batch size for CI
p_range = [0.1, 0.5] vocab_size_range = [111] # Single vocab size for CI
p_range = [0.1] # Single p value for CI
else:
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range)) configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
@@ -114,15 +127,19 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
filter_apply_order="joint", filter_apply_order="joint",
) )
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__": if __name__ == "__main__":
# Correctness check # Correctness check - simplified for CI
for cfg in configs: if IS_CI:
# Only test one configuration in CI
test_configs = [configs[0]] if configs else [(16, 111, 0.1)]
else:
test_configs = configs
for cfg in test_configs:
calculate_diff(*cfg) calculate_diff(*cfg)
print("\n" + "=" * 60) print("\n" + "=" * 60)