adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

View File

@@ -0,0 +1,153 @@
# Benchmarks SGLang kernels versus vLLM across
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse
import itertools
import re
from typing import List, Tuple
import sgl_kernel
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_quick # activation-only kernel
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
if not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C
def str2int_list(arg: str) -> List[int]:
if arg in ("", None):
return []
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
return [int(x) for x in arg.split(",")]
def calculate_diff(
kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int
) -> bool:
"""Compare vLLM with SGLang for one shape."""
device = torch.device("cuda")
# activation-only quick GELU
if kernel == "gelu_quick":
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
ref_out = torch.zeros_like(x)
getattr(vllm_ops, kernel)(ref_out, x)
test_out = getattr(sgl_kernel, kernel)(x)
# fused activation x mul kernels
else:
x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device)
ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
getattr(vllm_ops, kernel)(ref_out, x)
test_out = getattr(sgl_kernel, kernel)(x)
ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5)
tag = "✅ match" if ok else "❌ mismatch"
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] {tag}"
)
return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
dtypes = [torch.float16, torch.bfloat16]
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(7, 15)] # 128...16384
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"],
x_vals=[],
line_arg="provider",
line_vals=["vllm", "sglang", "speedup"],
line_names=["vLLM", "SGL Kernel", "Speed-up (x)"],
styles=[("blue", "-"), ("green", "-"), ("red", "--")],
ylabel="µs (median) or × (speed-up)",
plot_name="activation-performance",
args={},
)
)
def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
device = torch.device("cuda")
in_mult = 1 if kernel == "gelu_quick" else 2
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel)
sglang_kernel = getattr(sgl_kernel, kernel)
def baseline():
tmp = y0.clone()
vllm_kernel(tmp, x)
return tmp
def sglang():
return sglang_kernel(x)
# one-time correctness check
if provider == "vllm" and not calculate_diff(
kernel, dtype, batch_size, seq_len, dim
):
raise ValueError("Mismatch abort benchmark")
# timing helper
def timed(fn):
for _ in range(5):
fn()
torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "vllm":
return timed(baseline)
if provider == "sglang":
return timed(sglang)
# provider == "speedup"
t_ref, _, _ = timed(baseline)
t_sgl, _, _ = timed(sglang)
spd = t_ref / t_sgl
return (spd, spd, spd)
if __name__ == "__main__":
p = argparse.ArgumentParser("Activation kernel benchmark")
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
p.add_argument("--dims", type=str2int_list, default=default_dims)
p.add_argument("--verify_only", action="store_true")
args = p.parse_args()
# coerce lists
if isinstance(args.batch_sizes, str):
args.batch_sizes = str2int_list(args.batch_sizes)
if isinstance(args.seq_lens, str):
args.seq_lens = str2int_list(args.seq_lens)
if isinstance(args.dims, str):
args.dims = str2int_list(args.dims)
# patch perf_report grid
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims)
if hasattr(benchmark, "benchmarks"):
benchmark.benchmarks.x_vals = benchmark_grid
else:
benchmark.benchmark.x_vals = benchmark_grid
if args.verify_only:
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
print("✅ sanity pass" if ok else "❌ mismatch")
else:
benchmark.run(print_data=True)

View File

@@ -0,0 +1,118 @@
import itertools
from typing import List, Tuple
import torch
import triton
import triton.testing
from sgl_kernel import awq_dequantize
from vllm import _custom_ops as ops
def vllm_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
def sglang_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return awq_dequantize(qweight, scales, qzeros)
def calculate_diff(qweight_row: int, qweight_col: int):
"""Calculate difference between VLLM and SGLang implementations."""
device = torch.device("cuda")
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_row, qweight_col),
dtype=torch.int32,
device=device,
)
group_size = qweight_row
scales_row = qweight_row // group_size
scales_col = qweight_col * 8
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(scales_row, qweight_col),
dtype=torch.int32,
device=device,
)
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["qweight_row", "qweight_col"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "sglang"],
line_names=["VLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="awq-dequantize-performance",
args={},
)
)
def benchmark(qweight_row, qweight_col, provider):
dtype = torch.float16
device = torch.device("cuda")
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_row, qweight_col),
dtype=torch.int32,
device=device,
)
group_size = qweight_row
scales_row = qweight_row // group_size
scales_col = qweight_col * 8
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(scales_row, qweight_col),
dtype=torch.int32,
device=device,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm":
fn = lambda: vllm_awq_dequantize(
qweight.clone(), scales.clone(), qzeros.clone()
)
elif provider == "sglang":
fn = lambda: sglang_awq_dequantize(
qweight.clone(), scales.clone(), qzeros.clone()
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(qweight_row=3584, qweight_col=448)
benchmark.run(print_data=True)

View File

@@ -0,0 +1,145 @@
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
bs_range = [1, 8, 32, 64, 128, 256]
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
configs = list(itertools.product(bs_range, qlen_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
x_log=False,
line_arg="provider",
line_vals=[
"128 heads",
"64 heads",
"32 heads",
"16 heads",
],
line_names=[
"128 heads",
"64 heads",
"32 heads",
"16 heads",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="cutlass mla",
args={},
)
)
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
d = 576
dn = 64
dv = 512
h_q_map = {
"128": 128,
"64": 64,
"32": 32,
"16": 16,
}
parsed_h_q = next(
(value for key, value in h_q_map.items() if key in provider), None
)
if parsed_h_q is None:
raise ValueError(f"Unknown head configuration in provider: {provider}")
h_q = parsed_h_q
seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
max_seq_len = seq_lens.max().item()
block_num = (max_seq_len + block_size - 1) // block_size
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
qn = (
torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda")
* 100.0
)
qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0
block_table = torch.randint(
0,
batch_size * block_num,
(batch_size, block_num),
dtype=torch.int32,
device="cuda",
)
kv_cache = torch.randn(
block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda"
)
workspace_size = cutlass_mla_get_workspace_size(
block_num * block_size, batch_size, num_kv_splits=num_kv_splits
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: cutlass_mla_decode(
qn.transpose(0, 1),
qr,
kv_cache,
seq_lens,
block_table,
workspace,
1.44,
num_kv_splits,
),
quantiles=quantiles,
)
q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size()
gbps = (
lambda ms: (
q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size()
)
* 1e-9
/ (ms * 1e-3)
)
return gbps(ms), gbps(max_ms), gbps(min_ms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--block-sizes",
nargs="+",
type=int,
default=[1, 32, 64, 128],
help="List of batch sizes",
)
parser.add_argument(
"--num-kv-splits",
nargs="+",
type=int,
default=[-1],
help="List of batch sizes",
)
args = parser.parse_args()
for block_size in args.block_sizes:
for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_blackwell_mla_res",
block_size=block_size,
num_kv_splits=kv_split,
)
print("Benchmark finished!")

View File

@@ -0,0 +1,57 @@
import argparse
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_fused_a_gemm
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="TFLOPs",
plot_name="bf16 dsv3 fused a GEMM throughput",
args={},
)
)
def benchmark(num_tokens, impl):
kHdIn = 7168
kHdOut = 2112
M, K, N = num_tokens, kHdIn, kHdOut
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").transpose(0, 1)
quantiles = [0.5, 0.2, 0.8]
if impl == "torch":
def runner():
F.linear(mat_a, mat_b.T)
elif impl == "sgl-kernel":
def runner():
dsv3_fused_a_gemm(mat_a, mat_b)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")

