164 lines
4.7 KiB
Python
164 lines
4.7 KiB
Python
import itertools
|
|
import os
|
|
|
|
import pytest
|
|
import torch
|
|
import triton
|
|
from sgl_kernel import topk_softmax
|
|
|
|
# Optional vLLM import
|
|
try:
|
|
from vllm import _custom_ops as vllm_custom_ops
|
|
|
|
VLLM_AVAILABLE = True
|
|
except ImportError:
|
|
vllm_custom_ops = None
|
|
VLLM_AVAILABLE = False
|
|
|
|
# CI environment detection
|
|
IS_CI = (
|
|
os.getenv("CI", "false").lower() == "true"
|
|
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
|
)
|
|
|
|
|
|
def vllm_topk_softmax(gating_output, topk):
|
|
if not VLLM_AVAILABLE:
|
|
# Fallback to SGLang implementation if vLLM is not available
|
|
return sglang_topk_softmax(gating_output, topk)
|
|
|
|
num_tokens, num_experts = gating_output.shape
|
|
|
|
topk_weights = torch.empty(
|
|
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
|
|
)
|
|
topk_indices = torch.empty(
|
|
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
|
|
)
|
|
token_expert_indices = torch.empty(
|
|
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
|
|
)
|
|
torch.ops._moe_C.topk_softmax(
|
|
topk_weights, topk_indices, token_expert_indices, gating_output
|
|
)
|
|
return topk_weights, topk_indices
|
|
|
|
|
|
def sglang_topk_softmax(gating_output, topk):
|
|
num_tokens, num_experts = gating_output.shape
|
|
|
|
topk_weights = torch.empty(
|
|
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
|
|
)
|
|
topk_indices = torch.empty(
|
|
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
|
|
)
|
|
|
|
topk_softmax(
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_indices,
|
|
gating_output=gating_output,
|
|
)
|
|
|
|
return topk_weights, topk_indices
|
|
|
|
|
|
def calculate_diff(num_tokens, num_experts, topk):
|
|
gating_output = torch.randn(
|
|
(num_tokens, num_experts), device="cuda", dtype=torch.float32
|
|
)
|
|
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
|
|
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
|
|
|
|
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
|
|
indices_match = torch.equal(indices_vllm, indices_sglang)
|
|
|
|
if not VLLM_AVAILABLE:
|
|
print("⚠️ vLLM not available, skipping comparison")
|
|
return
|
|
|
|
if (
|
|
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
|
|
and indices_match
|
|
):
|
|
print("✅ VLLM and SGLang topk_softmax implementations match")
|
|
else:
|
|
print(
|
|
f"❌ Implementations differ: Weights diff={weights_diff}, Indices match={indices_match}"
|
|
)
|
|
|
|
|
|
# CI environment uses simplified parameters
|
|
if IS_CI:
|
|
num_tokens_range = [128] # Single value for CI
|
|
num_experts_range = [32] # Single value for CI
|
|
topk_range = [2] # Single value for CI
|
|
else:
|
|
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
|
num_experts_range = [32, 64, 128, 256, 12, 512]
|
|
topk_range = [1, 2, 4, 8]
|
|
|
|
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
|
|
|
|
|
# Filter providers based on vLLM availability
|
|
if VLLM_AVAILABLE:
|
|
line_vals = ["sglang", "vllm"]
|
|
line_names = ["SGLang", "VLLM"]
|
|
styles = [("blue", "-"), ("green", "-")]
|
|
else:
|
|
line_vals = ["sglang"]
|
|
line_names = ["SGLang"]
|
|
styles = [("blue", "-")]
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["num_tokens", "num_experts", "topk"],
|
|
x_vals=configs,
|
|
line_arg="provider",
|
|
line_vals=line_vals,
|
|
line_names=line_names,
|
|
styles=styles,
|
|
ylabel="Latency (us)",
|
|
plot_name="topk-softmax-performance",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(num_tokens, num_experts, topk, provider):
|
|
|
|
gating_output = torch.randn(
|
|
(num_tokens, num_experts), device="cuda", dtype=torch.float32
|
|
)
|
|
|
|
if provider == "vllm" or provider == "vllm1":
|
|
if not VLLM_AVAILABLE:
|
|
return (0, 0, 0)
|
|
fn = lambda: vllm_topk_softmax(gating_output, topk)
|
|
elif provider == "sglang" or provider == "sglang1":
|
|
fn = lambda: sglang_topk_softmax(gating_output, topk)
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
|
|
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Simplify configs for CI environment
|
|
if IS_CI:
|
|
test_configs = [(20, 32, 2)] # Single config for CI
|
|
else:
|
|
test_configs = [
|
|
(20, 256, 4),
|
|
(20, 256, 8),
|
|
(20, 12, 4),
|
|
(20, 12, 1),
|
|
(20, 512, 4),
|
|
(20, 512, 1),
|
|
]
|
|
|
|
for num_tokens, num_experts, topk in test_configs:
|
|
calculate_diff(num_tokens, num_experts, topk)
|
|
benchmark.run(print_data=True)
|