Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
46
.github/workflows/pr-test.yml
vendored
46
.github/workflows/pr-test.yml
vendored
@@ -155,6 +155,50 @@ jobs:
|
||||
cd test/srt
|
||||
python3 test_mla_deepseek_v3.py
|
||||
|
||||
sgl-kernel-benchmark-test:
|
||||
needs: [check-changes, sgl-kernel-build-wheels]
|
||||
if: always() && !failure() && !cancelled()
|
||||
runs-on: 1-gpu-runner
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
CI: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Cleanup
|
||||
run: |
|
||||
ls -alh sgl-kernel/dist || true
|
||||
rm -rf sgl-kernel/dist/* || true
|
||||
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: sgl-kernel/dist/
|
||||
merge-multiple: true
|
||||
pattern: wheel-python3.10-cuda12.9
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/ci_install_dependency.sh
|
||||
|
||||
- name: Run benchmark tests
|
||||
timeout-minutes: 45
|
||||
run: |
|
||||
cd sgl-kernel/benchmark
|
||||
echo "Running sgl-kernel benchmark tests in CI mode..."
|
||||
|
||||
echo "CI environment variable: $CI"
|
||||
echo "GITHUB_ACTIONS environment variable: $GITHUB_ACTIONS"
|
||||
|
||||
for bench_file in bench_*.py; do
|
||||
echo "Testing $bench_file..."
|
||||
timeout 60 python3 "$bench_file" || echo "Warning: $bench_file timed out or failed, continuing..."
|
||||
echo "Completed $bench_file"
|
||||
echo "---"
|
||||
done
|
||||
|
||||
echo "All benchmark tests completed!"
|
||||
|
||||
# =============================================== primary ====================================================
|
||||
|
||||
unit-test-frontend:
|
||||
@@ -647,7 +691,7 @@ jobs:
|
||||
check-changes,
|
||||
|
||||
sgl-kernel-build-wheels,
|
||||
sgl-kernel-unit-test, sgl-kernel-mla-test,
|
||||
sgl-kernel-unit-test, sgl-kernel-mla-test, sgl-kernel-benchmark-test,
|
||||
|
||||
unit-test-frontend, unit-test-backend-1-gpu,
|
||||
unit-test-backend-2-gpu, unit-test-backend-4-gpu, unit-test-backend-8-gpu,
|
||||
|
||||
@@ -2460,7 +2460,7 @@ class BumpAllocator:
|
||||
def log_info_on_rank0(logger, msg):
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -11,7 +12,21 @@ import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
vllm_ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# gelu_quick is only available on HIP/ROCm platforms
|
||||
try:
|
||||
@@ -22,7 +37,7 @@ except ImportError:
|
||||
GELU_QUICK_AVAILABLE = False
|
||||
gelu_quick = None
|
||||
|
||||
if not hasattr(vllm_ops, "silu_and_mul"):
|
||||
if VLLM_AVAILABLE and not hasattr(vllm_ops, "silu_and_mul"):
|
||||
vllm_ops = torch.ops._C
|
||||
|
||||
|
||||
@@ -40,6 +55,13 @@ def calculate_diff(
|
||||
"""Compare vLLM with SGLang for one shape."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
if not VLLM_AVAILABLE:
|
||||
print(
|
||||
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
|
||||
f"L={seq_len:3d} | D={dim:5d}] ⚠️ vLLM not available, skipping comparison"
|
||||
)
|
||||
return True
|
||||
|
||||
# activation-only quick GELU
|
||||
if kernel == "gelu_quick":
|
||||
if not GELU_QUICK_AVAILABLE:
|
||||
@@ -68,19 +90,30 @@ def calculate_diff(
|
||||
return ok
|
||||
|
||||
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
|
||||
if GELU_QUICK_AVAILABLE:
|
||||
kernels.append("gelu_quick")
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
# CI environment uses simplified parameters for kernels and dtypes too
|
||||
if IS_CI:
|
||||
kernels = ["silu_and_mul"] # Only test one kernel in CI
|
||||
dtypes = [torch.float16] # Only test one dtype in CI
|
||||
else:
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
|
||||
if GELU_QUICK_AVAILABLE:
|
||||
kernels.append("gelu_quick")
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
|
||||
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
|
||||
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||
default_dims = [2**i for i in range(10, 15)] # 1024...16384
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
default_batch_sizes = [1] # Single batch size for CI
|
||||
default_seq_lens = [1] # Single sequence length for CI
|
||||
default_dims = [1024] # Single dimension for CI
|
||||
else:
|
||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||
default_dims = [2**i for i in range(10, 15)] # 1024...16384
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
@@ -102,16 +135,24 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
|
||||
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
|
||||
vllm_kernel = getattr(vllm_ops, kernel)
|
||||
if not VLLM_AVAILABLE and provider in ["vllm", "speedup"]:
|
||||
# Skip vLLM-related benchmarks if vLLM is not available
|
||||
return (0, 0, 0)
|
||||
|
||||
if VLLM_AVAILABLE:
|
||||
vllm_kernel = getattr(vllm_ops, kernel)
|
||||
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
|
||||
# Skip benchmark for gelu_quick if not available
|
||||
return (0, 0, 0)
|
||||
sglang_kernel = getattr(sgl_kernel, kernel)
|
||||
|
||||
def baseline():
|
||||
tmp = y0.clone()
|
||||
vllm_kernel(tmp, x)
|
||||
return tmp
|
||||
if VLLM_AVAILABLE:
|
||||
tmp = y0.clone()
|
||||
vllm_kernel(tmp, x)
|
||||
return tmp
|
||||
else:
|
||||
return torch.zeros_like(y0)
|
||||
|
||||
def sglang():
|
||||
return sglang_kernel(x)
|
||||
@@ -134,7 +175,7 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
# provider == "speedup"
|
||||
t_ref, _, _ = timed(baseline)
|
||||
t_sgl, _, _ = timed(sglang)
|
||||
spd = t_ref / t_sgl
|
||||
spd = t_ref / t_sgl if t_ref > 0 else 1.0
|
||||
return (spd, spd, spd)
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,34 @@
|
||||
import itertools
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import awq_dequantize
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def vllm_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not VLLM_AVAILABLE:
|
||||
# Fallback to SGLang implementation
|
||||
return sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
|
||||
|
||||
@@ -43,6 +61,10 @@ def calculate_diff(qweight_row: int, qweight_col: int):
|
||||
device=device,
|
||||
)
|
||||
|
||||
if not VLLM_AVAILABLE:
|
||||
print("⚠️ vLLM not available, skipping comparison")
|
||||
return
|
||||
|
||||
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
|
||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
@@ -56,8 +78,13 @@ def calculate_diff(qweight_row: int, qweight_col: int):
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
|
||||
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
qweight_row_range = [128] # Single row size for CI
|
||||
qweight_cols_range = [16] # Single column size for CI
|
||||
else:
|
||||
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
|
||||
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
|
||||
|
||||
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
|
||||
|
||||
@@ -67,9 +94,9 @@ configs = list(itertools.product(qweight_row_range, qweight_cols_range))
|
||||
x_names=["qweight_row", "qweight_col"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
line_vals=["vllm", "sglang"] if VLLM_AVAILABLE else ["sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"] if VLLM_AVAILABLE else ["SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")] if VLLM_AVAILABLE else [("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="awq-dequantize-performance",
|
||||
args={},
|
||||
@@ -100,6 +127,8 @@ def benchmark(qweight_row, qweight_col, provider):
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
if not VLLM_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
fn = lambda: vllm_awq_dequantize(
|
||||
qweight.clone(), scales.clone(), qzeros.clone()
|
||||
)
|
||||
@@ -114,5 +143,11 @@ def benchmark(qweight_row, qweight_col, provider):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(qweight_row=3584, qweight_col=448)
|
||||
# Simplify for CI environment
|
||||
if IS_CI:
|
||||
qweight_row, qweight_col = 128, 16 # Smaller values for CI
|
||||
else:
|
||||
qweight_row, qweight_col = 3584, 448
|
||||
|
||||
calculate_diff(qweight_row=qweight_row, qweight_col=qweight_col)
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
@@ -1,13 +1,27 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
|
||||
|
||||
bs_range = [1, 8, 32, 64, 128, 256]
|
||||
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
bs_range = [1] # Single batch size for CI
|
||||
qlen_range = [64] # Single sequence length for CI
|
||||
else:
|
||||
bs_range = [1, 8, 32, 64, 128, 256]
|
||||
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
configs = list(itertools.product(bs_range, qlen_range))
|
||||
|
||||
@@ -131,13 +145,34 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
for block_size in args.block_sizes:
|
||||
for kv_split in args.num_kv_splits:
|
||||
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
block_size=block_size,
|
||||
num_kv_splits=kv_split,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
# Skip in CI environment or unsupported architectures
|
||||
if IS_CI:
|
||||
major, minor = get_device_capability()
|
||||
if major is None or major < 10: # Requires compute capability 10.0+
|
||||
print("Skipping Cutlass MLA benchmark in CI environment")
|
||||
if major is not None:
|
||||
print(
|
||||
f"Cutlass MLA requires compute capability 10.0+, but found {major}.{minor}"
|
||||
)
|
||||
else:
|
||||
print("Could not determine device capability")
|
||||
else:
|
||||
for block_size in args.block_sizes:
|
||||
for kv_split in args.num_kv_splits:
|
||||
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
block_size=block_size,
|
||||
num_kv_splits=kv_split,
|
||||
)
|
||||
print("Benchmark finished!")
|
||||
else:
|
||||
for block_size in args.block_sizes:
|
||||
for kv_split in args.num_kv_splits:
|
||||
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
block_size=block_size,
|
||||
num_kv_splits=kv_split,
|
||||
)
|
||||
print("Benchmark finished!")
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -6,16 +13,28 @@ import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import dsv3_fused_a_gemm
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
num_tokens_vals = [1] # Only test 1 value in CI
|
||||
line_vals = ["sgl-kernel"] # Only test sgl-kernel implementation in CI
|
||||
else:
|
||||
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
|
||||
line_vals = ["torch", "sgl-kernel"]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_vals=num_tokens_vals,
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch", "sgl-kernel"],
|
||||
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
|
||||
styles=[("blue", "-"), ("orange", "-")],
|
||||
line_vals=line_vals,
|
||||
line_names=(
|
||||
["torch (bf16)", "dsv3_fused_a_gemm"]
|
||||
if not IS_CI
|
||||
else ["dsv3_fused_a_gemm"]
|
||||
),
|
||||
styles=[("blue", "-"), ("orange", "-")] if not IS_CI else [("orange", "-")],
|
||||
ylabel="TFLOPs",
|
||||
plot_name="bf16 dsv3 fused a GEMM throughput",
|
||||
args={},
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -6,21 +13,37 @@ import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import dsv3_router_gemm
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
num_tokens_vals = [1] # Only test 1 value in CI
|
||||
line_vals = ["sgl-kernel-256"] # Only test one implementation in CI
|
||||
else:
|
||||
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
|
||||
line_vals = ["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_vals=num_tokens_vals,
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
|
||||
line_names=[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
line_vals=line_vals,
|
||||
line_names=(
|
||||
[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
]
|
||||
if not IS_CI
|
||||
else ["dsv3_router_gemm-256"]
|
||||
),
|
||||
styles=(
|
||||
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
|
||||
if not IS_CI
|
||||
else [("orange", "-")]
|
||||
),
|
||||
ylabel="TFLOPs",
|
||||
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
|
||||
args={},
|
||||
@@ -64,17 +87,25 @@ def benchmark_bf16_output(num_tokens, impl):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_vals=num_tokens_vals,
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
|
||||
line_names=[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
line_vals=line_vals,
|
||||
line_names=(
|
||||
[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
]
|
||||
if not IS_CI
|
||||
else ["dsv3_router_gemm-256"]
|
||||
),
|
||||
styles=(
|
||||
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
|
||||
if not IS_CI
|
||||
else [("orange", "-")]
|
||||
),
|
||||
ylabel="TFLOPs",
|
||||
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
|
||||
args={},
|
||||
|
||||
@@ -2,6 +2,7 @@ import argparse
|
||||
import copy
|
||||
import csv
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -9,6 +10,14 @@ import triton
|
||||
from flashinfer import mm_fp4
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
@@ -33,27 +42,34 @@ def get_weight_shapes(args):
|
||||
]
|
||||
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_sizes = [1, 8] # Simplified for CI
|
||||
else:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
],
|
||||
x_vals=batch_sizes,
|
||||
# x_vals = [64],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
@@ -188,21 +204,38 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Simplify for CI environment
|
||||
if IS_CI:
|
||||
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
|
||||
|
||||
if args.csv:
|
||||
with open(args.csv, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["provider", "m", "n", "k", "time_ms"])
|
||||
|
||||
NKs = get_weight_shapes(args)
|
||||
for N, K in NKs:
|
||||
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
N=N,
|
||||
K=K,
|
||||
dtype=args.dtype,
|
||||
correctness=args.correctness,
|
||||
csv_file=args.csv,
|
||||
)
|
||||
# Check architecture compatibility - FP4 operations require sm100a/sm103a
|
||||
major, minor = get_device_capability()
|
||||
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
|
||||
print("Skipping FP4 GEMM benchmark")
|
||||
if major is not None:
|
||||
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
|
||||
else:
|
||||
print("Could not determine device capability")
|
||||
else:
|
||||
NKs = get_weight_shapes(args)
|
||||
|
||||
print("Benchmark finished!")
|
||||
# Limit iterations in CI
|
||||
if IS_CI:
|
||||
NKs = NKs[:2] # Only test first 2 shapes in CI
|
||||
|
||||
for N, K in NKs:
|
||||
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
N=N,
|
||||
K=K,
|
||||
dtype=args.dtype,
|
||||
correctness=args.correctness,
|
||||
csv_file=args.csv,
|
||||
)
|
||||
print("Benchmark finished!")
|
||||
|
||||
@@ -1,18 +1,33 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
vllm_scaled_mm = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def get_weight_shapes(args):
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
@@ -80,15 +95,46 @@ def scale_shape(shape, group_shape):
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_sizes = [1, 8] # Simplified for CI
|
||||
else:
|
||||
batch_sizes = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
# Filter providers based on availability
|
||||
available_providers = ["sgl-kernel"]
|
||||
available_names = ["sgl-kernel"]
|
||||
available_styles = [("orange", "-")]
|
||||
|
||||
if VLLM_AVAILABLE:
|
||||
available_providers.insert(0, "vllm")
|
||||
available_names.insert(0, "vllm")
|
||||
available_styles.insert(0, ("blue", "-"))
|
||||
|
||||
available_providers.append("triton")
|
||||
available_names.append("sglang triton")
|
||||
available_styles.append(("red", "-"))
|
||||
|
||||
# Add deepgemm if available
|
||||
try:
|
||||
import deep_gemm
|
||||
|
||||
available_providers.append("deepgemm")
|
||||
available_names.append("deepgemm")
|
||||
available_styles.append(("yellow", "-"))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
|
||||
x_vals=batch_sizes,
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
|
||||
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
|
||||
line_vals=available_providers,
|
||||
line_names=available_names,
|
||||
styles=available_styles,
|
||||
ylabel="GB/s",
|
||||
plot_name="fp8 blockwise scaled matmul",
|
||||
args={},
|
||||
@@ -123,14 +169,16 @@ def benchmark(batch_size, provider, N, K):
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
elif provider == "vllm":
|
||||
if not VLLM_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "triton":
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: w8a8_block_fp8_matmul(
|
||||
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
|
||||
@@ -166,7 +214,17 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Simplify for CI environment
|
||||
if IS_CI:
|
||||
args.models = [args.models[0]] # Use only first model
|
||||
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
|
||||
|
||||
NK_model_names = get_weight_shapes(args)
|
||||
|
||||
# Limit iterations in CI
|
||||
if IS_CI:
|
||||
NK_model_names = NK_model_names[:2] # Only test first 2 shapes in CI
|
||||
|
||||
for N, K, model_name in NK_model_names:
|
||||
if N % 128 != 0 or K % 128 != 0:
|
||||
print(f"Skip {N=}, {K=} now")
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
@@ -290,36 +297,44 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-warmup", type=int, default=3)
|
||||
parser.add_argument("--num-run", type=int, default=10)
|
||||
shape_args = [
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
|
||||
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
|
||||
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
|
||||
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
|
||||
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
|
||||
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
|
||||
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
|
||||
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
|
||||
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
|
||||
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
|
||||
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
|
||||
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
|
||||
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
|
||||
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
|
||||
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
|
||||
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
|
||||
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
|
||||
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
|
||||
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
|
||||
]
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
shape_args = [
|
||||
# Only test one simple shape in CI
|
||||
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
|
||||
]
|
||||
else:
|
||||
shape_args = [
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
|
||||
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
|
||||
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
|
||||
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
|
||||
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
|
||||
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
|
||||
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
|
||||
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
|
||||
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
|
||||
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
|
||||
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
|
||||
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
|
||||
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
|
||||
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
|
||||
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
|
||||
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
|
||||
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
|
||||
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
|
||||
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
|
||||
]
|
||||
args = parser.parse_args()
|
||||
benchmark_one_shape(shape_args, args.num_warmup, args.num_run)
|
||||
|
||||
|
||||
@@ -1,14 +1,30 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
vllm_scaled_mm = None
|
||||
vllm_scaled_fp8_quant = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
@@ -86,25 +102,48 @@ def sglang_scaled_fp8_quant(
|
||||
return output, scale
|
||||
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_sizes = [1] # Single batch size for CI
|
||||
else:
|
||||
batch_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
# Filter line_vals based on vLLM availability
|
||||
if VLLM_AVAILABLE:
|
||||
line_vals = [
|
||||
"vllm-fp8-fp16",
|
||||
"vllm-fp8-bf16",
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
]
|
||||
line_names = [
|
||||
"vllm-fp8-fp16",
|
||||
"vllm-fp8-bf16",
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
]
|
||||
styles = [("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")]
|
||||
else:
|
||||
line_vals = [
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
]
|
||||
line_names = [
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
]
|
||||
styles = [("blue", "-"), ("blue", "--")]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
|
||||
x_vals=batch_sizes,
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"vllm-fp8-fp16",
|
||||
"vllm-fp8-bf16",
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
],
|
||||
line_names=[
|
||||
"vllm-fp8-fp16",
|
||||
"vllm-fp8-bf16",
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
],
|
||||
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
|
||||
line_vals=line_vals,
|
||||
line_names=line_names,
|
||||
styles=styles,
|
||||
ylabel="GB/s",
|
||||
plot_name="fp8 scaled matmul",
|
||||
args={},
|
||||
@@ -115,6 +154,9 @@ def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
a = torch.ones((M, K), device="cuda") * 5.0
|
||||
b = torch.ones((N, K), device="cuda") * 5.0
|
||||
# vLLM expects scalar scales, while sglang can handle per-token scales
|
||||
scale_a_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
|
||||
scale_b_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
@@ -122,8 +164,11 @@ def benchmark(batch_size, provider, N, K):
|
||||
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
|
||||
|
||||
if "vllm-fp8" in provider:
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
if not VLLM_AVAILABLE:
|
||||
# Return zero if vLLM is not available
|
||||
return (0, 0, 0)
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_scalar)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b_scalar)
|
||||
b_fp8 = b_fp8.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
||||
@@ -174,6 +219,11 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Simplify for CI environment
|
||||
if IS_CI:
|
||||
args.models = [args.models[0]] # Use only first model
|
||||
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
|
||||
@@ -1,11 +1,26 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
vllm_scaled_mm = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
@@ -62,15 +77,32 @@ WEIGHT_SHAPES = {
|
||||
}
|
||||
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_sizes = [1] # Single batch size for CI
|
||||
else:
|
||||
batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
# Filter providers based on vLLM availability
|
||||
if VLLM_AVAILABLE:
|
||||
line_vals = ["vllm", "sgl-kernel"]
|
||||
line_names = ["vllm int8 gemm", "sgl-kernel int8 gemm"]
|
||||
styles = [("blue", "-"), ("orange", "-")]
|
||||
else:
|
||||
line_vals = ["sgl-kernel"]
|
||||
line_names = ["sgl-kernel int8 gemm"]
|
||||
styles = [("orange", "-")]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_vals=batch_sizes,
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sgl-kernel"],
|
||||
line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"],
|
||||
styles=[("blue", "-"), ("orange", "-")],
|
||||
line_vals=line_vals,
|
||||
line_names=line_names,
|
||||
styles=styles,
|
||||
ylabel="GB/s",
|
||||
plot_name="int8 scaled matmul",
|
||||
args={},
|
||||
@@ -90,7 +122,9 @@ def benchmark(batch_size, provider, N, K):
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
elif provider == "vllm":
|
||||
if not VLLM_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
@@ -136,9 +170,16 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(print_data=True, N=N, K=K)
|
||||
# Skip in CI environment due to architecture compatibility issues
|
||||
if IS_CI:
|
||||
print(
|
||||
"Skipping INT8 GEMM benchmark in CI environment due to architecture compatibility issues"
|
||||
)
|
||||
print("INT8 operations may not be supported on all GPU architectures")
|
||||
else:
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(print_data=True, N=N, K=K)
|
||||
|
||||
print("Benchmark finished!")
|
||||
print("Benchmark finished!")
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import lightning_attention_decode
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||
@@ -207,7 +214,12 @@ def calculate_diff(batch_size):
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [i for i in range(1, 65)] # 1 to 128
|
||||
# Simplified for CI environment
|
||||
if IS_CI:
|
||||
batch_size_range = [1] # Single batch size for CI
|
||||
else:
|
||||
batch_size_range = [i for i in range(1, 65)] # 1 to 64
|
||||
|
||||
configs = [(bs,) for bs in batch_size_range]
|
||||
|
||||
|
||||
@@ -292,8 +304,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(batch_size=4)
|
||||
# Run correctness test - simplified for CI
|
||||
test_batch_size = 1 if IS_CI else 4
|
||||
calculate_diff(batch_size=test_batch_size)
|
||||
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -8,8 +9,17 @@ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
USE_RANDOM_PERM = False
|
||||
|
||||
@@ -197,19 +207,23 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
||||
num_tokens_post_pad_triton,
|
||||
)
|
||||
|
||||
try:
|
||||
ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_vllm,
|
||||
expert_ids_vllm,
|
||||
num_tokens_post_pad_vllm,
|
||||
)
|
||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||
vllm_works = True
|
||||
except Exception as e:
|
||||
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
|
||||
if VLLM_AVAILABLE:
|
||||
try:
|
||||
ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_vllm,
|
||||
expert_ids_vllm,
|
||||
num_tokens_post_pad_vllm,
|
||||
)
|
||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||
vllm_works = True
|
||||
except Exception as e:
|
||||
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
|
||||
vllm_works = False
|
||||
else:
|
||||
print("⚠️ vLLM not available, skipping vLLM test")
|
||||
vllm_works = False
|
||||
|
||||
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
|
||||
@@ -394,8 +408,18 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
||||
# Simplify for CI environment
|
||||
if IS_CI:
|
||||
num_tokens = 256 # Smaller for CI
|
||||
num_experts = 8 # Smaller for CI
|
||||
topk = 2 # Smaller for CI
|
||||
else:
|
||||
num_tokens = 1024
|
||||
num_experts = args.num_experts
|
||||
topk = args.topk
|
||||
|
||||
if not args.skip_full_benchmark:
|
||||
calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk)
|
||||
|
||||
if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI
|
||||
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
@@ -1,9 +1,22 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
||||
|
||||
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_sizes = [64, 128] # Only test 2 values in CI
|
||||
else:
|
||||
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||
|
||||
configs = [(bs,) for bs in batch_sizes]
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -8,6 +9,12 @@ from sgl_kernel import moe_fused_gate
|
||||
|
||||
from sglang.srt.layers.moe.topk import biased_grouped_topk
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
|
||||
return biased_grouped_topk(
|
||||
@@ -28,7 +35,12 @@ def biased_grouped_topk_org_fuse_kernel(
|
||||
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
|
||||
|
||||
|
||||
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
seq_length_range = [5000] # Only test one sequence length in CI
|
||||
else:
|
||||
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
|
||||
|
||||
configs = [(sq,) for sq in seq_length_range]
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,32 @@
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import topk_softmax
|
||||
from vllm import _custom_ops as vllm_custom_ops
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm import _custom_ops as vllm_custom_ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
vllm_custom_ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def vllm_topk_softmax(gating_output, topk):
|
||||
if not VLLM_AVAILABLE:
|
||||
# Fallback to SGLang implementation if vLLM is not available
|
||||
return sglang_topk_softmax(gating_output, topk)
|
||||
|
||||
num_tokens, num_experts = gating_output.shape
|
||||
|
||||
topk_weights = torch.empty(
|
||||
@@ -54,6 +73,10 @@ def calculate_diff(num_tokens, num_experts, topk):
|
||||
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
|
||||
indices_match = torch.equal(indices_vllm, indices_sglang)
|
||||
|
||||
if not VLLM_AVAILABLE:
|
||||
print("⚠️ vLLM not available, skipping comparison")
|
||||
return
|
||||
|
||||
if (
|
||||
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
|
||||
and indices_match
|
||||
@@ -65,21 +88,38 @@ def calculate_diff(num_tokens, num_experts, topk):
|
||||
)
|
||||
|
||||
|
||||
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
||||
num_experts_range = [32, 64, 128, 256, 12, 512]
|
||||
topk_range = [1, 2, 4, 8]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
num_tokens_range = [128] # Single value for CI
|
||||
num_experts_range = [32] # Single value for CI
|
||||
topk_range = [2] # Single value for CI
|
||||
else:
|
||||
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
||||
num_experts_range = [32, 64, 128, 256, 12, 512]
|
||||
topk_range = [1, 2, 4, 8]
|
||||
|
||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||
|
||||
|
||||
# Filter providers based on vLLM availability
|
||||
if VLLM_AVAILABLE:
|
||||
line_vals = ["sglang", "vllm"]
|
||||
line_names = ["SGLang", "VLLM"]
|
||||
styles = [("blue", "-"), ("green", "-")]
|
||||
else:
|
||||
line_vals = ["sglang"]
|
||||
line_names = ["SGLang"]
|
||||
styles = [("blue", "-")]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_experts", "topk"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["sglang", "vllm"],
|
||||
line_names=["SGLang", "VLLM"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
line_vals=line_vals,
|
||||
line_names=line_names,
|
||||
styles=styles,
|
||||
ylabel="Latency (us)",
|
||||
plot_name="topk-softmax-performance",
|
||||
args={},
|
||||
@@ -92,6 +132,8 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
)
|
||||
|
||||
if provider == "vllm" or provider == "vllm1":
|
||||
if not VLLM_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
fn = lambda: vllm_topk_softmax(gating_output, topk)
|
||||
elif provider == "sglang" or provider == "sglang1":
|
||||
fn = lambda: sglang_topk_softmax(gating_output, topk)
|
||||
@@ -103,14 +145,19 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
configs = [
|
||||
(20, 256, 4),
|
||||
(20, 256, 8),
|
||||
(20, 12, 4),
|
||||
(20, 12, 1),
|
||||
(20, 512, 4),
|
||||
(20, 512, 1),
|
||||
]
|
||||
for num_tokens, num_experts, topk in configs:
|
||||
# Simplify configs for CI environment
|
||||
if IS_CI:
|
||||
test_configs = [(20, 32, 2)] # Single config for CI
|
||||
else:
|
||||
test_configs = [
|
||||
(20, 256, 4),
|
||||
(20, 256, 8),
|
||||
(20, 12, 4),
|
||||
(20, 12, 1),
|
||||
(20, 512, 4),
|
||||
(20, 512, 1),
|
||||
]
|
||||
|
||||
for num_tokens, num_experts, topk in test_configs:
|
||||
calculate_diff(num_tokens, num_experts, topk)
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
@@ -162,9 +171,22 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(print_data=True, N=N, K=K)
|
||||
# Check architecture compatibility - FP4 operations require sm100a/sm103a
|
||||
major, minor = get_device_capability()
|
||||
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
|
||||
print("Skipping NVIDIA FP4 scaled GEMM benchmark")
|
||||
if major is not None:
|
||||
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
|
||||
else:
|
||||
print("Could not determine device capability")
|
||||
else:
|
||||
KN_model_names = prepare_shapes(args)
|
||||
|
||||
print("Benchmark finished!")
|
||||
# Limit iterations in CI
|
||||
if IS_CI:
|
||||
KN_model_names = KN_model_names[:2] # Only test first 2 shapes in CI
|
||||
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(print_data=True, N=N, K=K)
|
||||
print("Benchmark finished!")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -7,11 +8,26 @@ import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# Optional imports
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
@@ -19,6 +35,9 @@ def vllm_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not VLLM_AVAILABLE:
|
||||
# Fallback to SGLang implementation
|
||||
return sglang_scaled_fp8_quant(input, scale)
|
||||
return ops.scaled_fp8_quant(input, scale)
|
||||
|
||||
|
||||
@@ -42,6 +61,10 @@ def calculate_diff(batch_size: int, seq_len: int):
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
|
||||
|
||||
if not VLLM_AVAILABLE:
|
||||
print("⚠️ vLLM not available, skipping comparison")
|
||||
return
|
||||
|
||||
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
|
||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
|
||||
|
||||
@@ -56,8 +79,13 @@ def calculate_diff(batch_size: int, seq_len: int):
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_size_range = [16] # Single batch size for CI
|
||||
seq_len_range = [64] # Single sequence length for CI
|
||||
else:
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -16,15 +17,28 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
|
||||
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
|
||||
group_size_range = [128] # For DeepSeek V3/R1
|
||||
# TODO test int8
|
||||
dst_dtype_range = [fp8_type_]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
num_tokens_range = [64] # Single value for CI
|
||||
hidden_dim_range = [1536] # Single value for CI
|
||||
group_size_range = [128] # Keep as is
|
||||
dst_dtype_range = [fp8_type_] # Keep as is
|
||||
else:
|
||||
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
|
||||
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
|
||||
group_size_range = [128] # For DeepSeek V3/R1
|
||||
# TODO test int8
|
||||
dst_dtype_range = [fp8_type_]
|
||||
flags_range = [
|
||||
dict(
|
||||
column_major_scales=False,
|
||||
@@ -82,7 +96,7 @@ def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
|
||||
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
|
||||
|
||||
fn, kernel_names = {
|
||||
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
|
||||
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_8bit"),
|
||||
"sglang": (
|
||||
sglang_per_token_group_quant_8bit,
|
||||
"per_token_group_quant_8bit_kernel",
|
||||
|
||||
@@ -1,15 +1,31 @@
|
||||
import itertools
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import sgl_per_token_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
# Get correct FP8 E4M3 maximum value
|
||||
@@ -49,6 +65,9 @@ def torch_per_token_quant_fp8(
|
||||
def vllm_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not VLLM_AVAILABLE:
|
||||
# Fallback to SGLang implementation
|
||||
return sglang_per_token_quant_fp8(input)
|
||||
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
@@ -74,6 +93,17 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
|
||||
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
|
||||
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
|
||||
|
||||
if not VLLM_AVAILABLE:
|
||||
print("⚠️ vLLM not available, skipping vLLM comparison")
|
||||
# Only compare Torch vs SGLang
|
||||
torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
|
||||
torch_sglang_out_diff = (
|
||||
torch.abs(torch_out.float() - sglang_out.float()).mean().item()
|
||||
)
|
||||
print(f"Scale difference (Torch vs SGLang): {torch_sglang_scale_diff:.8f}")
|
||||
print(f"Output difference (Torch vs SGLang): {torch_sglang_out_diff:.8f}")
|
||||
return
|
||||
|
||||
print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
|
||||
|
||||
# Compare scales
|
||||
@@ -125,9 +155,15 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
|
||||
print(f" VLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}")
|
||||
|
||||
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
|
||||
hidden_dim_range = [1368, 2048, 4096]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_size_range = [16] # Single batch size for CI
|
||||
seq_len_range = [64] # Single sequence length for CI
|
||||
hidden_dim_range = [2048] # Single hidden dimension for CI
|
||||
else:
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
|
||||
hidden_dim_range = [1368, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
|
||||
|
||||
@@ -137,9 +173,19 @@ configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_ran
|
||||
x_names=["batch_size", "seq_len", "hidden_dim"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "vllm", "sglang"],
|
||||
line_names=["Torch Reference", "VLLM", "SGL Kernel"],
|
||||
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
|
||||
line_vals=(
|
||||
["torch", "vllm", "sglang"] if VLLM_AVAILABLE else ["torch", "sglang"]
|
||||
),
|
||||
line_names=(
|
||||
["Torch Reference", "VLLM", "SGL Kernel"]
|
||||
if VLLM_AVAILABLE
|
||||
else ["Torch Reference", "SGL Kernel"]
|
||||
),
|
||||
styles=(
|
||||
[("red", "-"), ("blue", "-"), ("green", "-")]
|
||||
if VLLM_AVAILABLE
|
||||
else [("red", "-"), ("green", "-")]
|
||||
),
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
@@ -156,6 +202,8 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
elif provider == "vllm":
|
||||
if not VLLM_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
fn = lambda: vllm_per_token_quant_fp8(x.clone())
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_per_token_quant_fp8(x.clone())
|
||||
@@ -166,11 +214,16 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test various hidden dimensions for correctness
|
||||
test_dims = [1368, 2048, 4096]
|
||||
# Test various hidden dimensions for correctness - simplified for CI
|
||||
if IS_CI:
|
||||
test_dims = [2048] # Single dimension for CI
|
||||
batch_size, seq_len = 4, 64 # Smaller values for CI
|
||||
else:
|
||||
test_dims = [1368, 2048, 4096]
|
||||
batch_size, seq_len = 4, 4096
|
||||
|
||||
for dim in test_dims:
|
||||
calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim)
|
||||
calculate_diff(batch_size=batch_size, seq_len=seq_len, hidden_dim=dim)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting performance benchmark...")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -10,6 +11,12 @@ from sgl_kernel import (
|
||||
qserve_w4a8_per_group_gemm,
|
||||
)
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
@@ -65,10 +72,17 @@ WEIGHT_SHAPES = {
|
||||
}
|
||||
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_sizes = [1, 16] # Simplified for CI
|
||||
else:
|
||||
batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_vals=batch_sizes,
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
|
||||
@@ -184,13 +198,19 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
# Skip in CI environment
|
||||
if IS_CI:
|
||||
print("Skipping QServe W4A8 GEMM benchmark in CI environment")
|
||||
print("QServe operations may have compatibility issues in CI")
|
||||
else:
|
||||
KN_model_names = prepare_shapes(args)
|
||||
|
||||
print("Benchmark finished!")
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# (batch_size, seq_len, hidden_size) and prints speed-up.
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -10,9 +11,31 @@ import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.testing
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
from sgl_kernel.utils import is_arch_support_pdl
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
# Optional imports
|
||||
try:
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
|
||||
FLASHINFER_AVAILABLE = True
|
||||
except ImportError:
|
||||
fused_add_rmsnorm = None
|
||||
rmsnorm = None
|
||||
FLASHINFER_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
vllm_ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def str2int_list(arg: str) -> List[int]:
|
||||
@@ -79,6 +102,10 @@ def rmsnorm_flashinfer(
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
if not FLASHINFER_AVAILABLE:
|
||||
# Fallback to naive implementation if FlashInfer is not available
|
||||
return rmsnorm_naive(x, weight, residual, eps)
|
||||
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
@@ -103,6 +130,10 @@ def rmsnorm_vllm(
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
if not VLLM_AVAILABLE:
|
||||
# Fallback to naive implementation if vLLM is not available
|
||||
return rmsnorm_naive(x, weight, residual, eps)
|
||||
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
@@ -179,37 +210,72 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
output_sglang = output_sglang[0]
|
||||
|
||||
print(f"Naive output={output_naive}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"VLLM output={output_vllm}")
|
||||
if FLASHINFER_AVAILABLE:
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
else:
|
||||
print("FlashInfer not available, skipped")
|
||||
if VLLM_AVAILABLE:
|
||||
print(f"VLLM output={output_vllm}")
|
||||
else:
|
||||
print("vLLM not available, skipped")
|
||||
print(f"SGLang output={output_sglang}")
|
||||
|
||||
if (
|
||||
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
# Only compare available implementations
|
||||
all_match = torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
|
||||
if FLASHINFER_AVAILABLE:
|
||||
all_match = all_match and torch.allclose(
|
||||
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
if VLLM_AVAILABLE:
|
||||
all_match = all_match and torch.allclose(
|
||||
output_naive, output_vllm, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
if all_match:
|
||||
print("✅ All available implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
|
||||
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
|
||||
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
default_batch_sizes = [1] # Single batch size for CI
|
||||
default_seq_lens = [64] # Single sequence length for CI
|
||||
default_hidden_sizes = [4096] # Single hidden size for CI
|
||||
else:
|
||||
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
|
||||
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
|
||||
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
|
||||
|
||||
|
||||
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
|
||||
return list(itertools.product(bsizes, slens, hsizes))
|
||||
|
||||
|
||||
# Filter providers based on availability
|
||||
available_providers = ["huggingface", "sglang"]
|
||||
available_names = ["HuggingFace", "SGL Kernel"]
|
||||
available_styles = [("blue", "-"), ("orange", "-")]
|
||||
|
||||
if FLASHINFER_AVAILABLE:
|
||||
available_providers.insert(-1, "flashinfer")
|
||||
available_names.insert(-1, "FlashInfer")
|
||||
available_styles.insert(-1, ("green", "-"))
|
||||
|
||||
if VLLM_AVAILABLE:
|
||||
available_providers.insert(-1, "vllm")
|
||||
available_names.insert(-1, "vLLM")
|
||||
available_styles.insert(-1, ("red", "-"))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "hidden_size"],
|
||||
x_vals=[],
|
||||
line_arg="provider",
|
||||
line_vals=["huggingface", "flashinfer", "vllm", "sglang"],
|
||||
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")],
|
||||
line_vals=available_providers,
|
||||
line_names=available_names,
|
||||
styles=available_styles,
|
||||
ylabel="µs (median) or × (speed-up)",
|
||||
plot_name="rmsnorm-performance",
|
||||
args={},
|
||||
@@ -242,6 +308,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||
)
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
if not FLASHINFER_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
return timed(
|
||||
lambda: rmsnorm_flashinfer(
|
||||
x.clone(),
|
||||
@@ -250,6 +318,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||
)
|
||||
)
|
||||
elif provider == "vllm":
|
||||
if not VLLM_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
return timed(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
@@ -267,13 +337,22 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||
)
|
||||
|
||||
# provider == "speedup"
|
||||
t_ref, _, _ = timed(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
if VLLM_AVAILABLE:
|
||||
t_ref, _, _ = timed(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
t_ref, _, _ = timed(
|
||||
lambda: rmsnorm_naive(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
)
|
||||
t_sgl, _, _ = timed(
|
||||
lambda: rmsnorm_sglang(
|
||||
x.clone(),
|
||||
@@ -281,7 +360,7 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
|
||||
residual.clone() if residual is not None else None,
|
||||
)
|
||||
)
|
||||
spd = t_ref / t_sgl
|
||||
spd = t_ref / t_sgl if t_ref > 0 else 1.0
|
||||
return (spd, spd, spd)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -12,17 +13,31 @@ from sgl_kernel.testing.rotary_embedding import (
|
||||
|
||||
from sglang.srt.bench_utils import bench_kineto
|
||||
|
||||
configs = [
|
||||
(batch_size, seq_len, save_kv_cache)
|
||||
for batch_size, seq_len in (
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_seq_configs = [(1, 1)] # Single config for CI
|
||||
save_kv_configs = [False] # Single option for CI
|
||||
else:
|
||||
batch_seq_configs = [
|
||||
(1, 1),
|
||||
(32, 1),
|
||||
(128, 1),
|
||||
(512, 1),
|
||||
(2, 512),
|
||||
(4, 4096),
|
||||
)
|
||||
for save_kv_cache in (False, True)
|
||||
]
|
||||
save_kv_configs = [False, True]
|
||||
|
||||
configs = [
|
||||
(batch_size, seq_len, save_kv_cache)
|
||||
for batch_size, seq_len in batch_seq_configs
|
||||
for save_kv_cache in save_kv_configs
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k, top_p, eps=1e-4
|
||||
@@ -67,10 +74,16 @@ def calculate_diff(batch_size, vocab_size, p):
|
||||
)
|
||||
|
||||
|
||||
# parameter space
|
||||
batch_size_range = [16, 64, 128]
|
||||
vocab_size_range = [111, 32000]
|
||||
p_range = [0.1, 0.5]
|
||||
# parameter space - simplified for CI
|
||||
if IS_CI:
|
||||
batch_size_range = [16] # Single batch size for CI
|
||||
vocab_size_range = [111] # Single vocab size for CI
|
||||
p_range = [0.1] # Single p value for CI
|
||||
else:
|
||||
batch_size_range = [16, 64, 128]
|
||||
vocab_size_range = [111, 32000]
|
||||
p_range = [0.1, 0.5]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
|
||||
|
||||
|
||||
@@ -114,15 +127,19 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
fn, quantiles=[0.5, 0.2, 0.8]
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Correctness check
|
||||
for cfg in configs:
|
||||
# Correctness check - simplified for CI
|
||||
if IS_CI:
|
||||
# Only test one configuration in CI
|
||||
test_configs = [configs[0]] if configs else [(16, 111, 0.1)]
|
||||
else:
|
||||
test_configs = configs
|
||||
|
||||
for cfg in test_configs:
|
||||
calculate_diff(*cfg)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
Reference in New Issue
Block a user