diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index bc7d48594..cd13f6e9a 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -155,6 +155,50 @@ jobs: cd test/srt python3 test_mla_deepseek_v3.py + sgl-kernel-benchmark-test: + needs: [check-changes, sgl-kernel-build-wheels] + if: always() && !failure() && !cancelled() + runs-on: 1-gpu-runner + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + CI: true + steps: + - uses: actions/checkout@v4 + + - name: Cleanup + run: | + ls -alh sgl-kernel/dist || true + rm -rf sgl-kernel/dist/* || true + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-python3.10-cuda12.9 + + - name: Install dependencies + run: | + CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/ci_install_dependency.sh + + - name: Run benchmark tests + timeout-minutes: 45 + run: | + cd sgl-kernel/benchmark + echo "Running sgl-kernel benchmark tests in CI mode..." + + echo "CI environment variable: $CI" + echo "GITHUB_ACTIONS environment variable: $GITHUB_ACTIONS" + + for bench_file in bench_*.py; do + echo "Testing $bench_file..." + timeout 60 python3 "$bench_file" || echo "Warning: $bench_file timed out or failed, continuing..." + echo "Completed $bench_file" + echo "---" + done + + echo "All benchmark tests completed!" + # =============================================== primary ==================================================== unit-test-frontend: @@ -647,7 +691,7 @@ jobs: check-changes, sgl-kernel-build-wheels, - sgl-kernel-unit-test, sgl-kernel-mla-test, + sgl-kernel-unit-test, sgl-kernel-mla-test, sgl-kernel-benchmark-test, unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, unit-test-backend-4-gpu, unit-test-backend-8-gpu, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2dd68dab4..8038ccf8a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2460,7 +2460,7 @@ class BumpAllocator: def log_info_on_rank0(logger, msg): from sglang.srt.distributed import get_tensor_model_parallel_rank - if get_tensor_model_parallel_rank() == 0: + if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0: logger.info(msg) diff --git a/sgl-kernel/benchmark/bench_activation.py b/sgl-kernel/benchmark/bench_activation.py index 0c59cceee..3caa5b936 100644 --- a/sgl-kernel/benchmark/bench_activation.py +++ b/sgl-kernel/benchmark/bench_activation.py @@ -2,6 +2,7 @@ # (kernel, dtype, batch_size, seq_len, dim) and prints speed-up. import argparse import itertools +import os import re from typing import List, Tuple @@ -11,7 +12,21 @@ import torch.nn.functional as F import triton import triton.testing from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul -from vllm import _custom_ops as vllm_ops + +# Optional vLLM import +try: + from vllm import _custom_ops as vllm_ops + + VLLM_AVAILABLE = True +except ImportError: + vllm_ops = None + VLLM_AVAILABLE = False + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) # gelu_quick is only available on HIP/ROCm platforms try: @@ -22,7 +37,7 @@ except ImportError: GELU_QUICK_AVAILABLE = False gelu_quick = None -if not hasattr(vllm_ops, "silu_and_mul"): +if VLLM_AVAILABLE and not hasattr(vllm_ops, "silu_and_mul"): vllm_ops = torch.ops._C @@ -40,6 +55,13 @@ def calculate_diff( """Compare vLLM with SGLang for one shape.""" device = torch.device("cuda") + if not VLLM_AVAILABLE: + print( + f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | " + f"L={seq_len:3d} | D={dim:5d}] ⚠️ vLLM not available, skipping comparison" + ) + return True + # activation-only quick GELU if kernel == "gelu_quick": if not GELU_QUICK_AVAILABLE: @@ -68,19 +90,30 @@ def calculate_diff( return ok -kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"] -if GELU_QUICK_AVAILABLE: - kernels.append("gelu_quick") -dtypes = [torch.float16, torch.bfloat16] +# CI environment uses simplified parameters for kernels and dtypes too +if IS_CI: + kernels = ["silu_and_mul"] # Only test one kernel in CI + dtypes = [torch.float16] # Only test one dtype in CI +else: + kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"] + if GELU_QUICK_AVAILABLE: + kernels.append("gelu_quick") + dtypes = [torch.float16, torch.bfloat16] def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]: return list(itertools.product(kernels, dtypes, bsizes, slens, dims_)) -default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16 -default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64 -default_dims = [2**i for i in range(10, 15)] # 1024...16384 +# CI environment uses simplified parameters +if IS_CI: + default_batch_sizes = [1] # Single batch size for CI + default_seq_lens = [1] # Single sequence length for CI + default_dims = [1024] # Single dimension for CI +else: + default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16 + default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64 + default_dims = [2**i for i in range(10, 15)] # 1024...16384 @triton.testing.perf_report( @@ -102,16 +135,24 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device) y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) - vllm_kernel = getattr(vllm_ops, kernel) + if not VLLM_AVAILABLE and provider in ["vllm", "speedup"]: + # Skip vLLM-related benchmarks if vLLM is not available + return (0, 0, 0) + + if VLLM_AVAILABLE: + vllm_kernel = getattr(vllm_ops, kernel) if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE: # Skip benchmark for gelu_quick if not available return (0, 0, 0) sglang_kernel = getattr(sgl_kernel, kernel) def baseline(): - tmp = y0.clone() - vllm_kernel(tmp, x) - return tmp + if VLLM_AVAILABLE: + tmp = y0.clone() + vllm_kernel(tmp, x) + return tmp + else: + return torch.zeros_like(y0) def sglang(): return sglang_kernel(x) @@ -134,7 +175,7 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): # provider == "speedup" t_ref, _, _ = timed(baseline) t_sgl, _, _ = timed(sglang) - spd = t_ref / t_sgl + spd = t_ref / t_sgl if t_ref > 0 else 1.0 return (spd, spd, spd) diff --git a/sgl-kernel/benchmark/bench_awq_dequant.py b/sgl-kernel/benchmark/bench_awq_dequant.py index a906894c9..6bd03ab8a 100644 --- a/sgl-kernel/benchmark/bench_awq_dequant.py +++ b/sgl-kernel/benchmark/bench_awq_dequant.py @@ -1,16 +1,34 @@ import itertools +import os from typing import List, Tuple import torch import triton import triton.testing from sgl_kernel import awq_dequantize -from vllm import _custom_ops as ops + +# Optional vLLM import +try: + from vllm import _custom_ops as ops + + VLLM_AVAILABLE = True +except ImportError: + ops = None + VLLM_AVAILABLE = False + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) def vllm_awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + if not VLLM_AVAILABLE: + # Fallback to SGLang implementation + return sglang_awq_dequantize(qweight, scales, qzeros) return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) @@ -43,6 +61,10 @@ def calculate_diff(qweight_row: int, qweight_col: int): device=device, ) + if not VLLM_AVAILABLE: + print("⚠️ vLLM not available, skipping comparison") + return + vllm_out = vllm_awq_dequantize(qweight, scales, qzeros) sglang_out = sglang_awq_dequantize(qweight, scales, qzeros) @@ -56,8 +78,13 @@ def calculate_diff(qweight_row: int, qweight_col: int): print("❌ Implementations differ") -qweight_row_range = [3584, 18944, 128, 256, 512, 1024] -qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128] +# CI environment uses simplified parameters +if IS_CI: + qweight_row_range = [128] # Single row size for CI + qweight_cols_range = [16] # Single column size for CI +else: + qweight_row_range = [3584, 18944, 128, 256, 512, 1024] + qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128] configs = list(itertools.product(qweight_row_range, qweight_cols_range)) @@ -67,9 +94,9 @@ configs = list(itertools.product(qweight_row_range, qweight_cols_range)) x_names=["qweight_row", "qweight_col"], x_vals=configs, line_arg="provider", - line_vals=["vllm", "sglang"], - line_names=["VLLM", "SGL Kernel"], - styles=[("blue", "-"), ("green", "-")], + line_vals=["vllm", "sglang"] if VLLM_AVAILABLE else ["sglang"], + line_names=["VLLM", "SGL Kernel"] if VLLM_AVAILABLE else ["SGL Kernel"], + styles=[("blue", "-"), ("green", "-")] if VLLM_AVAILABLE else [("green", "-")], ylabel="us", plot_name="awq-dequantize-performance", args={}, @@ -100,6 +127,8 @@ def benchmark(qweight_row, qweight_col, provider): quantiles = [0.5, 0.2, 0.8] if provider == "vllm": + if not VLLM_AVAILABLE: + return (0, 0, 0) fn = lambda: vllm_awq_dequantize( qweight.clone(), scales.clone(), qzeros.clone() ) @@ -114,5 +143,11 @@ def benchmark(qweight_row, qweight_col, provider): if __name__ == "__main__": - calculate_diff(qweight_row=3584, qweight_col=448) + # Simplify for CI environment + if IS_CI: + qweight_row, qweight_col = 128, 16 # Smaller values for CI + else: + qweight_row, qweight_col = 3584, 448 + + calculate_diff(qweight_row=qweight_row, qweight_col=qweight_col) benchmark.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_cutlass_mla.py b/sgl-kernel/benchmark/bench_cutlass_mla.py index e01fdc110..6947f309d 100644 --- a/sgl-kernel/benchmark/bench_cutlass_mla.py +++ b/sgl-kernel/benchmark/bench_cutlass_mla.py @@ -1,13 +1,27 @@ import argparse import copy import itertools +import os import torch import triton from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size -bs_range = [1, 8, 32, 64, 128, 256] -qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192] +from sglang.srt.utils import get_device_capability + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +# CI environment uses simplified parameters +if IS_CI: + bs_range = [1] # Single batch size for CI + qlen_range = [64] # Single sequence length for CI +else: + bs_range = [1, 8, 32, 64, 128, 256] + qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192] configs = list(itertools.product(bs_range, qlen_range)) @@ -131,13 +145,34 @@ if __name__ == "__main__": ) args = parser.parse_args() - for block_size in args.block_sizes: - for kv_split in args.num_kv_splits: - print(f"block_size={block_size}, num_kv_splits={kv_split}: ") - benchmark.run( - print_data=True, - block_size=block_size, - num_kv_splits=kv_split, - ) - - print("Benchmark finished!") + # Skip in CI environment or unsupported architectures + if IS_CI: + major, minor = get_device_capability() + if major is None or major < 10: # Requires compute capability 10.0+ + print("Skipping Cutlass MLA benchmark in CI environment") + if major is not None: + print( + f"Cutlass MLA requires compute capability 10.0+, but found {major}.{minor}" + ) + else: + print("Could not determine device capability") + else: + for block_size in args.block_sizes: + for kv_split in args.num_kv_splits: + print(f"block_size={block_size}, num_kv_splits={kv_split}: ") + benchmark.run( + print_data=True, + block_size=block_size, + num_kv_splits=kv_split, + ) + print("Benchmark finished!") + else: + for block_size in args.block_sizes: + for kv_split in args.num_kv_splits: + print(f"block_size={block_size}, num_kv_splits={kv_split}: ") + benchmark.run( + print_data=True, + block_size=block_size, + num_kv_splits=kv_split, + ) + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py b/sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py index a00c59ad0..bdf7f85de 100644 --- a/sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py +++ b/sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py @@ -1,4 +1,11 @@ import argparse +import os + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) import torch import torch.nn.functional as F @@ -6,16 +13,28 @@ import triton import triton.testing from sgl_kernel import dsv3_fused_a_gemm +# CI environment uses simplified parameters +if IS_CI: + num_tokens_vals = [1] # Only test 1 value in CI + line_vals = ["sgl-kernel"] # Only test sgl-kernel implementation in CI +else: + num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode + line_vals = ["torch", "sgl-kernel"] + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["num_tokens"], - x_vals=[i + 1 for i in range(16)], + x_vals=num_tokens_vals, x_log=False, line_arg="impl", - line_vals=["torch", "sgl-kernel"], - line_names=["torch (bf16)", "dsv3_fused_a_gemm"], - styles=[("blue", "-"), ("orange", "-")], + line_vals=line_vals, + line_names=( + ["torch (bf16)", "dsv3_fused_a_gemm"] + if not IS_CI + else ["dsv3_fused_a_gemm"] + ), + styles=[("blue", "-"), ("orange", "-")] if not IS_CI else [("orange", "-")], ylabel="TFLOPs", plot_name="bf16 dsv3 fused a GEMM throughput", args={}, diff --git a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py index b699d2580..2daee279f 100644 --- a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py +++ b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py @@ -1,4 +1,11 @@ import argparse +import os + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) import torch import torch.nn.functional as F @@ -6,21 +13,37 @@ import triton import triton.testing from sgl_kernel import dsv3_router_gemm +# CI environment uses simplified parameters +if IS_CI: + num_tokens_vals = [1] # Only test 1 value in CI + line_vals = ["sgl-kernel-256"] # Only test one implementation in CI +else: + num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode + line_vals = ["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"] + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["num_tokens"], - x_vals=[i + 1 for i in range(16)], + x_vals=num_tokens_vals, x_log=False, line_arg="impl", - line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"], - line_names=[ - "torch-256", - "dsv3_router_gemm-256", - "torch-384", - "dsv3_router_gemm-384", - ], - styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")], + line_vals=line_vals, + line_names=( + [ + "torch-256", + "dsv3_router_gemm-256", + "torch-384", + "dsv3_router_gemm-384", + ] + if not IS_CI + else ["dsv3_router_gemm-256"] + ), + styles=( + [("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")] + if not IS_CI + else [("orange", "-")] + ), ylabel="TFLOPs", plot_name="input-bf16-output-bf16 dsv3 router gemm throughput", args={}, @@ -64,17 +87,25 @@ def benchmark_bf16_output(num_tokens, impl): @triton.testing.perf_report( triton.testing.Benchmark( x_names=["num_tokens"], - x_vals=[i + 1 for i in range(16)], + x_vals=num_tokens_vals, x_log=False, line_arg="impl", - line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"], - line_names=[ - "torch-256", - "dsv3_router_gemm-256", - "torch-384", - "dsv3_router_gemm-384", - ], - styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")], + line_vals=line_vals, + line_names=( + [ + "torch-256", + "dsv3_router_gemm-256", + "torch-384", + "dsv3_router_gemm-384", + ] + if not IS_CI + else ["dsv3_router_gemm-256"] + ), + styles=( + [("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")] + if not IS_CI + else [("orange", "-")] + ), ylabel="TFLOPs", plot_name="input-bf16-output-fp32 dsv3 router gemm throughput", args={}, diff --git a/sgl-kernel/benchmark/bench_fp4_gemm.py b/sgl-kernel/benchmark/bench_fp4_gemm.py index e8d467702..0323fde22 100755 --- a/sgl-kernel/benchmark/bench_fp4_gemm.py +++ b/sgl-kernel/benchmark/bench_fp4_gemm.py @@ -2,6 +2,7 @@ import argparse import copy import csv import itertools +import os import pytest import torch @@ -9,6 +10,14 @@ import triton from flashinfer import mm_fp4 from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant +from sglang.srt.utils import get_device_capability + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + FLOAT4_E2M1_MAX = 6.0 FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max @@ -33,27 +42,34 @@ def get_weight_shapes(args): ] +# CI environment uses simplified parameters +if IS_CI: + batch_sizes = [1, 8] # Simplified for CI +else: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 3072, + 4096, + 8192, + 16384, + ] + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], - x_vals=[ - 1, - 2, - 4, - 8, - 16, - 32, - 64, - 128, - 256, - 512, - 1024, - 2048, - 3072, - 4096, - 8192, - 16384, - ], + x_vals=batch_sizes, # x_vals = [64], x_log=False, line_arg="provider", @@ -188,21 +204,38 @@ if __name__ == "__main__": ) args = parser.parse_args() + # Simplify for CI environment + if IS_CI: + args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size + if args.csv: with open(args.csv, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["provider", "m", "n", "k", "time_ms"]) - NKs = get_weight_shapes(args) - for N, K in NKs: - print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") - benchmark.run( - print_data=True, - N=N, - K=K, - dtype=args.dtype, - correctness=args.correctness, - csv_file=args.csv, - ) + # Check architecture compatibility - FP4 operations require sm100a/sm103a + major, minor = get_device_capability() + if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a) + print("Skipping FP4 GEMM benchmark") + if major is not None: + print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}") + else: + print("Could not determine device capability") + else: + NKs = get_weight_shapes(args) - print("Benchmark finished!") + # Limit iterations in CI + if IS_CI: + NKs = NKs[:2] # Only test first 2 shapes in CI + + for N, K in NKs: + print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") + benchmark.run( + print_data=True, + N=N, + K=K, + dtype=args.dtype, + correctness=args.correctness, + csv_file=args.csv, + ) + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py b/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py index 8843457f9..70766df94 100644 --- a/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py @@ -1,18 +1,33 @@ import argparse import copy import itertools +import os import deep_gemm import torch import triton from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor from sgl_kernel import fp8_blockwise_scaled_mm -from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm + +# Optional vLLM import +try: + from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm + + VLLM_AVAILABLE = True +except ImportError: + vllm_scaled_mm = None + VLLM_AVAILABLE = False from sglang.srt.layers.quantization.fp8_kernel import ( w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul, ) +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + def get_weight_shapes(args): models_tps = list(itertools.product(args.models, args.tp_sizes)) @@ -80,15 +95,46 @@ def scale_shape(shape, group_shape): return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) +# CI environment uses simplified parameters +if IS_CI: + batch_sizes = [1, 8] # Simplified for CI +else: + batch_sizes = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] + +# Filter providers based on availability +available_providers = ["sgl-kernel"] +available_names = ["sgl-kernel"] +available_styles = [("orange", "-")] + +if VLLM_AVAILABLE: + available_providers.insert(0, "vllm") + available_names.insert(0, "vllm") + available_styles.insert(0, ("blue", "-")) + +available_providers.append("triton") +available_names.append("sglang triton") +available_styles.append(("red", "-")) + +# Add deepgemm if available +try: + import deep_gemm + + available_providers.append("deepgemm") + available_names.append("deepgemm") + available_styles.append(("yellow", "-")) +except ImportError: + pass + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], - x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], + x_vals=batch_sizes, x_log=False, line_arg="provider", - line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"], - line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"], - styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")], + line_vals=available_providers, + line_names=available_names, + styles=available_styles, ylabel="GB/s", plot_name="fp8 blockwise scaled matmul", args={}, @@ -123,14 +169,16 @@ def benchmark(batch_size, provider, N, K): ), quantiles=quantiles, ) - if provider == "vllm": + elif provider == "vllm": + if not VLLM_AVAILABLE: + return (0, 0, 0) scale_a = scale_a.t().contiguous().t() b_fp8, scale_b = b_fp8.t(), scale_b.t() ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16), quantiles=quantiles, ) - if provider == "triton": + elif provider == "triton": ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: w8a8_block_fp8_matmul( a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16 @@ -166,7 +214,17 @@ if __name__ == "__main__": ) args = parser.parse_args() + # Simplify for CI environment + if IS_CI: + args.models = [args.models[0]] # Use only first model + args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size + NK_model_names = get_weight_shapes(args) + + # Limit iterations in CI + if IS_CI: + NK_model_names = NK_model_names[:2] # Only test first 2 shapes in CI + for N, K, model_name in NK_model_names: if N % 128 != 0 or K % 128 != 0: print(f"Skip {N=}, {K=} now") diff --git a/sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py b/sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py index 6aa131244..19e425b52 100644 --- a/sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py @@ -1,4 +1,11 @@ import argparse +import os + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) import random from dataclasses import dataclass from typing import List, Tuple @@ -290,36 +297,44 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--num-warmup", type=int, default=3) parser.add_argument("--num-run", type=int, default=10) - shape_args = [ - # Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8 - ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256), - # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8 - ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256), - # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16 - ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256), - # Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16 - ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256), - # Decode, DeepSeek-R1, gateup, bs = 32, TP = 8 - ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256), - # Decode, DeepSeek-R1, gateup, bs = 64, TP = 16 - ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256), - # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8 - ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32), - # Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16 - ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16), - # Decode, DeepSeek-R1, gateup, bs = 128, EP = 8 - ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32), - # Decode, DeepSeek-R1, gateup, bs = 256, EP = 16 - ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16), - # Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4 - ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128), - # Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4 - ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128), - # Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4 - ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128), - # Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4 - ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128), - ] + + # CI environment uses simplified parameters + if IS_CI: + shape_args = [ + # Only test one simple shape in CI + ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256), + ] + else: + shape_args = [ + # Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8 + ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256), + # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8 + ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256), + # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16 + ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256), + # Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16 + ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256), + # Decode, DeepSeek-R1, gateup, bs = 32, TP = 8 + ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256), + # Decode, DeepSeek-R1, gateup, bs = 64, TP = 16 + ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256), + # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8 + ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32), + # Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16 + ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16), + # Decode, DeepSeek-R1, gateup, bs = 128, EP = 8 + ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32), + # Decode, DeepSeek-R1, gateup, bs = 256, EP = 16 + ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16), + # Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4 + ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128), + # Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4 + ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128), + # Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4 + ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128), + # Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4 + ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128), + ] args = parser.parse_args() benchmark_one_shape(shape_args, args.num_warmup, args.num_run) diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py index f94c32bec..a49f3b06f 100644 --- a/sgl-kernel/benchmark/bench_fp8_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -1,14 +1,30 @@ import argparse import copy import itertools +import os from typing import Optional, Tuple import torch import triton from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm from sgl_kernel import sgl_per_tensor_quant_fp8 -from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm -from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant + +# Optional vLLM import +try: + from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm + from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant + + VLLM_AVAILABLE = True +except ImportError: + vllm_scaled_mm = None + vllm_scaled_fp8_quant = None + VLLM_AVAILABLE = False + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) # Weight Shapes are in the format # ([K, N], TP_SPLIT_DIM) @@ -86,25 +102,48 @@ def sglang_scaled_fp8_quant( return output, scale +# CI environment uses simplified parameters +if IS_CI: + batch_sizes = [1] # Single batch size for CI +else: + batch_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048] + +# Filter line_vals based on vLLM availability +if VLLM_AVAILABLE: + line_vals = [ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ] + line_names = [ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ] + styles = [("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")] +else: + line_vals = [ + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ] + line_names = [ + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ] + styles = [("blue", "-"), ("blue", "--")] + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], - x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], + x_vals=batch_sizes, x_log=False, line_arg="provider", - line_vals=[ - "vllm-fp8-fp16", - "vllm-fp8-bf16", - "sglang-fp8-fp16", - "sglang-fp8-bf16", - ], - line_names=[ - "vllm-fp8-fp16", - "vllm-fp8-bf16", - "sglang-fp8-fp16", - "sglang-fp8-bf16", - ], - styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], + line_vals=line_vals, + line_names=line_names, + styles=styles, ylabel="GB/s", plot_name="fp8 scaled matmul", args={}, @@ -115,6 +154,9 @@ def benchmark(batch_size, provider, N, K): M = batch_size a = torch.ones((M, K), device="cuda") * 5.0 b = torch.ones((N, K), device="cuda") * 5.0 + # vLLM expects scalar scales, while sglang can handle per-token scales + scale_a_scalar = torch.randn(1, device="cuda", dtype=torch.float32) + scale_b_scalar = torch.randn(1, device="cuda", dtype=torch.float32) scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] @@ -122,8 +164,11 @@ def benchmark(batch_size, provider, N, K): dtype = torch.float16 if "fp16" in provider else torch.bfloat16 if "vllm-fp8" in provider: - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + if not VLLM_AVAILABLE: + # Return zero if vLLM is not available + return (0, 0, 0) + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_scalar) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b_scalar) b_fp8 = b_fp8.t() ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), @@ -174,6 +219,11 @@ if __name__ == "__main__": ) args = parser.parse_args() + # Simplify for CI environment + if IS_CI: + args.models = [args.models[0]] # Use only first model + args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size + KN_model_names = prepare_shapes(args) for K, N, model_name in KN_model_names: print(f"{model_name} N={N} K={K}: ") diff --git a/sgl-kernel/benchmark/bench_int8_gemm.py b/sgl-kernel/benchmark/bench_int8_gemm.py index 8a3c5d3f3..95f0f3bb8 100644 --- a/sgl-kernel/benchmark/bench_int8_gemm.py +++ b/sgl-kernel/benchmark/bench_int8_gemm.py @@ -1,11 +1,26 @@ import argparse import copy import itertools +import os import torch import triton from sgl_kernel import int8_scaled_mm -from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm + +# Optional vLLM import +try: + from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm + + VLLM_AVAILABLE = True +except ImportError: + vllm_scaled_mm = None + VLLM_AVAILABLE = False + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) def to_int8(tensor: torch.Tensor) -> torch.Tensor: @@ -62,15 +77,32 @@ WEIGHT_SHAPES = { } +# CI environment uses simplified parameters +if IS_CI: + batch_sizes = [1] # Single batch size for CI +else: + batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048] + +# Filter providers based on vLLM availability +if VLLM_AVAILABLE: + line_vals = ["vllm", "sgl-kernel"] + line_names = ["vllm int8 gemm", "sgl-kernel int8 gemm"] + styles = [("blue", "-"), ("orange", "-")] +else: + line_vals = ["sgl-kernel"] + line_names = ["sgl-kernel int8 gemm"] + styles = [("orange", "-")] + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], - x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], + x_vals=batch_sizes, x_log=False, line_arg="provider", - line_vals=["vllm", "sgl-kernel"], - line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"], - styles=[("blue", "-"), ("orange", "-")], + line_vals=line_vals, + line_names=line_names, + styles=styles, ylabel="GB/s", plot_name="int8 scaled matmul", args={}, @@ -90,7 +122,9 @@ def benchmark(batch_size, provider, N, K): lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), quantiles=quantiles, ) - if provider == "vllm": + elif provider == "vllm": + if not VLLM_AVAILABLE: + return (0, 0, 0) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), quantiles=quantiles, @@ -136,9 +170,16 @@ if __name__ == "__main__": ) args = parser.parse_args() - KN_model_names = prepare_shapes(args) - for K, N, model_name in KN_model_names: - print(f"{model_name} N={N} K={K}: ") - benchmark.run(print_data=True, N=N, K=K) + # Skip in CI environment due to architecture compatibility issues + if IS_CI: + print( + "Skipping INT8 GEMM benchmark in CI environment due to architecture compatibility issues" + ) + print("INT8 operations may not be supported on all GPU architectures") + else: + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run(print_data=True, N=N, K=K) - print("Benchmark finished!") + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_lightning_attention_decode.py b/sgl-kernel/benchmark/bench_lightning_attention_decode.py index 6097e966d..db0ef05bd 100644 --- a/sgl-kernel/benchmark/bench_lightning_attention_decode.py +++ b/sgl-kernel/benchmark/bench_lightning_attention_decode.py @@ -1,11 +1,18 @@ import itertools import math +import os import torch import triton import triton.language as tl from sgl_kernel import lightning_attention_decode +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + def next_power_of_2(n): return 2 ** (int(math.ceil(math.log(n, 2)))) @@ -207,7 +214,12 @@ def calculate_diff(batch_size): print("❌ Implementations differ") -batch_size_range = [i for i in range(1, 65)] # 1 to 128 +# Simplified for CI environment +if IS_CI: + batch_size_range = [1] # Single batch size for CI +else: + batch_size_range = [i for i in range(1, 65)] # 1 to 64 + configs = [(bs,) for bs in batch_size_range] @@ -292,8 +304,9 @@ if __name__ == "__main__": ) args = parser.parse_args() - # Run correctness test - calculate_diff(batch_size=4) + # Run correctness test - simplified for CI + test_batch_size = 1 if IS_CI else 4 + calculate_diff(batch_size=test_batch_size) # Run performance benchmark benchmark.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index 749d97d8f..2156c5cd4 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -1,5 +1,6 @@ import argparse import itertools +import os import torch import triton @@ -8,8 +9,17 @@ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size try: from vllm import _custom_ops as ops + + VLLM_AVAILABLE = True except ImportError: ops = None + VLLM_AVAILABLE = False + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) USE_RANDOM_PERM = False @@ -197,19 +207,23 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): num_tokens_post_pad_triton, ) - try: - ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids_vllm, - expert_ids_vllm, - num_tokens_post_pad_vllm, - ) - print(f"✅ VLLM implementation works with {num_experts} experts!") - vllm_works = True - except Exception as e: - print(f"❌ VLLM implementation failed with {num_experts} experts: {e}") + if VLLM_AVAILABLE: + try: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids_vllm, + expert_ids_vllm, + num_tokens_post_pad_vllm, + ) + print(f"✅ VLLM implementation works with {num_experts} experts!") + vllm_works = True + except Exception as e: + print(f"❌ VLLM implementation failed with {num_experts} experts: {e}") + vllm_works = False + else: + print("⚠️ vLLM not available, skipping vLLM test") vllm_works = False if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( @@ -394,8 +408,18 @@ if __name__ == "__main__": ) args = parser.parse_args() - calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) + # Simplify for CI environment + if IS_CI: + num_tokens = 256 # Smaller for CI + num_experts = 8 # Smaller for CI + topk = 2 # Smaller for CI + else: + num_tokens = 1024 + num_experts = args.num_experts + topk = args.topk - if not args.skip_full_benchmark: + calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk) + + if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI print(f"\n📊 Running performance benchmark for {args.num_experts} experts...") benchmark.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py index 38b07e83a..2a617d72d 100644 --- a/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py +++ b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py @@ -1,9 +1,22 @@ +import os + import torch + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) import triton from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel -batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096] +# CI environment uses simplified parameters +if IS_CI: + batch_sizes = [64, 128] # Only test 2 values in CI +else: + batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096] + configs = [(bs,) for bs in batch_sizes] diff --git a/sgl-kernel/benchmark/bench_moe_fused_gate.py b/sgl-kernel/benchmark/bench_moe_fused_gate.py index 4455a91b6..cb5ac1760 100644 --- a/sgl-kernel/benchmark/bench_moe_fused_gate.py +++ b/sgl-kernel/benchmark/bench_moe_fused_gate.py @@ -1,5 +1,6 @@ import itertools import math +import os import torch import triton @@ -8,6 +9,12 @@ from sgl_kernel import moe_fused_gate from sglang.srt.layers.moe.topk import biased_grouped_topk +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk): return biased_grouped_topk( @@ -28,7 +35,12 @@ def biased_grouped_topk_org_fuse_kernel( return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk) -seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000] +# CI environment uses simplified parameters +if IS_CI: + seq_length_range = [5000] # Only test one sequence length in CI +else: + seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000] + configs = [(sq,) for sq in seq_length_range] diff --git a/sgl-kernel/benchmark/bench_moe_topk_softmax.py b/sgl-kernel/benchmark/bench_moe_topk_softmax.py index 36d89eb82..e065981b8 100644 --- a/sgl-kernel/benchmark/bench_moe_topk_softmax.py +++ b/sgl-kernel/benchmark/bench_moe_topk_softmax.py @@ -1,13 +1,32 @@ import itertools +import os import pytest import torch import triton from sgl_kernel import topk_softmax -from vllm import _custom_ops as vllm_custom_ops + +# Optional vLLM import +try: + from vllm import _custom_ops as vllm_custom_ops + + VLLM_AVAILABLE = True +except ImportError: + vllm_custom_ops = None + VLLM_AVAILABLE = False + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) def vllm_topk_softmax(gating_output, topk): + if not VLLM_AVAILABLE: + # Fallback to SGLang implementation if vLLM is not available + return sglang_topk_softmax(gating_output, topk) + num_tokens, num_experts = gating_output.shape topk_weights = torch.empty( @@ -54,6 +73,10 @@ def calculate_diff(num_tokens, num_experts, topk): weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item() indices_match = torch.equal(indices_vllm, indices_sglang) + if not VLLM_AVAILABLE: + print("⚠️ vLLM not available, skipping comparison") + return + if ( torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3) and indices_match @@ -65,21 +88,38 @@ def calculate_diff(num_tokens, num_experts, topk): ) -num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768] -num_experts_range = [32, 64, 128, 256, 12, 512] -topk_range = [1, 2, 4, 8] +# CI environment uses simplified parameters +if IS_CI: + num_tokens_range = [128] # Single value for CI + num_experts_range = [32] # Single value for CI + topk_range = [2] # Single value for CI +else: + num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768] + num_experts_range = [32, 64, 128, 256, 12, 512] + topk_range = [1, 2, 4, 8] configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) +# Filter providers based on vLLM availability +if VLLM_AVAILABLE: + line_vals = ["sglang", "vllm"] + line_names = ["SGLang", "VLLM"] + styles = [("blue", "-"), ("green", "-")] +else: + line_vals = ["sglang"] + line_names = ["SGLang"] + styles = [("blue", "-")] + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["num_tokens", "num_experts", "topk"], x_vals=configs, line_arg="provider", - line_vals=["sglang", "vllm"], - line_names=["SGLang", "VLLM"], - styles=[("blue", "-"), ("green", "-")], + line_vals=line_vals, + line_names=line_names, + styles=styles, ylabel="Latency (us)", plot_name="topk-softmax-performance", args={}, @@ -92,6 +132,8 @@ def benchmark(num_tokens, num_experts, topk, provider): ) if provider == "vllm" or provider == "vllm1": + if not VLLM_AVAILABLE: + return (0, 0, 0) fn = lambda: vllm_topk_softmax(gating_output, topk) elif provider == "sglang" or provider == "sglang1": fn = lambda: sglang_topk_softmax(gating_output, topk) @@ -103,14 +145,19 @@ def benchmark(num_tokens, num_experts, topk, provider): if __name__ == "__main__": - configs = [ - (20, 256, 4), - (20, 256, 8), - (20, 12, 4), - (20, 12, 1), - (20, 512, 4), - (20, 512, 1), - ] - for num_tokens, num_experts, topk in configs: + # Simplify configs for CI environment + if IS_CI: + test_configs = [(20, 32, 2)] # Single config for CI + else: + test_configs = [ + (20, 256, 4), + (20, 256, 8), + (20, 12, 4), + (20, 12, 1), + (20, 512, 4), + (20, 512, 1), + ] + + for num_tokens, num_experts, topk in test_configs: calculate_diff(num_tokens, num_experts, topk) benchmark.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py b/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py index e7fdcc7ae..3867f6093 100644 --- a/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py +++ b/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py @@ -1,11 +1,20 @@ import argparse import copy import itertools +import os import torch import triton from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant +from sglang.srt.utils import get_device_capability + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + FLOAT4_E2M1_MAX = 6.0 FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max @@ -162,9 +171,22 @@ if __name__ == "__main__": ) args = parser.parse_args() - KN_model_names = prepare_shapes(args) - for K, N, model_name in KN_model_names: - print(f"{model_name} N={N} K={K}: ") - benchmark.run(print_data=True, N=N, K=K) + # Check architecture compatibility - FP4 operations require sm100a/sm103a + major, minor = get_device_capability() + if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a) + print("Skipping NVIDIA FP4 scaled GEMM benchmark") + if major is not None: + print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}") + else: + print("Could not determine device capability") + else: + KN_model_names = prepare_shapes(args) - print("Benchmark finished!") + # Limit iterations in CI + if IS_CI: + KN_model_names = KN_model_names[:2] # Only test first 2 shapes in CI + + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run(print_data=True, N=N, K=K) + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py b/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py index d42007819..ead9e9aa1 100644 --- a/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py @@ -1,5 +1,6 @@ import itertools import math +import os from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -7,11 +8,26 @@ import torch import triton import triton.testing from sgl_kernel import sgl_per_tensor_quant_fp8 -from vllm import _custom_ops as ops + +# Optional imports +try: + from vllm import _custom_ops as ops + + VLLM_AVAILABLE = True +except ImportError: + ops = None + VLLM_AVAILABLE = False from sglang.srt.utils import is_hip _is_hip = is_hip() + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn @@ -19,6 +35,9 @@ def vllm_scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + if not VLLM_AVAILABLE: + # Fallback to SGLang implementation + return sglang_scaled_fp8_quant(input, scale) return ops.scaled_fp8_quant(input, scale) @@ -42,6 +61,10 @@ def calculate_diff(batch_size: int, seq_len: int): device = torch.device("cuda") x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device) + if not VLLM_AVAILABLE: + print("⚠️ vLLM not available, skipping comparison") + return + vllm_out, vllm_scale = vllm_scaled_fp8_quant(x) sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) @@ -56,8 +79,13 @@ def calculate_diff(batch_size: int, seq_len: int): print("❌ Implementations differ") -batch_size_range = [16, 32, 64, 128] -seq_len_range = [64, 128, 256, 512, 1024, 2048] +# CI environment uses simplified parameters +if IS_CI: + batch_size_range = [16] # Single batch size for CI + seq_len_range = [64] # Single sequence length for CI +else: + batch_size_range = [16, 32, 64, 128] + seq_len_range = [64, 128, 256, 512, 1024, 2048] configs = list(itertools.product(batch_size_range, seq_len_range)) diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py index 3f37a3248..558b7486e 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py @@ -1,4 +1,5 @@ import itertools +import os import time from functools import partial from pathlib import Path @@ -16,15 +17,28 @@ from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit from sglang.srt.utils import is_hip +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn -num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] -hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 -group_size_range = [128] # For DeepSeek V3/R1 -# TODO test int8 -dst_dtype_range = [fp8_type_] +# CI environment uses simplified parameters +if IS_CI: + num_tokens_range = [64] # Single value for CI + hidden_dim_range = [1536] # Single value for CI + group_size_range = [128] # Keep as is + dst_dtype_range = [fp8_type_] # Keep as is +else: + num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] + hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 + group_size_range = [128] # For DeepSeek V3/R1 + # TODO test int8 + dst_dtype_range = [fp8_type_] flags_range = [ dict( column_major_scales=False, @@ -82,7 +96,7 @@ def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider): x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) fn, kernel_names = { - "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"), + "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_8bit"), "sglang": ( sglang_per_token_group_quant_8bit, "per_token_group_quant_8bit_kernel", diff --git a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py index e77410f92..8db1869d1 100644 --- a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py @@ -1,15 +1,31 @@ import itertools +import os from typing import Optional, Tuple import torch import triton import triton.testing from sgl_kernel import sgl_per_token_quant_fp8 -from vllm import _custom_ops as ops + +# Optional vLLM import +try: + from vllm import _custom_ops as ops + + VLLM_AVAILABLE = True +except ImportError: + ops = None + VLLM_AVAILABLE = False from sglang.srt.utils import is_hip _is_hip = is_hip() + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn # Get correct FP8 E4M3 maximum value @@ -49,6 +65,9 @@ def torch_per_token_quant_fp8( def vllm_per_token_quant_fp8( input: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: + if not VLLM_AVAILABLE: + # Fallback to SGLang implementation + return sglang_per_token_quant_fp8(input) return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True) @@ -74,6 +93,17 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int): vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) + if not VLLM_AVAILABLE: + print("⚠️ vLLM not available, skipping vLLM comparison") + # Only compare Torch vs SGLang + torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item() + torch_sglang_out_diff = ( + torch.abs(torch_out.float() - sglang_out.float()).mean().item() + ) + print(f"Scale difference (Torch vs SGLang): {torch_sglang_scale_diff:.8f}") + print(f"Output difference (Torch vs SGLang): {torch_sglang_out_diff:.8f}") + return + print(f"\n=== Comparison for hidden_dim={hidden_dim} ===") # Compare scales @@ -125,9 +155,15 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int): print(f" VLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}") -batch_size_range = [16, 32, 64, 128] -seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096] -hidden_dim_range = [1368, 2048, 4096] +# CI environment uses simplified parameters +if IS_CI: + batch_size_range = [16] # Single batch size for CI + seq_len_range = [64] # Single sequence length for CI + hidden_dim_range = [2048] # Single hidden dimension for CI +else: + batch_size_range = [16, 32, 64, 128] + seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096] + hidden_dim_range = [1368, 2048, 4096] configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range)) @@ -137,9 +173,19 @@ configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_ran x_names=["batch_size", "seq_len", "hidden_dim"], x_vals=configs, line_arg="provider", - line_vals=["torch", "vllm", "sglang"], - line_names=["Torch Reference", "VLLM", "SGL Kernel"], - styles=[("red", "-"), ("blue", "-"), ("green", "-")], + line_vals=( + ["torch", "vllm", "sglang"] if VLLM_AVAILABLE else ["torch", "sglang"] + ), + line_names=( + ["Torch Reference", "VLLM", "SGL Kernel"] + if VLLM_AVAILABLE + else ["Torch Reference", "SGL Kernel"] + ), + styles=( + [("red", "-"), ("blue", "-"), ("green", "-")] + if VLLM_AVAILABLE + else [("red", "-"), ("green", "-")] + ), ylabel="us", plot_name="per-token-dynamic-quant-fp8-performance", args={}, @@ -156,6 +202,8 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider): if provider == "torch": fn = lambda: torch_per_token_quant_fp8(x.clone()) elif provider == "vllm": + if not VLLM_AVAILABLE: + return (0, 0, 0) fn = lambda: vllm_per_token_quant_fp8(x.clone()) elif provider == "sglang": fn = lambda: sglang_per_token_quant_fp8(x.clone()) @@ -166,11 +214,16 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider): if __name__ == "__main__": - # Test various hidden dimensions for correctness - test_dims = [1368, 2048, 4096] + # Test various hidden dimensions for correctness - simplified for CI + if IS_CI: + test_dims = [2048] # Single dimension for CI + batch_size, seq_len = 4, 64 # Smaller values for CI + else: + test_dims = [1368, 2048, 4096] + batch_size, seq_len = 4, 4096 for dim in test_dims: - calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim) + calculate_diff(batch_size=batch_size, seq_len=seq_len, hidden_dim=dim) print("\n" + "=" * 60) print("Starting performance benchmark...") diff --git a/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py b/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py index bcf3d7fba..5827fa993 100644 --- a/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py +++ b/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py @@ -1,6 +1,7 @@ import argparse import copy import itertools +import os import torch import triton @@ -10,6 +11,12 @@ from sgl_kernel import ( qserve_w4a8_per_group_gemm, ) +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + def to_int8(tensor: torch.Tensor) -> torch.Tensor: return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) @@ -65,10 +72,17 @@ WEIGHT_SHAPES = { } +# CI environment uses simplified parameters +if IS_CI: + batch_sizes = [1, 16] # Simplified for CI +else: + batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048] + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], - x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], + x_vals=batch_sizes, x_log=False, line_arg="provider", line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"], @@ -184,13 +198,19 @@ if __name__ == "__main__": ) args = parser.parse_args() - KN_model_names = prepare_shapes(args) - for K, N, model_name in KN_model_names: - print(f"{model_name} N={N} K={K}: ") - benchmark.run( - print_data=True, - N=N, - K=K, - ) + # Skip in CI environment + if IS_CI: + print("Skipping QServe W4A8 GEMM benchmark in CI environment") + print("QServe operations may have compatibility issues in CI") + else: + KN_model_names = prepare_shapes(args) - print("Benchmark finished!") + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_rmsnorm.py b/sgl-kernel/benchmark/bench_rmsnorm.py index fc3b732c5..d521ab05f 100644 --- a/sgl-kernel/benchmark/bench_rmsnorm.py +++ b/sgl-kernel/benchmark/bench_rmsnorm.py @@ -2,6 +2,7 @@ # (batch_size, seq_len, hidden_size) and prints speed-up. import argparse import itertools +import os import re from typing import List, Optional, Tuple, Union @@ -10,9 +11,31 @@ import torch import torch.nn as nn import triton import triton.testing -from flashinfer.norm import fused_add_rmsnorm, rmsnorm from sgl_kernel.utils import is_arch_support_pdl -from vllm import _custom_ops as vllm_ops + +# Optional imports +try: + from flashinfer.norm import fused_add_rmsnorm, rmsnorm + + FLASHINFER_AVAILABLE = True +except ImportError: + fused_add_rmsnorm = None + rmsnorm = None + FLASHINFER_AVAILABLE = False + +try: + from vllm import _custom_ops as vllm_ops + + VLLM_AVAILABLE = True +except ImportError: + vllm_ops = None + VLLM_AVAILABLE = False + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) def str2int_list(arg: str) -> List[int]: @@ -79,6 +102,10 @@ def rmsnorm_flashinfer( residual: Optional[torch.Tensor] = None, eps: float = 1e-6, ): + if not FLASHINFER_AVAILABLE: + # Fallback to naive implementation if FlashInfer is not available + return rmsnorm_naive(x, weight, residual, eps) + orig_shape = x.shape x = x.view(-1, x.shape[-1]) if residual is not None: @@ -103,6 +130,10 @@ def rmsnorm_vllm( residual: Optional[torch.Tensor] = None, eps: float = 1e-6, ): + if not VLLM_AVAILABLE: + # Fallback to naive implementation if vLLM is not available + return rmsnorm_naive(x, weight, residual, eps) + orig_shape = x.shape x = x.view(-1, x.shape[-1]) if residual is not None: @@ -179,37 +210,72 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): output_sglang = output_sglang[0] print(f"Naive output={output_naive}") - print(f"FlashInfer output={output_flashinfer}") - print(f"VLLM output={output_vllm}") + if FLASHINFER_AVAILABLE: + print(f"FlashInfer output={output_flashinfer}") + else: + print("FlashInfer not available, skipped") + if VLLM_AVAILABLE: + print(f"VLLM output={output_vllm}") + else: + print("vLLM not available, skipped") print(f"SGLang output={output_sglang}") - if ( - torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2) - and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2) - and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2) - ): - print("✅ All implementations match") + # Only compare available implementations + all_match = torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2) + if FLASHINFER_AVAILABLE: + all_match = all_match and torch.allclose( + output_naive, output_flashinfer, atol=1e-2, rtol=1e-2 + ) + if VLLM_AVAILABLE: + all_match = all_match and torch.allclose( + output_naive, output_vllm, atol=1e-2, rtol=1e-2 + ) + + if all_match: + print("✅ All available implementations match") else: print("❌ Implementations differ") -default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64 -default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024 -default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144 +# CI environment uses simplified parameters +if IS_CI: + default_batch_sizes = [1] # Single batch size for CI + default_seq_lens = [64] # Single sequence length for CI + default_hidden_sizes = [4096] # Single hidden size for CI +else: + default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64 + default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024 + default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144 def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]: return list(itertools.product(bsizes, slens, hsizes)) +# Filter providers based on availability +available_providers = ["huggingface", "sglang"] +available_names = ["HuggingFace", "SGL Kernel"] +available_styles = [("blue", "-"), ("orange", "-")] + +if FLASHINFER_AVAILABLE: + available_providers.insert(-1, "flashinfer") + available_names.insert(-1, "FlashInfer") + available_styles.insert(-1, ("green", "-")) + +if VLLM_AVAILABLE: + available_providers.insert(-1, "vllm") + available_names.insert(-1, "vLLM") + available_styles.insert(-1, ("red", "-")) + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size", "seq_len", "hidden_size"], x_vals=[], line_arg="provider", - line_vals=["huggingface", "flashinfer", "vllm", "sglang"], - line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"], - styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")], + line_vals=available_providers, + line_names=available_names, + styles=available_styles, ylabel="µs (median) or × (speed-up)", plot_name="rmsnorm-performance", args={}, @@ -242,6 +308,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual): ) ) elif provider == "flashinfer": + if not FLASHINFER_AVAILABLE: + return (0, 0, 0) return timed( lambda: rmsnorm_flashinfer( x.clone(), @@ -250,6 +318,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual): ) ) elif provider == "vllm": + if not VLLM_AVAILABLE: + return (0, 0, 0) return timed( lambda: rmsnorm_vllm( x.clone(), @@ -267,13 +337,22 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual): ) # provider == "speedup" - t_ref, _, _ = timed( - lambda: rmsnorm_vllm( - x.clone(), - weight, - residual.clone() if residual is not None else None, + if VLLM_AVAILABLE: + t_ref, _, _ = timed( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ) + ) + else: + t_ref, _, _ = timed( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ) ) - ) t_sgl, _, _ = timed( lambda: rmsnorm_sglang( x.clone(), @@ -281,7 +360,7 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual): residual.clone() if residual is not None else None, ) ) - spd = t_ref / t_sgl + spd = t_ref / t_sgl if t_ref > 0 else 1.0 return (spd, spd, spd) diff --git a/sgl-kernel/benchmark/bench_rotary_embedding.py b/sgl-kernel/benchmark/bench_rotary_embedding.py index b4e0f5e0b..418fcd7dd 100644 --- a/sgl-kernel/benchmark/bench_rotary_embedding.py +++ b/sgl-kernel/benchmark/bench_rotary_embedding.py @@ -1,4 +1,5 @@ import itertools +import os import torch import triton @@ -12,17 +13,31 @@ from sgl_kernel.testing.rotary_embedding import ( from sglang.srt.bench_utils import bench_kineto -configs = [ - (batch_size, seq_len, save_kv_cache) - for batch_size, seq_len in ( +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +# CI environment uses simplified parameters +if IS_CI: + batch_seq_configs = [(1, 1)] # Single config for CI + save_kv_configs = [False] # Single option for CI +else: + batch_seq_configs = [ (1, 1), (32, 1), (128, 1), (512, 1), (2, 512), (4, 4096), - ) - for save_kv_cache in (False, True) + ] + save_kv_configs = [False, True] + +configs = [ + (batch_size, seq_len, save_kv_cache) + for batch_size, seq_len in batch_seq_configs + for save_kv_cache in save_kv_configs ] diff --git a/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py b/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py index 254584837..278356c38 100644 --- a/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py +++ b/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py @@ -1,10 +1,17 @@ import itertools +import os import sgl_kernel import torch import triton import triton.testing +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + def torch_top_k_top_p_joint_sampling_from_probs( normalized_prob, top_k, top_p, eps=1e-4 @@ -67,10 +74,16 @@ def calculate_diff(batch_size, vocab_size, p): ) -# parameter space -batch_size_range = [16, 64, 128] -vocab_size_range = [111, 32000] -p_range = [0.1, 0.5] +# parameter space - simplified for CI +if IS_CI: + batch_size_range = [16] # Single batch size for CI + vocab_size_range = [111] # Single vocab size for CI + p_range = [0.1] # Single p value for CI +else: + batch_size_range = [16, 64, 128] + vocab_size_range = [111, 32000] + p_range = [0.1, 0.5] + configs = list(itertools.product(batch_size_range, vocab_size_range, p_range)) @@ -114,15 +127,19 @@ def benchmark_sampling(batch_size, vocab_size, p, provider): filter_apply_order="joint", ) - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - fn, quantiles=[0.5, 0.2, 0.8] - ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) return 1000 * ms, 1000 * max_ms, 1000 * min_ms if __name__ == "__main__": - # Correctness check - for cfg in configs: + # Correctness check - simplified for CI + if IS_CI: + # Only test one configuration in CI + test_configs = [configs[0]] if configs else [(16, 111, 0.1)] + else: + test_configs = configs + + for cfg in test_configs: calculate_diff(*cfg) print("\n" + "=" * 60)