View File

@@ -0,0 +1,127 @@
import argparse
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_router_gemm
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="TFLOPs",
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
args={},
)
)
def benchmark_bf16_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts
M, K = num_tokens, 7168
if impl == "torch-256" or impl == "sgl-kernel-256":
N = 256
elif impl == "torch-384" or impl == "sgl-kernel-384":
N = 384
else:
raise ValueError(f"Unknown impl: {impl}")
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
quantiles = [0.5, 0.2, 0.8]
if impl == "torch-256" or impl == "torch-384":
def runner():
F.linear(mat_a, mat_b)
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="TFLOPs",
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
args={},
)
)
def benchmark_float_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts
M, K = num_tokens, 7168
if impl == "torch-256" or impl == "sgl-kernel-256":
N = 256
elif impl == "torch-384" or impl == "sgl-kernel-384":
N = 384
else:
raise ValueError(f"Unknown impl: {impl}")
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
quantiles = [0.5, 0.2, 0.8]
if impl == "torch-256" or impl == "torch-384":
def runner():
F.linear(mat_a, mat_b).to(torch.float32)
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark_bf16_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
benchmark_float_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)

View File

@@ -0,0 +1,210 @@
import argparse
import copy
import csv
import itertools
import pytest
import torch
import triton
from flashinfer import mm_fp4
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def get_weight_shapes(args):
models_tps = args.tp_sizes
if models_tps == [4]:
return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]]
if models_tps == [8]:
return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]]
return [
[1024, 3584],
[7168, 256],
[7168, 2304],
[9216, 3584],
[512, 3584],
[7168, 128],
[7168, 1152],
[4608, 3584],
]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
],
# x_vals = [64],
x_log=False,
line_arg="provider",
line_vals=["cutlass", "cudnn", "trtllm"],
line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"],
styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")],
ylabel="latency (ms)",
plot_name="fp4_gemm_benchmark",
args={},
)
)
def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
M = batch_size
packed_k = K
K = 2 * packed_k
a_dtype = torch.randn((M, K), dtype=dtype, device="cuda")
b_dtype = torch.randn((N, K), dtype=dtype, device="cuda")
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
# print("a_fp4", a_fp4)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
res_fi = torch.empty((M, N), dtype=dtype, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "cutlass":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
),
quantiles=quantiles,
)
if provider == "cudnn":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
),
quantiles=quantiles,
)
if provider == "trtllm":
a_scale_interleaved = a_scale_interleaved.to(torch.uint8)
b_scale_interleaved = b_scale_interleaved.to(torch.uint8)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="trtllm",
),
quantiles=quantiles,
)
if correctness:
res_cutlass = cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="cudnn",
)
assert torch.allclose(
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
), "cudnn fp4 doesn't match cutlass fp4"
mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="trtllm",
)
assert torch.allclose(
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
), "trtllm fp4 doesn't match cutlass fp4"
if csv_file:
with open(csv_file, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow([provider, M, N, K, ms])
return ms, min_ms, max_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
parser.add_argument(
"--dtype",
type=torch.dtype,
default=torch.bfloat16,
help="Data type",
)
parser.add_argument(
"--correctness",
action="store_true",
help="Check correctness",
)
parser.add_argument(
"--csv",
type=str,
default="results_cutlass_cudnn.csv",
help="CSV file to save results",
)
args = parser.parse_args()
if args.csv:
with open(args.csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["provider", "m", "n", "k", "time_ms"])
NKs = get_weight_shapes(args)
for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp4_res",
N=N,
K=K,
dtype=args.dtype,
correctness=args.correctness,
csv_file=args.csv,
)
print("Benchmark finished!")

View File

