Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -8,8 +9,17 @@ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
USE_RANDOM_PERM = False
|
||||
|
||||
@@ -197,19 +207,23 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
||||
num_tokens_post_pad_triton,
|
||||
)
|
||||
|
||||
try:
|
||||
ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_vllm,
|
||||
expert_ids_vllm,
|
||||
num_tokens_post_pad_vllm,
|
||||
)
|
||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||
vllm_works = True
|
||||
except Exception as e:
|
||||
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
|
||||
if VLLM_AVAILABLE:
|
||||
try:
|
||||
ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_vllm,
|
||||
expert_ids_vllm,
|
||||
num_tokens_post_pad_vllm,
|
||||
)
|
||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||
vllm_works = True
|
||||
except Exception as e:
|
||||
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
|
||||
vllm_works = False
|
||||
else:
|
||||
print("⚠️ vLLM not available, skipping vLLM test")
|
||||
vllm_works = False
|
||||
|
||||
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
|
||||
@@ -394,8 +408,18 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
||||
# Simplify for CI environment
|
||||
if IS_CI:
|
||||
num_tokens = 256 # Smaller for CI
|
||||
num_experts = 8 # Smaller for CI
|
||||
topk = 2 # Smaller for CI
|
||||
else:
|
||||
num_tokens = 1024
|
||||
num_experts = args.num_experts
|
||||
topk = args.topk
|
||||
|
||||
if not args.skip_full_benchmark:
|
||||
calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk)
|
||||
|
||||
if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI
|
||||
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
Reference in New Issue
Block a user