Fix sgl-kernel benchmark dead code (#11022)

This commit is contained in:
Xiaoyu Zhang
2025-09-29 15:06:40 +08:00
committed by GitHub
parent 71959545df
commit 11965b0daf
25 changed files with 1019 additions and 260 deletions

View File

@@ -1,5 +1,6 @@
import argparse
import itertools
import os
import torch
import triton
@@ -8,8 +9,17 @@ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
USE_RANDOM_PERM = False
@@ -197,19 +207,23 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
num_tokens_post_pad_triton,
)
try:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True
except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
if VLLM_AVAILABLE:
try:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True
except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False
else:
print("⚠️ vLLM not available, skipping vLLM test")
vllm_works = False
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
@@ -394,8 +408,18 @@ if __name__ == "__main__":
)
args = parser.parse_args()
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
# Simplify for CI environment
if IS_CI:
num_tokens = 256 # Smaller for CI
num_experts = 8 # Smaller for CI
topk = 2 # Smaller for CI
else:
num_tokens = 1024
num_experts = args.num_experts
topk = args.topk
if not args.skip_full_benchmark:
calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk)
if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
benchmark.run(print_data=True)