@@ -0,0 +1,183 @@
import argparse
import copy
import itertools
import deep_gemm
import torch
import triton
from deep_gemm import get_col_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
)
def get_weight_shapes(args):
models_tps = list(itertools.product(args.models, args.tp_sizes))
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
# only support Deepseek-V3
SUPPORT_MODEL = ["deepseek-ai/DeepSeek-V3"]
weight_shapes = []
for model, tp_size in models_tps:
assert model in SUPPORT_MODEL
for t in total:
new_t = [t[0], t[1], model]
weight_shapes.append(new_t)
for n_t in n_tp:
new_t = [n_t[0] // tp_size, n_t[1], model]
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = [k_t[0], k_t[1] // tp_size, model]
weight_shapes.append(new_t)
return weight_shapes
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
def fp8_gemm_deepgemm(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""DeepGEMM implementation of FP8 GEMM"""
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
x_log=False,
line_arg="provider",
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
ylabel="GB/s",
plot_name="fp8 blockwise scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_blockwise_scaled_mm(
a_fp8, b_fp8, scale_a, scale_b, torch.float16
),
quantiles=quantiles,
)
if provider == "vllm":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: w8a8_block_fp8_matmul(
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
),
quantiles=quantiles,
)
if provider == "deepgemm":
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_deepgemm(
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
),
quantiles=quantiles,
)
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["deepseek-ai/DeepSeek-V3"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
NK_model_names = get_weight_shapes(args)
for N, K, model_name in NK_model_names:
if N % 128 != 0 or K % 128 != 0:
print(f"Skip {N=}, {K=} now")
continue
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp8_blockwise_res",
N=N,
K=K,
)
print("Benchmark finished!")

View File

@@ -0,0 +1,328 @@
import argparse
import random
from dataclasses import dataclass
from typing import List, Tuple
import deep_gemm
import torch
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
def get_m_alignment_for_contiguous_layout():
return 128
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2)
)
def construct_contiguous_grouped(
num_groups: int, expected_m_per_group: int, k: int, n: int
) -> Tuple[
int,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor],
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
alignment = get_m_alignment_for_contiguous_layout()
group_ms = [int(expected_m_per_group) for _ in range(num_groups)]
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16)
m_indices = torch.empty(m, device="cuda", dtype=torch.int32)
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
start = 0
for i, group_m in enumerate(group_ms):
actual_end = start + group_m
aligned_end = start + ceil_div(group_m, alignment) * alignment
m_indices[start:actual_end] = i
m_indices[actual_end:aligned_end] = -1
start = aligned_end
assert m % 4 == 0, f"TMA alignment error: {m}"
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (
torch.empty_like(y, dtype=torch.float8_e4m3fn),
torch.empty(
(num_groups, ceil_div(n, 128), k // 128), device="cuda", dtype=torch.float
),
)
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
return m, x_fp8, y_fp8, m_indices, out
def bench_deepgemm(
expected_m_per_group: int,
n: int,
k: int,
num_groups: int,
num_warmup: int,
num_run: int,
) -> Tuple[float, int]:
# construct tensors
m, x_fp8, y_fp8, m_indices, out = construct_contiguous_grouped(
num_groups, expected_m_per_group, k, n
)
def run_deepgemm():
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(x_fp8, y_fp8, out, m_indices)
# warmup
for _ in range(num_warmup):
run_deepgemm()
torch.cuda.synchronize()
# run
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: list[float] = []
start_event.record()
for _ in range(num_run):
run_deepgemm()
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
return avg, m
def bench_cutlass(
expected_m_per_group: int,
n: int,
k: int,
num_groups: int,
num_warmup: int,
num_run: int,
) -> Tuple[float, int]:
device = "cuda"
alignment = 16
n_g = ceil_div(n, alignment) * alignment
k_g = ceil_div(k, alignment) * alignment
out_dtype = torch.bfloat16
expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32)
problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32)
layout_sfa = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
layout_sfb = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
a_tensors = []
b_tensors = []
a_scales_tensors = []
b_scales_tensors = []
# TODO(@TianQiLin666666): Unique group_ms in all bench function
group_ms = [
alignment * ceil_div(int(expected_m_per_group), alignment)
for _ in range(num_groups)
]
for g in range(num_groups):
m_g = group_ms[g]
expert_offsets[g + 1] = expert_offsets[g] + m_g
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device))
b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t())
a_tensors.append(a_g)
b_tensors.append(b_g)
a_scales_tensors.append(a_scale)
b_scales_tensors.append(b_scale)
a_stack = torch.empty(
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
)
b_stack = torch.empty(
(num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
)
for g in range(num_groups):
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
b_stack[g] = b_tensors[g].t()
b_stack = b_stack.transpose(1, 2)
a_scale_stack = torch.empty(
(expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32
)
b_scale_stack = torch.empty(
(num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32
)
for g in range(num_groups):
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g]
b_scale_stack[g] = b_scales_tensors[g].t()
b_scale_stack = b_scale_stack.transpose(1, 2)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full(
(num_groups,), a_stack.stride(0), device=device, dtype=torch.int64
)
c_strides = torch.full(
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
a_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
b_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
out_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
a_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
b_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
def run_cutlass():
fp8_blockwise_scaled_grouped_mm(
c_out,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_stack,
b_stack,
a_scale_stack,
b_scale_stack,
a_strides,
a_strides,
c_strides,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets[:-1],
workspace,
)
# warmup
for _ in range(num_warmup):
run_cutlass()
torch.cuda.synchronize()
# run
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_run):
run_cutlass()
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
return avg, expert_offsets[-1]
def bench_sglang_triton(
expected_m_per_group: int,
n: int,
k: int,
num_groups: int,
num_warmup: int,
num_run: int,
) -> Tuple[float, int]:
pass
benchmark_kernels = {
"deepgemm": bench_deepgemm,
"cutlass": bench_cutlass,
# "triton": bench_sglang_triton,
}
@dataclass
class ShapeArg:
expected_m_per_group: int
n: int
k: int
num_groups: int
def benchmark_one_shape(
shape_args: List[ShapeArg],
num_warmup: int,
num_run: int,
):
for shape in shape_args:
print(
f"\nBenchmark: expected_m_per_group={shape.expected_m_per_group}, n={shape.n}, k={shape.k}, num_groups={shape.num_groups}"
)
for kernel_name, kernel_func in benchmark_kernels.items():
average_time, m = kernel_func(
shape.expected_m_per_group,
shape.n,
shape.k,
shape.num_groups,
num_warmup,
num_run,
)
print(f"{kernel_name}: {average_time} us")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-warmup", type=int, default=3)
parser.add_argument("--num-run", type=int, default=10)
shape_args = [
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
]
args = parser.parse_args()
benchmark_one_shape(shape_args, args.num_warmup, args.num_run)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,184 @@
import argparse
import copy
import itertools
from typing import Optional, Tuple
import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
return output, scale
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
line_names=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="fp8 scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
# M, N, K = batch_size, 4096, 8192
M = batch_size
a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
if "vllm-fp8" in provider:
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
),
quantiles=quantiles,
)
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
)
print("Benchmark finished!")

View File

@@ -0,0 +1,146 @@
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import int8_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["vllm", "sgl-kernel"],
line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="GB/s",
plot_name="int8 scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
a = to_int8(torch.randn((M, K), device="cuda") * 5)
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
bias = torch.randn((N,), device="cuda", dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
)
if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
)
gbps = (
lambda ms: (
(2 * M * N * K - M * N) * a.element_size()
+ (3 * M * N) * scale_a.element_size()
)
* 1e-9
/ (ms * 1e-3)
)
return gbps(ms), gbps(max_ms), gbps(min_ms)
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
)
print("Benchmark finished!")

View File

@@ -0,0 +1,299 @@
import itertools
import math
import torch
import triton
import triton.language as tl
from sgl_kernel import lightning_attention_decode
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
@triton.jit
def _decode_kernel(
Q,
K,
V,
KV,
Out,
S,
b: tl.constexpr,
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
d_original: tl.constexpr,
e: tl.constexpr,
e_original: tl.constexpr,
):
off_bh = tl.program_id(0)
off_h = off_bh % h
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
kv_offset = off_bh * d * e
s = tl.load(S + off_h)
ratio = tl.exp(-s)
d_idx = tl.arange(0, d)
e_idx = tl.arange(0, e)
# Create masks for original dimensions
d_mask = d_idx < d_original
e_mask = e_idx < e_original
# Load with masking
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
# Load KV with 2D masking
kv = tl.load(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
mask=(d_mask[:, None] & e_mask[None, :]),
other=0.0,
)
# Compute outer product using element-wise operations
k_v_prod = k[:, None] * v[None, :]
kv = ratio * kv + k_v_prod
# Store KV with 2D masking
tl.store(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
kv.to(KV.dtype.element_ty),
mask=(d_mask[:, None] & e_mask[None, :]),
)
# Compute matrix-vector multiplication using element-wise operations and reduction
o = tl.sum(q[:, None] * kv, axis=0)
# Store output with masking
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
def triton_lightning_attn_decode(q, k, v, kv, s):
"""Triton implementation of Lightning Attention decode operation"""
b, h, n, d = q.shape
e = v.shape[-1]
assert n == 1, "Sequence length must be 1 in decode mode"
# Get padded dimensions (power of 2)
d_padded = next_power_of_2(d)
e_padded = next_power_of_2(e)
# Create output tensor (padded)
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
# Create padded tensors without actually padding the data
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
kv_padded = torch.empty(
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
)
# Copy data to padded tensors
q_padded[..., :d] = q
k_padded[..., :d] = k
v_padded[..., :e] = v
kv_padded[..., :d, :e] = kv
# Launch kernel
grid = (b * h, 1)
_decode_kernel[grid](
q_padded,
k_padded,
v_padded,
kv_padded,
o_padded,
s,
b=b,
h=h,
n=n,
d=d_padded,
d_original=d,
e=e_padded,
e_original=e,
)
# Get unpadded outputs
o = o_padded[..., :e]
kv_out = kv_padded[..., :d, :e]
return o, kv_out
def lightning_attention_decode_naive(q, k, v, past_kv, slope):
"""Naive implementation of lightning attention decode"""
original_dtype = q.dtype
ratio = torch.exp(-slope) # [h, 1, 1]
kv = past_kv
b, h, n, d = q.shape
output = []
for i in range(n):
kv = ratio * kv.to(torch.float32) + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
qkv = torch.einsum(
"... n e, ... e d -> ... n d",
q[:, :, i : i + 1].to(torch.float32),
kv.to(torch.float32),
)
output.append(qkv)
output = torch.cat(output, dim=-2)
return output.to(original_dtype), kv
def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
def calculate_diff(batch_size):
dtype = torch.bfloat16
device = torch.device("cuda")
num_heads = 64
head_dim = 96
seq_len = 1
q = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
k = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
v = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
slope = torch.randn(num_heads, 1, 1, device=device)
output_naive, new_kv_naive = lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
)
output_kernel = torch.empty_like(output_naive)
new_kv_kernel = torch.empty_like(new_kv_naive)
lightning_attention_decode_kernel(
q.clone(),
k.clone(),
v.clone(),
past_kv.clone(),
slope.clone(),
output_kernel,
new_kv_kernel,
)
output_triton, new_kv_triton = triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
)
if (
torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [i for i in range(1, 65)] # 1 to 128
configs = [(bs,) for bs in batch_size_range]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["naive", "kernel", "triton"],
line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="us",
plot_name="lightning-attention-decode-performance",
args={},
)
)
def benchmark(batch_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
num_heads = 64
head_dim = 96
seq_len = 1
q = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
k = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
v = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
slope = torch.randn(num_heads, 1, 1, device=device)
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
quantiles=quantiles,
)
elif provider == "kernel":
output = torch.empty(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: lightning_attention_decode_kernel(
q.clone(),
k.clone(),
v.clone(),
past_kv.clone(),
slope.clone(),
output,
new_kv,
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
help="Path to save lightning attention decode benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(batch_size=4)
# Run performance benchmark
benchmark.run(print_data=True)

View File

@@ -0,0 +1,401 @@
import argparse
import itertools
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
USE_RANDOM_PERM = False
def ceil_div(a, b):
return (a + b - 1) // b
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts,)
tokens_cnts = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1,)](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
topk_ids = torch.stack(
[
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
for _ in range(num_tokens)
]
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids_cuda = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids_cuda.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids_cuda = torch.zeros(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
sorted_ids_triton.fill_(topk_ids.numel())
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
sorted_ids_vllm = torch.empty_like(sorted_ids_cuda)
sorted_ids_vllm.fill_(topk_ids.numel())
expert_ids_vllm = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda)
# compare the performance of cuda, triton and vllm implementation
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_cuda,
expert_ids_cuda,
num_tokens_post_pad_cuda,
cumsum_buffer,
)
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
)
try:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True
except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
):
print("✅ SGL and Triton implementations match")
else:
print("❌ SGL and Triton implementations do not match")
print("SGL expert_ids:", expert_ids_cuda)
print("Triton expert_ids:", expert_ids_triton)
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
if (
vllm_works
and torch.allclose(expert_ids_cuda, expert_ids_vllm)
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm)
):
print("✅ SGL and VLLM implementations match")
else:
if not vllm_works:
print("⚠️ VLLM comparison skipped due to failure")
else:
print("❌ SGL and VLLM implementations do not match")
print("SGL expert_ids:", expert_ids_cuda)
print("VLLM expert_ids:", expert_ids_vllm)
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
# Test range
num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
num_experts_range = [8, 32, 64, 128, 256]
topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda")
for i in range(num_tokens):
topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[
:topk
]
return topk_ids
def sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
pad_sorted_token_ids=False,
):
if not pad_sorted_token_ids:
sorted_ids.fill_(topk_ids.numel())
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
cumsum_buffer,
pad_sorted_token_ids,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sgl", "sgl_fusion", "triton"],
line_names=["SGL", "SGL Fusion", "Triton"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="us",
plot_name="moe-align-block-size-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
block_size = 128
if USE_RANDOM_PERM:
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
else:
topk_ids = torch.randint(
0,
num_experts,
(num_tokens, topk),
dtype=torch.int32,
device="cuda",
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
),
quantiles=quantiles,
)
elif provider == "sgl_fusion":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
pad_sorted_token_ids=True,
),
quantiles=quantiles,
)
elif provider == "triton":
sorted_ids.fill_(topk_ids.numel())
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/moe_align_blocks/",
help="Path to save moe align benchmark results",
)
parser.add_argument(
"--num_experts",
type=int,
default=256,
choices=[8, 16, 32, 64, 128, 256],
help="Number of experts for benchmark",
)
parser.add_argument(
"--topk",
type=int,
default=8,
choices=[2, 4, 8],
help="Top-k value for benchmark",
)
parser.add_argument(
"--skip_full_benchmark",
action="store_true",
help="Only run the calculate_diff function, skip full benchmarking",
)
args = parser.parse_args()
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
if not args.skip_full_benchmark:
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
benchmark.run(print_data=True)

View File

@@ -0,0 +1,93 @@
import torch
import triton
from sgl_kernel import ep_moe_post_reorder
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
configs = [(bs,) for bs in batch_sizes]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["cuda", "triton"],
line_names=["CUDA Kernel", "Triton Kernel"],
styles=[("green", "-"), ("orange", "-")],
ylabel="us",
plot_name="ep-moe-post-reorder-performance",
args={},
)
)
def benchmark(batch_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512
def alloc_tensors():
down_output = torch.randn(
batch_size * topk, hidden_size, dtype=dtype, device=device
)
output = torch.zeros(batch_size, hidden_size, dtype=dtype, device=device)
src2dst = torch.randint(
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
)
topk_ids = torch.randint(
start_expert_id,
end_expert_id + 1,
(batch_size, topk),
dtype=torch.int32,
device=device,
)
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
return down_output, output, src2dst, topk_ids, topk_weights
quantiles = [0.5, 0.2, 0.8]
if provider == "cuda":
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
def run_cuda():
ep_moe_post_reorder(
d_out,
out,
s2d,
tk_ids,
tk_weights,
start_expert_id,
end_expert_id,
topk,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
elif provider == "triton":
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
def run_triton():
post_reorder_triton_kernel[(batch_size,)](
d_out.view(-1),
out.view(-1),
s2d.view(-1),
tk_ids.view(-1),
tk_weights.view(-1),
start_expert_id,
end_expert_id,
topk,
hidden_size,
0,
block_size,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
else:
raise ValueError(f"Unknown provider: {provider}")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,103 @@
import torch
import triton
from sgl_kernel import ep_moe_pre_reorder
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
configs = [(bs,) for bs in batch_sizes]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["cuda", "triton"],
line_names=["CUDA Kernel", "Triton Kernel"],
styles=[("green", "-"), ("orange", "-")],
ylabel="us",
plot_name="ep-moe-pre-reorder-performance",
args={},
)
)
def benchmark(batch_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
hidden_size, topk, start_expert_id, end_expert_id, block_size = (
4096,
8,
0,
255,
512,
)
# Allocate fresh tensors for every run to match bench_moe_fused_gate style
def alloc_tensors():
input_ = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
gateup_input = torch.zeros(
batch_size * topk, hidden_size, dtype=dtype, device=device
)
src2dst = torch.randint(
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
)
topk_ids = torch.randint(
start_expert_id,
end_expert_id + 1,
(batch_size, topk),
dtype=torch.int32,
device=device,
)
a1_scales = torch.rand(
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
)
return input_, gateup_input, src2dst, topk_ids, a1_scales
quantiles = [0.5, 0.2, 0.8]
if provider == "cuda":
inp, gout, s2d, tk_ids, scales = alloc_tensors()
def run_cuda():
ep_moe_pre_reorder(
inp,
gout,
s2d,
tk_ids,
scales,
start_expert_id,
end_expert_id,
topk,
True,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
elif provider == "triton":
inp, gout, s2d, tk_ids, scales = alloc_tensors()
def run_triton():
pre_reorder_triton_kernel[(batch_size,)](
inp.view(-1),
gout.view(-1),
s2d.view(-1),
tk_ids.view(-1),
scales,
start_expert_id,
end_expert_id,
topk,
hidden_size,
block_size,
True,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
else:
raise ValueError(f"Unknown provider: {provider}")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,77 @@
import itertools
import math
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_fused_gate
from sglang.srt.layers.moe.topk import biased_grouped_topk
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
return biased_grouped_topk(
scores,
scores,
bias,
topk=topk,
renormalize=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
routed_scaling_factor=2.5, # DeepSeek-R1 : 2.5, Kimi K2: 2.872
)
def biased_grouped_topk_org_fuse_kernel(
scores, bias, num_expert_group, topk_group, topk
):
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
configs = [(sq,) for sq in seq_length_range]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["seq_length"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["original", "kernel"],
line_names=["Original", "SGL Kernel"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name="moe-fused-gate-performance",
args={},
)
)
def benchmark(seq_length, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
bias = torch.rand(num_experts, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "original":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: biased_grouped_topk_org(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
elif provider == "kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: biased_grouped_topk_org_fuse_kernel(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,92 @@
import itertools
import torch
import triton
from sgl_kernel import ep_moe_silu_and_mul
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel
batch_size_range = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
hidden_size_range = [1024, 2048, 4096, 8192]
block_size_range = [128, 256, 512]
configs = list(itertools.product(batch_size_range, hidden_size_range, block_size_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "hidden_size", "block_size"],
x_vals=[list(cfg) for cfg in configs],
line_arg="provider",
line_vals=["cuda", "triton"],
line_names=["CUDA Kernel", "Triton Kernel"],
styles=[("green", "-"), ("orange", "-")],
ylabel="us",
plot_name="ep-moe-silu-and-mul-performance",
args={},
)
)
def benchmark(batch_size, hidden_size, block_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
half_hidden_size = hidden_size // 2
start_expert_id, end_expert_id = 0, 255
block_size = 512
quantiles = [0.5, 0.2, 0.8]
def alloc_tensors():
gateup_output = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
down_input = torch.empty(
batch_size, half_hidden_size, dtype=dtype, device=device
)
reorder_topk_ids = torch.randint(
start_expert_id,
end_expert_id + 1,
(batch_size,),
dtype=torch.int32,
device=device,
)
scales = torch.rand(
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
)
return gateup_output, down_input, reorder_topk_ids, scales
if provider == "cuda":
gateup, down, ids, scales = alloc_tensors()
def run_cuda():
ep_moe_silu_and_mul(
gateup,
down,
ids,
scales,
start_expert_id,
end_expert_id,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
elif provider == "triton":
gateup, down, ids, scales = alloc_tensors()
def run_triton():
silu_and_mul_triton_kernel[(batch_size,)](
gateup.view(-1),
down.view(-1),
hidden_size,
ids,
scales,
start_expert_id,
end_expert_id,
block_size,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
else:
raise ValueError(f"Unknown provider: {provider}")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,116 @@
import itertools
import pytest
import torch
import triton
from sgl_kernel import topk_softmax
from vllm import _custom_ops as vllm_custom_ops
def vllm_topk_softmax(gating_output, topk):
num_tokens, num_experts = gating_output.shape
topk_weights = torch.empty(
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
)
topk_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
token_expert_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
torch.ops._moe_C.topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output
)
return topk_weights, topk_indices
def sglang_topk_softmax(gating_output, topk):
num_tokens, num_experts = gating_output.shape
topk_weights = torch.empty(
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
)
topk_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
topk_softmax(
topk_weights=topk_weights,
topk_ids=topk_indices,
gating_output=gating_output,
)
return topk_weights, topk_indices
def calculate_diff(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
)
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
indices_match = torch.equal(indices_vllm, indices_sglang)
if (
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
and indices_match
):
print("✅ VLLM and SGLang topk_softmax implementations match")
else:
print(
f"❌ Implementations differ: Weights diff={weights_diff}, Indices match={indices_match}"
)
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
num_experts_range = [32, 64, 128, 256, 12, 512]
topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang", "vllm"],
line_names=["SGLang", "VLLM"],
styles=[("blue", "-"), ("green", "-")],
ylabel="Latency (us)",
plot_name="topk-softmax-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
)
if provider == "vllm" or provider == "vllm1":
fn = lambda: vllm_topk_softmax(gating_output, topk)
elif provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
configs = [
(20, 256, 4),
(20, 256, 8),
(20, 12, 4),
(20, 12, 1),
(20, 512, 4),
(20, 512, 1),
]
for num_tokens, num_experts, topk in configs:
calculate_diff(num_tokens, num_experts, topk)
benchmark.run(print_data=True)

View File

@@ -0,0 +1,172 @@
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=[
"sglang-fp4-fp16",
"sglang-fp4-bf16",
],
line_names=[
"sglang-fp4-fp16",
"sglang-fp4-bf16",
],
styles=[("green", "-"), ("blue", "-")],
ylabel="TFLOPS",
plot_name="fp4 block scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
# M, N, K = batch_size, 4096, 8192
run_step = 100
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
M = batch_size
a = torch.randn((M, K), dtype=dtype, device="cuda")
b = torch.randn((N, K), dtype=dtype, device="cuda")
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b, b_global_scale)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Bridging the gap between CPU and GPU
for _ in range(25):
c = a @ b.t()
# Warmup
for _ in range(5):
cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
start_event.record()
for _ in range(run_step):
cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
ms = start_event.elapsed_time(end_event) / run_step
tflops = lambda ms: (2 * M * N * K) * 1e-9 / ms
return tflops(ms)
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K
)
print("Benchmark finished!")

View File

@@ -0,0 +1,98 @@
import itertools
import math
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def vllm_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input, scale)
def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
return output, scale
def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between VLLM and SGLang implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
scale_diff = torch.abs(vllm_scale - sglang_scale).item()
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
configs = list(itertools.product(batch_size_range, seq_len_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "sglang"],
line_names=["VLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-tensor-quant-fp8-performance",
args={},
)
)
def benchmark(batch_size, seq_len, provider):
dtype = torch.float16
device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm":
fn = lambda: vllm_scaled_fp8_quant(x.clone())
elif provider == "sglang":
fn = lambda: sglang_scaled_fp8_quant(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark.run(print_data=True)

View File

@@ -0,0 +1,98 @@
import itertools
import time
from functools import partial
from pathlib import Path
import torch
import triton
from sglang.srt.bench_utils import bench_kineto
from sglang.srt.layers.quantization.fp8_kernel import (
create_per_token_group_quant_fp8_output_scale,
)
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
)
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
group_size_range = [128] # For DeepSeek V3/R1
# TODO test int8
dst_dtype_range = [fp8_type_]
flags_range = [
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
),
]
configs = list(
itertools.product(
num_tokens_range,
hidden_dim_range,
group_size_range,
dst_dtype_range,
flags_range,
)
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
x_vals=configs,
line_arg="provider",
line_vals=["triton", "sglang"],
line_names=["Triton", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-group-quant-8bit-performance",
args={},
)
)
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
if flags["scale_ue8m0"] and group_size != 128:
return
device = torch.device("cuda")
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
fn, kernel_names = {
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
"sglang": (
sglang_per_token_group_quant_8bit,
"per_token_group_quant_8bit_kernel",
),
}[provider]
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
return time_s * 1e6
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,177 @@
import itertools
from typing import Optional, Tuple
import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# Get correct FP8 E4M3 maximum value
if _is_hip:
FP8_E4M3_MAX = 224.0 # ROCM uses 224.0
else:
# For CUDA, get the actual max value from the type
FP8_E4M3_MAX = float(torch.finfo(fp8_type_).max)
def torch_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pure PyTorch reference implementation for per-token FP8 quantization."""
device = input.device
dtype = input.dtype
# Find max absolute value per token (row) - exactly like CUDA kernel
max_vals = torch.abs(input).max(dim=1)[0] # [num_tokens]
# Calculate scale per token - exactly like CUDA kernel: scale = max_value / FP8_E4M3_MAX
scales = max_vals / FP8_E4M3_MAX # [num_tokens]
# No special zero handling - directly compute 1.0 / scale like CUDA kernel
scale_inv = 1.0 / scales # [num_tokens]
# Quantize: input * scale_inv, then clamp to FP8 range
quantized_float = input * scale_inv.unsqueeze(1) # Broadcast scale_inv
quantized_float = torch.clamp(quantized_float, -FP8_E4M3_MAX, FP8_E4M3_MAX)
# Convert to FP8 - use more explicit conversion
quantized_fp8 = quantized_float.to(fp8_type_)
return quantized_fp8, scales
def vllm_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
def sglang_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
sgl_per_token_quant_fp8(input, output, scale)
return output, scale
def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
"""Compare Torch reference, VLLM, and SGLang implementations."""
device = torch.device("cuda")
x = torch.rand(
(batch_size * seq_len, hidden_dim), dtype=torch.float16, device=device
)
# Get all three implementations
torch_out, torch_scale = torch_per_token_quant_fp8(x)
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
# Compare scales
torch_vllm_scale_diff = torch.abs(torch_scale - vllm_scale).mean().item()
torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
vllm_sglang_scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
print(f"Scale differences:")
print(f" Torch vs VLLM: {torch_vllm_scale_diff:.8f}")
print(f" Torch vs SGLang: {torch_sglang_scale_diff:.8f}")
print(f" VLLM vs SGLang: {vllm_sglang_scale_diff:.8f}")
# Compare outputs
torch_vllm_out_diff = torch.abs(torch_out.float() - vllm_out.float()).mean().item()
torch_sglang_out_diff = (
torch.abs(torch_out.float() - sglang_out.float()).mean().item()
)
vllm_sglang_out_diff = (
torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
)
print(f"Output differences:")
print(f" Torch vs VLLM: {torch_vllm_out_diff:.8f}")
print(f" Torch vs SGLang: {torch_sglang_out_diff:.8f}")
print(f" VLLM vs SGLang: {vllm_sglang_out_diff:.8f}")
# Check tolerances
rtol, atol = 1e-3, 1e-5
torch_vllm_match = torch.allclose(
torch_out.float(), vllm_out.float(), rtol=rtol, atol=atol
) and torch.allclose(torch_scale, vllm_scale, rtol=rtol, atol=atol)
torch_sglang_match = torch.allclose(
torch_out.float(), sglang_out.float(), rtol=rtol, atol=atol
) and torch.allclose(torch_scale, sglang_scale, rtol=rtol, atol=atol)
if hidden_dim == 1368:
rtol = 1e-2
# we found vllm sglang has diff when hidden dim is not dividable by 16
# and we believe SGLang is closer to Torch implementation
vllm_sglang_match = torch.allclose(
vllm_out.float(), sglang_out.float(), rtol=rtol, atol=atol
) and torch.allclose(vllm_scale, sglang_scale, rtol=rtol, atol=atol)
print(f"Matches (rtol={rtol}, atol={atol}):")
print(f" Torch vs VLLM: {'' if torch_vllm_match else ''}")
print(f" Torch vs SGLang: {'' if torch_sglang_match else ''}")
print(f" VLLM vs SGLang: {'' if vllm_sglang_match else ''}")
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
hidden_dim_range = [1368, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "hidden_dim"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "vllm", "sglang"],
line_names=["Torch Reference", "VLLM", "SGL Kernel"],
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
dtype = torch.float16
device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "vllm":
fn = lambda: vllm_per_token_quant_fp8(x.clone())
elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
# Test various hidden dimensions for correctness
test_dims = [1368, 2048, 4096]
for dim in test_dims:
calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim)
print("\n" + "=" * 60)
print("Starting performance benchmark...")
benchmark_quantization.run(print_data=True)

View File

@@ -0,0 +1,198 @@
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import (
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
line_names=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="ms",
plot_name="FP16_vs_W8A8_vs_Qserve_W4A8_GEMM",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
# For W8A8
a = to_int8(torch.randn((M, K), device="cuda") * 5)
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
a_fp16 = a.to(torch.float16)
b_fp16 = b.to(torch.float16)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
# For Qserve W4A8 per channel
a_qserve_chn = a
# two int4s pack into one int8
b_qserve_chn = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
# b_qserve_chn = b.t().contiguous()
scale_a_qserve_chn = scale_a.to(torch.float16)
scale_b_qserve_chn = scale_b.to(torch.float16)
szero_b_qserve_chn = torch.randn((N,), device="cuda", dtype=torch.float16)
a_sum_qserve_chn = torch.randn((M,), device="cuda", dtype=torch.float16)
# For Qserve W4A8 per group
group_size = 128
assert K % group_size == 0, "K must be divisible by group_size"
a_qserve_group = a
# two int4s pack into one int8
b_qserve_group = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
# b_qserve_group = b.t().contiguous()
scale_a_qserve_group = scale_a.to(torch.float16)
scale_b_qserve_group = scale_b.to(torch.float16)
scale_i8_b_qserve_group = to_int8(
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
)
zero_i8_b_qserve_group = to_int8(
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
)
quantiles = [0.5, 0.2, 0.8]
if provider == "FP16":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.matmul(a_fp16, b_fp16),
quantiles=quantiles,
)
if provider == "W8A8":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Channel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: qserve_w4a8_per_chn_gemm(
a_qserve_chn,
b_qserve_chn,
scale_b_qserve_chn,
scale_a_qserve_chn,
szero_b_qserve_chn,
a_sum_qserve_chn,
),
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Group":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: qserve_w4a8_per_group_gemm(
a_qserve_group,
b_qserve_group,
zero_i8_b_qserve_group,
scale_i8_b_qserve_group,
scale_b_qserve_group,
scale_a_qserve_group,
),
quantiles=quantiles,
)
return ms, max_ms, min_ms
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_qserve_w4a8_gemm_res",
N=N,
K=K,
)
print("Benchmark finished!")

View File

@@ -0,0 +1,96 @@
import itertools
import torch
import triton
from sgl_kernel import FusedSetKVBufferArg
from sgl_kernel.testing.rotary_embedding import (
FlashInferRotaryEmbedding,
MHATokenToKVPool,
RotaryEmbedding,
create_inputs,
)
from sglang.srt.bench_utils import bench_kineto
configs = [
(batch_size, seq_len, save_kv_cache)
for batch_size, seq_len in (
(1, 1),
(32, 1),
(128, 1),
(512, 1),
(2, 512),
(4, 4096),
)
for save_kv_cache in (False, True)
]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "save_kv_cache"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang"],
line_names=["SGL Kernel"],
styles=[("green", "-")],
ylabel="us",
plot_name="bench_rotary_embedding",
args={},
)
)
def benchmark(batch_size, seq_len, save_kv_cache, provider):
device = torch.device("cuda")
num_q_heads = 32
num_kv_heads = 8
head_size = 64
dtype = torch.bfloat16
config = dict(
head_size=head_size,
rotary_dim=64,
max_position_embeddings=4096,
base=8000,
is_neox_style=True,
dtype=dtype,
)
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
inputs = create_inputs(
head_size=head_size,
batch_size=batch_size,
seq_len=seq_len,
device=device,
dtype=dtype,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
)
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
bench_fn = lambda: rope_flashinfer.forward_cuda(
inputs["pos_ids"],
query_flashinfer,
key_flashinfer,
fused_set_kv_buffer_arg=(
FusedSetKVBufferArg(
value=inputs["value"],
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
k_scale=None,
v_scale=None,
cache_loc=inputs["out_cache_loc"],
)
if save_kv_cache
else None
),
)
time_s = bench_kineto(bench_fn, kernel_names="BatchQKApplyRotaryPosIds")
return time_s * 1e6
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,128 @@
import itertools
import sgl_kernel
import torch
import triton
import triton.testing
def torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k, top_p, eps=1e-4
):
"""Reference PyTorch implementation of joint top-k top-p sampling."""
batch_size, vocab_size = normalized_prob.shape
samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)
for i in range(batch_size):
p_val = top_p[i].item()
k_val = top_k[i].item()
# top-p mask
sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask_top_p = torch.zeros(
vocab_size, dtype=torch.int32, device=normalized_prob.device
)
mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())
# top-k mask
sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
pivot = sorted_prob_desc[k_val - 1]
mask_top_k = (normalized_prob[i] >= pivot).int()
# joint mask
mask = torch.minimum(mask_top_p, mask_top_k).bool()
# sample from masked probs
masked_probs = normalized_prob[i] * mask
masked_probs = masked_probs / masked_probs.sum()
idx = torch.multinomial(masked_probs, 1)
samples[i] = idx
return samples
def calculate_diff(batch_size, vocab_size, p):
"""Compare Torch reference and SGLang kernel for correctness."""
torch.manual_seed(42)
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
device = torch.device("cuda")
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
top_p_tensor = torch.full((batch_size,), p, device=device)
top_k_tensor = torch.full((batch_size,), k, device=device)
torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k_tensor, top_p_tensor
)
sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
)
# parameter space
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "vocab_size", "p"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "sglang"],
line_names=["Torch Reference", "SGL Kernel"],
styles=[("red", "-"), ("green", "-")],
ylabel="us",
plot_name="top-k-top-p-joint-sampling-performance",
args={},
)
)
def benchmark_sampling(batch_size, vocab_size, p, provider):
torch.manual_seed(42)
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
device = torch.device("cuda")
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
top_p_tensor = torch.full((batch_size,), p, device=device)
top_k_tensor = torch.full((batch_size,), k, device=device)
if provider == "torch":
fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob.clone(), top_k_tensor, top_p_tensor
)
elif provider == "sglang":
fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob.clone(),
top_k_tensor,
top_p_tensor,
filter_apply_order="joint",
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
# Correctness check
for cfg in configs:
calculate_diff(*cfg)
print("\n" + "=" * 60)
print("Starting performance benchmark...")
benchmark_sampling.run(print_data=True)