adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
153
sgl-kernel/benchmark/bench_activation.py
Normal file
153
sgl-kernel/benchmark/bench_activation.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Benchmarks SGLang kernels versus vLLM across
|
||||
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
|
||||
import argparse
|
||||
import itertools
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import gelu_quick # activation-only kernel
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
if not hasattr(vllm_ops, "silu_and_mul"):
|
||||
vllm_ops = torch.ops._C
|
||||
|
||||
|
||||
def str2int_list(arg: str) -> List[int]:
|
||||
if arg in ("", None):
|
||||
return []
|
||||
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
|
||||
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
|
||||
return [int(x) for x in arg.split(",")]
|
||||
|
||||
|
||||
def calculate_diff(
|
||||
kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int
|
||||
) -> bool:
|
||||
"""Compare vLLM with SGLang for one shape."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# activation-only quick GELU
|
||||
if kernel == "gelu_quick":
|
||||
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
ref_out = torch.zeros_like(x)
|
||||
getattr(vllm_ops, kernel)(ref_out, x)
|
||||
test_out = getattr(sgl_kernel, kernel)(x)
|
||||
# fused activation x mul kernels
|
||||
else:
|
||||
x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device)
|
||||
ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
getattr(vllm_ops, kernel)(ref_out, x)
|
||||
test_out = getattr(sgl_kernel, kernel)(x)
|
||||
|
||||
ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5)
|
||||
tag = "✅ match" if ok else "❌ mismatch"
|
||||
print(
|
||||
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
|
||||
f"L={seq_len:3d} | D={dim:5d}] {tag}"
|
||||
)
|
||||
return ok
|
||||
|
||||
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
|
||||
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
|
||||
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||
default_dims = [2**i for i in range(7, 15)] # 128...16384
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"],
|
||||
x_vals=[],
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang", "speedup"],
|
||||
line_names=["vLLM", "SGL Kernel", "Speed-up (x)"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "--")],
|
||||
ylabel="µs (median) or × (speed-up)",
|
||||
plot_name="activation-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
device = torch.device("cuda")
|
||||
in_mult = 1 if kernel == "gelu_quick" else 2
|
||||
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
|
||||
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
|
||||
vllm_kernel = getattr(vllm_ops, kernel)
|
||||
sglang_kernel = getattr(sgl_kernel, kernel)
|
||||
|
||||
def baseline():
|
||||
tmp = y0.clone()
|
||||
vllm_kernel(tmp, x)
|
||||
return tmp
|
||||
|
||||
def sglang():
|
||||
return sglang_kernel(x)
|
||||
|
||||
# one-time correctness check
|
||||
if provider == "vllm" and not calculate_diff(
|
||||
kernel, dtype, batch_size, seq_len, dim
|
||||
):
|
||||
raise ValueError("Mismatch – abort benchmark")
|
||||
|
||||
# timing helper
|
||||
def timed(fn):
|
||||
for _ in range(5):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
||||
|
||||
if provider == "vllm":
|
||||
return timed(baseline)
|
||||
if provider == "sglang":
|
||||
return timed(sglang)
|
||||
|
||||
# provider == "speedup"
|
||||
t_ref, _, _ = timed(baseline)
|
||||
t_sgl, _, _ = timed(sglang)
|
||||
spd = t_ref / t_sgl
|
||||
return (spd, spd, spd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser("Activation kernel benchmark")
|
||||
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
|
||||
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
|
||||
p.add_argument("--dims", type=str2int_list, default=default_dims)
|
||||
p.add_argument("--verify_only", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
# coerce lists
|
||||
if isinstance(args.batch_sizes, str):
|
||||
args.batch_sizes = str2int_list(args.batch_sizes)
|
||||
if isinstance(args.seq_lens, str):
|
||||
args.seq_lens = str2int_list(args.seq_lens)
|
||||
if isinstance(args.dims, str):
|
||||
args.dims = str2int_list(args.dims)
|
||||
|
||||
# patch perf_report grid
|
||||
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims)
|
||||
if hasattr(benchmark, "benchmarks"):
|
||||
benchmark.benchmarks.x_vals = benchmark_grid
|
||||
else:
|
||||
benchmark.benchmark.x_vals = benchmark_grid
|
||||
|
||||
if args.verify_only:
|
||||
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
|
||||
print("✅ sanity pass" if ok else "❌ mismatch")
|
||||
else:
|
||||
benchmark.run(print_data=True)
|
||||
118
sgl-kernel/benchmark/bench_awq_dequant.py
Normal file
118
sgl-kernel/benchmark/bench_awq_dequant.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import itertools
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import awq_dequantize
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def vllm_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
|
||||
|
||||
def sglang_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
return awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
|
||||
def calculate_diff(qweight_row: int, qweight_col: int):
|
||||
"""Calculate difference between VLLM and SGLang implementations."""
|
||||
device = torch.device("cuda")
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
group_size = qweight_row
|
||||
scales_row = qweight_row // group_size
|
||||
scales_col = qweight_col * 8
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(scales_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
|
||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||
|
||||
if torch.allclose(
|
||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
|
||||
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
|
||||
|
||||
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["qweight_row", "qweight_col"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="awq-dequantize-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(qweight_row, qweight_col, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
group_size = qweight_row
|
||||
scales_row = qweight_row // group_size
|
||||
scales_col = qweight_col * 8
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(scales_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
fn = lambda: vllm_awq_dequantize(
|
||||
qweight.clone(), scales.clone(), qzeros.clone()
|
||||
)
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_awq_dequantize(
|
||||
qweight.clone(), scales.clone(), qzeros.clone()
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(qweight_row=3584, qweight_col=448)
|
||||
benchmark.run(print_data=True)
|
||||
145
sgl-kernel/benchmark/bench_cutlass_mla.py
Normal file
145
sgl-kernel/benchmark/bench_cutlass_mla.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
|
||||
|
||||
bs_range = [1, 8, 32, 64, 128, 256]
|
||||
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
configs = list(itertools.product(bs_range, qlen_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"128 heads",
|
||||
"64 heads",
|
||||
"32 heads",
|
||||
"16 heads",
|
||||
],
|
||||
line_names=[
|
||||
"128 heads",
|
||||
"64 heads",
|
||||
"32 heads",
|
||||
"16 heads",
|
||||
],
|
||||
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
|
||||
ylabel="GB/s",
|
||||
plot_name="cutlass mla",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
||||
d = 576
|
||||
dn = 64
|
||||
dv = 512
|
||||
|
||||
h_q_map = {
|
||||
"128": 128,
|
||||
"64": 64,
|
||||
"32": 32,
|
||||
"16": 16,
|
||||
}
|
||||
parsed_h_q = next(
|
||||
(value for key, value in h_q_map.items() if key in provider), None
|
||||
)
|
||||
|
||||
if parsed_h_q is None:
|
||||
raise ValueError(f"Unknown head configuration in provider: {provider}")
|
||||
h_q = parsed_h_q
|
||||
|
||||
seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
|
||||
max_seq_len = seq_lens.max().item()
|
||||
block_num = (max_seq_len + block_size - 1) // block_size
|
||||
|
||||
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
|
||||
# One 128-wide tile can hold (128 // block_size) small blocks.
|
||||
pack_factor = 128 // block_size
|
||||
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
||||
|
||||
qn = (
|
||||
torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda")
|
||||
* 100.0
|
||||
)
|
||||
qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0
|
||||
block_table = torch.randint(
|
||||
0,
|
||||
batch_size * block_num,
|
||||
(batch_size, block_num),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
kv_cache = torch.randn(
|
||||
block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
block_num * block_size, batch_size, num_kv_splits=num_kv_splits
|
||||
)
|
||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: cutlass_mla_decode(
|
||||
qn.transpose(0, 1),
|
||||
qr,
|
||||
kv_cache,
|
||||
seq_lens,
|
||||
block_table,
|
||||
workspace,
|
||||
1.44,
|
||||
num_kv_splits,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size()
|
||||
|
||||
gbps = (
|
||||
lambda ms: (
|
||||
q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size()
|
||||
)
|
||||
* 1e-9
|
||||
/ (ms * 1e-3)
|
||||
)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--block-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1, 32, 64, 128],
|
||||
help="List of batch sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-kv-splits",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[-1],
|
||||
help="List of batch sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
for block_size in args.block_sizes:
|
||||
for kv_split in args.num_kv_splits:
|
||||
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_blackwell_mla_res",
|
||||
block_size=block_size,
|
||||
num_kv_splits=kv_split,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
57
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
Normal file
57
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import dsv3_fused_a_gemm
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch", "sgl-kernel"],
|
||||
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
|
||||
styles=[("blue", "-"), ("orange", "-")],
|
||||
ylabel="TFLOPs",
|
||||
plot_name="bf16 dsv3 fused a GEMM throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, impl):
|
||||
kHdIn = 7168
|
||||
kHdOut = 2112
|
||||
M, K, N = num_tokens, kHdIn, kHdOut
|
||||
|
||||
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").transpose(0, 1)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if impl == "torch":
|
||||
|
||||
def runner():
|
||||
F.linear(mat_a, mat_b.T)
|
||||
|
||||
elif impl == "sgl-kernel":
|
||||
|
||||
def runner():
|
||||
dsv3_fused_a_gemm(mat_a, mat_b)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")
|
||||
127
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
Normal file
127
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import dsv3_router_gemm
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
|
||||
line_names=[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="TFLOPs",
|
||||
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_bf16_output(num_tokens, impl):
|
||||
# M: num_tokens, K: hidden_dim, N: num_experts
|
||||
M, K = num_tokens, 7168
|
||||
|
||||
if impl == "torch-256" or impl == "sgl-kernel-256":
|
||||
N = 256
|
||||
elif impl == "torch-384" or impl == "sgl-kernel-384":
|
||||
N = 384
|
||||
else:
|
||||
raise ValueError(f"Unknown impl: {impl}")
|
||||
|
||||
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if impl == "torch-256" or impl == "torch-384":
|
||||
|
||||
def runner():
|
||||
F.linear(mat_a, mat_b)
|
||||
|
||||
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
|
||||
|
||||
def runner():
|
||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
|
||||
line_names=[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="TFLOPs",
|
||||
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_float_output(num_tokens, impl):
|
||||
# M: num_tokens, K: hidden_dim, N: num_experts
|
||||
M, K = num_tokens, 7168
|
||||
|
||||
if impl == "torch-256" or impl == "sgl-kernel-256":
|
||||
N = 256
|
||||
elif impl == "torch-384" or impl == "sgl-kernel-384":
|
||||
N = 384
|
||||
else:
|
||||
raise ValueError(f"Unknown impl: {impl}")
|
||||
|
||||
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if impl == "torch-256" or impl == "torch-384":
|
||||
|
||||
def runner():
|
||||
F.linear(mat_a, mat_b).to(torch.float32)
|
||||
|
||||
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
|
||||
|
||||
def runner():
|
||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_bf16_output.run(
|
||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
||||
)
|
||||
benchmark_float_output.run(
|
||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
||||
)
|
||||
210
sgl-kernel/benchmark/bench_fp4_gemm.py
Executable file
210
sgl-kernel/benchmark/bench_fp4_gemm.py
Executable file
@@ -0,0 +1,210 @@
|
||||
import argparse
|
||||
import copy
|
||||
import csv
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
from flashinfer import mm_fp4
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def get_weight_shapes(args):
|
||||
models_tps = args.tp_sizes
|
||||
|
||||
if models_tps == [4]:
|
||||
return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]]
|
||||
|
||||
if models_tps == [8]:
|
||||
return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]]
|
||||
return [
|
||||
[1024, 3584],
|
||||
[7168, 256],
|
||||
[7168, 2304],
|
||||
[9216, 3584],
|
||||
[512, 3584],
|
||||
[7168, 128],
|
||||
[7168, 1152],
|
||||
[4608, 3584],
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
],
|
||||
# x_vals = [64],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["cutlass", "cudnn", "trtllm"],
|
||||
line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"],
|
||||
styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")],
|
||||
ylabel="latency (ms)",
|
||||
plot_name="fp4_gemm_benchmark",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
|
||||
M = batch_size
|
||||
packed_k = K
|
||||
K = 2 * packed_k
|
||||
a_dtype = torch.randn((M, K), dtype=dtype, device="cuda")
|
||||
b_dtype = torch.randn((N, K), dtype=dtype, device="cuda")
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
# print("a_fp4", a_fp4)
|
||||
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
res_fi = torch.empty((M, N), dtype=dtype, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "cutlass":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cudnn":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "trtllm":
|
||||
a_scale_interleaved = a_scale_interleaved.to(torch.uint8)
|
||||
b_scale_interleaved = b_scale_interleaved.to(torch.uint8)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
backend="trtllm",
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if correctness:
|
||||
res_cutlass = cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
backend="cudnn",
|
||||
)
|
||||
assert torch.allclose(
|
||||
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
|
||||
), "cudnn fp4 doesn't match cutlass fp4"
|
||||
mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
backend="trtllm",
|
||||
)
|
||||
assert torch.allclose(
|
||||
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
|
||||
), "trtllm fp4 doesn't match cutlass fp4"
|
||||
|
||||
if csv_file:
|
||||
with open(csv_file, "a", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([provider, M, N, K, ms])
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=torch.dtype,
|
||||
default=torch.bfloat16,
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--correctness",
|
||||
action="store_true",
|
||||
help="Check correctness",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
type=str,
|
||||
default="results_cutlass_cudnn.csv",
|
||||
help="CSV file to save results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.csv:
|
||||
with open(args.csv, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["provider", "m", "n", "k", "time_ms"])
|
||||
|
||||
NKs = get_weight_shapes(args)
|
||||
for N, K in NKs:
|
||||
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_fp4_res",
|
||||
N=N,
|
||||
K=K,
|
||||
dtype=args.dtype,
|
||||
correctness=args.correctness,
|
||||
csv_file=args.csv,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
183
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
Normal file
183
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
|
||||
def get_weight_shapes(args):
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
# only support Deepseek-V3
|
||||
SUPPORT_MODEL = ["deepseek-ai/DeepSeek-V3"]
|
||||
|
||||
weight_shapes = []
|
||||
for model, tp_size in models_tps:
|
||||
assert model in SUPPORT_MODEL
|
||||
for t in total:
|
||||
new_t = [t[0], t[1], model]
|
||||
weight_shapes.append(new_t)
|
||||
for n_t in n_tp:
|
||||
new_t = [n_t[0] // tp_size, n_t[1], model]
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = [k_t[0], k_t[1] // tp_size, model]
|
||||
weight_shapes.append(new_t)
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def fp8_gemm_deepgemm(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""DeepGEMM implementation of FP8 GEMM"""
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run DeepGEMM kernel
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
return out
|
||||
|
||||
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
|
||||
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
|
||||
ylabel="GB/s",
|
||||
plot_name="fp8 blockwise scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
scale_a_group_shape = (1, 128)
|
||||
scale_b_group_shape = (128, 128)
|
||||
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
|
||||
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
|
||||
|
||||
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl-kernel":
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_blockwise_scaled_mm(
|
||||
a_fp8, b_fp8, scale_a, scale_b, torch.float16
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: w8a8_block_fp8_matmul(
|
||||
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "deepgemm":
|
||||
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_deepgemm(
|
||||
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["deepseek-ai/DeepSeek-V3"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
NK_model_names = get_weight_shapes(args)
|
||||
for N, K, model_name in NK_model_names:
|
||||
if N % 128 != 0 or K % 128 != 0:
|
||||
print(f"Skip {N=}, {K=} now")
|
||||
continue
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_fp8_blockwise_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
328
sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py
Normal file
328
sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py
Normal file
@@ -0,0 +1,328 @@
|
||||
import argparse
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
|
||||
|
||||
|
||||
def get_m_alignment_for_contiguous_layout():
|
||||
return 128
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (128 - (n % 128)) % 128
|
||||
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
def construct_contiguous_grouped(
|
||||
num_groups: int, expected_m_per_group: int, k: int, n: int
|
||||
) -> Tuple[
|
||||
int,
|
||||
Tuple[torch.Tensor, torch.Tensor],
|
||||
Tuple[torch.Tensor, torch.Tensor],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
alignment = get_m_alignment_for_contiguous_layout()
|
||||
group_ms = [int(expected_m_per_group) for _ in range(num_groups)]
|
||||
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
|
||||
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16)
|
||||
m_indices = torch.empty(m, device="cuda", dtype=torch.int32)
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
start = 0
|
||||
for i, group_m in enumerate(group_ms):
|
||||
actual_end = start + group_m
|
||||
aligned_end = start + ceil_div(group_m, alignment) * alignment
|
||||
m_indices[start:actual_end] = i
|
||||
m_indices[actual_end:aligned_end] = -1
|
||||
start = aligned_end
|
||||
|
||||
assert m % 4 == 0, f"TMA alignment error: {m}"
|
||||
x_fp8 = per_token_cast_to_fp8(x)
|
||||
y_fp8 = (
|
||||
torch.empty_like(y, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, ceil_div(n, 128), k // 128), device="cuda", dtype=torch.float
|
||||
),
|
||||
)
|
||||
for i in range(num_groups):
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
|
||||
return m, x_fp8, y_fp8, m_indices, out
|
||||
|
||||
|
||||
def bench_deepgemm(
|
||||
expected_m_per_group: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
num_warmup: int,
|
||||
num_run: int,
|
||||
) -> Tuple[float, int]:
|
||||
# construct tensors
|
||||
m, x_fp8, y_fp8, m_indices, out = construct_contiguous_grouped(
|
||||
num_groups, expected_m_per_group, k, n
|
||||
)
|
||||
|
||||
def run_deepgemm():
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
|
||||
# warmup
|
||||
for _ in range(num_warmup):
|
||||
run_deepgemm()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# run
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
latencies: list[float] = []
|
||||
start_event.record()
|
||||
for _ in range(num_run):
|
||||
run_deepgemm()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
|
||||
|
||||
return avg, m
|
||||
|
||||
|
||||
def bench_cutlass(
|
||||
expected_m_per_group: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
num_warmup: int,
|
||||
num_run: int,
|
||||
) -> Tuple[float, int]:
|
||||
device = "cuda"
|
||||
alignment = 16
|
||||
n_g = ceil_div(n, alignment) * alignment
|
||||
k_g = ceil_div(k, alignment) * alignment
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32)
|
||||
problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32)
|
||||
layout_sfa = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
|
||||
layout_sfb = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
|
||||
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
b_scales_tensors = []
|
||||
|
||||
# TODO(@TianQiLin666666): Unique group_ms in all bench function
|
||||
group_ms = [
|
||||
alignment * ceil_div(int(expected_m_per_group), alignment)
|
||||
for _ in range(num_groups)
|
||||
]
|
||||
for g in range(num_groups):
|
||||
m_g = group_ms[g]
|
||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
|
||||
|
||||
a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device))
|
||||
b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t())
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
a_scales_tensors.append(a_scale)
|
||||
b_scales_tensors.append(b_scale)
|
||||
|
||||
a_stack = torch.empty(
|
||||
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
b_stack = torch.empty(
|
||||
(num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
for g in range(num_groups):
|
||||
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
|
||||
b_stack[g] = b_tensors[g].t()
|
||||
b_stack = b_stack.transpose(1, 2)
|
||||
|
||||
a_scale_stack = torch.empty(
|
||||
(expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32
|
||||
)
|
||||
b_scale_stack = torch.empty(
|
||||
(num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
for g in range(num_groups):
|
||||
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g]
|
||||
b_scale_stack[g] = b_scales_tensors[g].t()
|
||||
b_scale_stack = b_scale_stack.transpose(1, 2)
|
||||
|
||||
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||
a_strides = torch.full(
|
||||
(num_groups,), a_stack.stride(0), device=device, dtype=torch.int64
|
||||
)
|
||||
c_strides = torch.full(
|
||||
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
|
||||
)
|
||||
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||
a_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
b_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
out_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
a_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
b_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
|
||||
def run_cutlass():
|
||||
fp8_blockwise_scaled_grouped_mm(
|
||||
c_out,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a_stack,
|
||||
b_stack,
|
||||
a_scale_stack,
|
||||
b_scale_stack,
|
||||
a_strides,
|
||||
a_strides,
|
||||
c_strides,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets[:-1],
|
||||
workspace,
|
||||
)
|
||||
|
||||
# warmup
|
||||
for _ in range(num_warmup):
|
||||
run_cutlass()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# run
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(num_run):
|
||||
run_cutlass()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
|
||||
|
||||
return avg, expert_offsets[-1]
|
||||
|
||||
|
||||
def bench_sglang_triton(
|
||||
expected_m_per_group: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
num_warmup: int,
|
||||
num_run: int,
|
||||
) -> Tuple[float, int]:
|
||||
pass
|
||||
|
||||
|
||||
benchmark_kernels = {
|
||||
"deepgemm": bench_deepgemm,
|
||||
"cutlass": bench_cutlass,
|
||||
# "triton": bench_sglang_triton,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShapeArg:
|
||||
expected_m_per_group: int
|
||||
n: int
|
||||
k: int
|
||||
num_groups: int
|
||||
|
||||
|
||||
def benchmark_one_shape(
|
||||
shape_args: List[ShapeArg],
|
||||
num_warmup: int,
|
||||
num_run: int,
|
||||
):
|
||||
for shape in shape_args:
|
||||
print(
|
||||
f"\nBenchmark: expected_m_per_group={shape.expected_m_per_group}, n={shape.n}, k={shape.k}, num_groups={shape.num_groups}"
|
||||
)
|
||||
for kernel_name, kernel_func in benchmark_kernels.items():
|
||||
average_time, m = kernel_func(
|
||||
shape.expected_m_per_group,
|
||||
shape.n,
|
||||
shape.k,
|
||||
shape.num_groups,
|
||||
num_warmup,
|
||||
num_run,
|
||||
)
|
||||
print(f"{kernel_name}: {average_time} us")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-warmup", type=int, default=3)
|
||||
parser.add_argument("--num-run", type=int, default=10)
|
||||
shape_args = [
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
|
||||
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
|
||||
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
|
||||
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
|
||||
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
|
||||
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
|
||||
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
|
||||
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
|
||||
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
|
||||
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
|
||||
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
|
||||
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
|
||||
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
|
||||
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
|
||||
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
|
||||
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
|
||||
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
|
||||
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
|
||||
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
|
||||
]
|
||||
args = parser.parse_args()
|
||||
benchmark_one_shape(shape_args, args.num_warmup, args.num_run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
184
sgl-kernel/benchmark/bench_fp8_gemm.py
Normal file
184
sgl-kernel/benchmark/bench_fp8_gemm.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
# Example:
|
||||
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||
# - TP1 : K = 14336, N = 4096
|
||||
# - TP2 : K = 7168, N = 4096
|
||||
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||
# - TP1 : K = 4096, N = 6144
|
||||
# - TP4 : K = 4096, N = 1536
|
||||
|
||||
# TP1 shapes
|
||||
WEIGHT_SHAPES = {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def sglang_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_type_: torch.dtype = torch.float8_e4m3fn
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
is_static = True
|
||||
if scale is None:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
is_static = False
|
||||
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"vllm-fp8-fp16",
|
||||
"vllm-fp8-bf16",
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
],
|
||||
line_names=[
|
||||
"vllm-fp8-fp16",
|
||||
"vllm-fp8-bf16",
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
],
|
||||
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
|
||||
ylabel="GB/s",
|
||||
plot_name="fp8 scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
# M, N, K = batch_size, 4096, 8192
|
||||
M = batch_size
|
||||
a = torch.ones((M, K), device="cuda") * 5.0
|
||||
b = torch.ones((N, K), device="cuda") * 5.0
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
|
||||
|
||||
if "vllm-fp8" in provider:
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
b_fp8 = b_fp8.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif "sglang-fp8" in provider:
|
||||
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
|
||||
b_fp8 = b_fp8.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_scaled_mm(
|
||||
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
146
sgl-kernel/benchmark/bench_int8_gemm.py
Normal file
146
sgl-kernel/benchmark/bench_int8_gemm.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
WEIGHT_SHAPES = {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sgl-kernel"],
|
||||
line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"],
|
||||
styles=[("blue", "-"), ("orange", "-")],
|
||||
ylabel="GB/s",
|
||||
plot_name="int8 scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
a = to_int8(torch.randn((M, K), device="cuda") * 5)
|
||||
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
bias = torch.randn((N,), device="cuda", dtype=torch.float16)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl-kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
gbps = (
|
||||
lambda ms: (
|
||||
(2 * M * N * K - M * N) * a.element_size()
|
||||
+ (3 * M * N) * scale_a.element_size()
|
||||
)
|
||||
* 1e-9
|
||||
/ (ms * 1e-3)
|
||||
)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import lightning_attention_decode
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _decode_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
KV,
|
||||
Out,
|
||||
S,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n: tl.constexpr,
|
||||
d: tl.constexpr,
|
||||
d_original: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
e_original: tl.constexpr,
|
||||
):
|
||||
off_bh = tl.program_id(0)
|
||||
off_h = off_bh % h
|
||||
|
||||
qk_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
o_offset = off_bh * n * e
|
||||
kv_offset = off_bh * d * e
|
||||
|
||||
s = tl.load(S + off_h)
|
||||
ratio = tl.exp(-s)
|
||||
|
||||
d_idx = tl.arange(0, d)
|
||||
e_idx = tl.arange(0, e)
|
||||
|
||||
# Create masks for original dimensions
|
||||
d_mask = d_idx < d_original
|
||||
e_mask = e_idx < e_original
|
||||
|
||||
# Load with masking
|
||||
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
|
||||
|
||||
# Load KV with 2D masking
|
||||
kv = tl.load(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute outer product using element-wise operations
|
||||
k_v_prod = k[:, None] * v[None, :]
|
||||
kv = ratio * kv + k_v_prod
|
||||
|
||||
# Store KV with 2D masking
|
||||
tl.store(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
kv.to(KV.dtype.element_ty),
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
)
|
||||
|
||||
# Compute matrix-vector multiplication using element-wise operations and reduction
|
||||
o = tl.sum(q[:, None] * kv, axis=0)
|
||||
|
||||
# Store output with masking
|
||||
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
|
||||
|
||||
|
||||
def triton_lightning_attn_decode(q, k, v, kv, s):
|
||||
"""Triton implementation of Lightning Attention decode operation"""
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
assert n == 1, "Sequence length must be 1 in decode mode"
|
||||
|
||||
# Get padded dimensions (power of 2)
|
||||
d_padded = next_power_of_2(d)
|
||||
e_padded = next_power_of_2(e)
|
||||
|
||||
# Create output tensor (padded)
|
||||
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
|
||||
# Create padded tensors without actually padding the data
|
||||
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
|
||||
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
|
||||
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
kv_padded = torch.empty(
|
||||
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
|
||||
)
|
||||
|
||||
# Copy data to padded tensors
|
||||
q_padded[..., :d] = q
|
||||
k_padded[..., :d] = k
|
||||
v_padded[..., :e] = v
|
||||
kv_padded[..., :d, :e] = kv
|
||||
|
||||
# Launch kernel
|
||||
grid = (b * h, 1)
|
||||
_decode_kernel[grid](
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
kv_padded,
|
||||
o_padded,
|
||||
s,
|
||||
b=b,
|
||||
h=h,
|
||||
n=n,
|
||||
d=d_padded,
|
||||
d_original=d,
|
||||
e=e_padded,
|
||||
e_original=e,
|
||||
)
|
||||
|
||||
# Get unpadded outputs
|
||||
o = o_padded[..., :e]
|
||||
kv_out = kv_padded[..., :d, :e]
|
||||
|
||||
return o, kv_out
|
||||
|
||||
|
||||
def lightning_attention_decode_naive(q, k, v, past_kv, slope):
|
||||
"""Naive implementation of lightning attention decode"""
|
||||
original_dtype = q.dtype
|
||||
ratio = torch.exp(-slope) # [h, 1, 1]
|
||||
|
||||
kv = past_kv
|
||||
b, h, n, d = q.shape
|
||||
|
||||
output = []
|
||||
for i in range(n):
|
||||
kv = ratio * kv.to(torch.float32) + torch.einsum(
|
||||
"... n d, ... n e -> ... d e",
|
||||
k[:, :, i : i + 1],
|
||||
v[:, :, i : i + 1],
|
||||
)
|
||||
qkv = torch.einsum(
|
||||
"... n e, ... e d -> ... n d",
|
||||
q[:, :, i : i + 1].to(torch.float32),
|
||||
kv.to(torch.float32),
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.cat(output, dim=-2)
|
||||
|
||||
return output.to(original_dtype), kv
|
||||
|
||||
|
||||
def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
|
||||
return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||
|
||||
|
||||
def calculate_diff(batch_size):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_heads = 64
|
||||
head_dim = 96
|
||||
seq_len = 1
|
||||
|
||||
q = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
output_naive, new_kv_naive = lightning_attention_decode_naive(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
)
|
||||
|
||||
output_kernel = torch.empty_like(output_naive)
|
||||
new_kv_kernel = torch.empty_like(new_kv_naive)
|
||||
lightning_attention_decode_kernel(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
past_kv.clone(),
|
||||
slope.clone(),
|
||||
output_kernel,
|
||||
new_kv_kernel,
|
||||
)
|
||||
|
||||
output_triton, new_kv_triton = triton_lightning_attn_decode(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
)
|
||||
|
||||
if (
|
||||
torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [i for i in range(1, 65)] # 1 to 128
|
||||
configs = [(bs,) for bs in batch_size_range]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["naive", "kernel", "triton"],
|
||||
line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="lightning-attention-decode-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_heads = 64
|
||||
head_dim = 96
|
||||
seq_len = 1
|
||||
|
||||
q = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "naive":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: lightning_attention_decode_naive(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
output = torch.empty(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: lightning_attention_decode_kernel(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
past_kv.clone(),
|
||||
slope.clone(),
|
||||
output,
|
||||
new_kv,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: triton_lightning_attn_decode(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
|
||||
help="Path to save lightning attention decode benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(batch_size=4)
|
||||
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
401
sgl-kernel/benchmark/bench_moe_align_block_size.py
Normal file
401
sgl-kernel/benchmark/bench_moe_align_block_size.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
except ImportError:
|
||||
ops = None
|
||||
|
||||
USE_RANDOM_PERM = False
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts,)
|
||||
tokens_cnts = torch.zeros(
|
||||
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
||||
tokens_per_thread = ceil_div(numel, num_experts)
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1,)](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
|
||||
|
||||
def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
||||
topk_ids = torch.stack(
|
||||
[
|
||||
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||
for _ in range(num_tokens)
|
||||
]
|
||||
)
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids_cuda = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
sorted_ids_cuda.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids_cuda = torch.zeros(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad_cuda = torch.empty(
|
||||
(1), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.zeros(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
||||
sorted_ids_triton.fill_(topk_ids.numel())
|
||||
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
|
||||
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
||||
|
||||
sorted_ids_vllm = torch.empty_like(sorted_ids_cuda)
|
||||
sorted_ids_vllm.fill_(topk_ids.numel())
|
||||
expert_ids_vllm = torch.zeros_like(expert_ids_cuda)
|
||||
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda)
|
||||
|
||||
# compare the performance of cuda, triton and vllm implementation
|
||||
sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_cuda,
|
||||
expert_ids_cuda,
|
||||
num_tokens_post_pad_cuda,
|
||||
cumsum_buffer,
|
||||
)
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_triton,
|
||||
expert_ids_triton,
|
||||
num_tokens_post_pad_triton,
|
||||
)
|
||||
|
||||
try:
|
||||
ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_vllm,
|
||||
expert_ids_vllm,
|
||||
num_tokens_post_pad_vllm,
|
||||
)
|
||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||
vllm_works = True
|
||||
except Exception as e:
|
||||
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
|
||||
vllm_works = False
|
||||
|
||||
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
|
||||
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
|
||||
):
|
||||
print("✅ SGL and Triton implementations match")
|
||||
else:
|
||||
print("❌ SGL and Triton implementations do not match")
|
||||
print("SGL expert_ids:", expert_ids_cuda)
|
||||
print("Triton expert_ids:", expert_ids_triton)
|
||||
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
||||
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
|
||||
|
||||
if (
|
||||
vllm_works
|
||||
and torch.allclose(expert_ids_cuda, expert_ids_vllm)
|
||||
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm)
|
||||
):
|
||||
print("✅ SGL and VLLM implementations match")
|
||||
else:
|
||||
if not vllm_works:
|
||||
print("⚠️ VLLM comparison skipped due to failure")
|
||||
else:
|
||||
print("❌ SGL and VLLM implementations do not match")
|
||||
print("SGL expert_ids:", expert_ids_cuda)
|
||||
print("VLLM expert_ids:", expert_ids_vllm)
|
||||
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
||||
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
|
||||
|
||||
|
||||
# Test range
|
||||
num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
num_experts_range = [8, 32, 64, 128, 256]
|
||||
topk_range = [1, 2, 4, 8]
|
||||
|
||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||
|
||||
|
||||
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||
topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda")
|
||||
for i in range(num_tokens):
|
||||
topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[
|
||||
:topk
|
||||
]
|
||||
return topk_ids
|
||||
|
||||
|
||||
def sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
pad_sorted_token_ids=False,
|
||||
):
|
||||
if not pad_sorted_token_ids:
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
cumsum_buffer,
|
||||
pad_sorted_token_ids,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_experts", "topk"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["sgl", "sgl_fusion", "triton"],
|
||||
line_names=["SGL", "SGL Fusion", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="moe-align-block-size-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, num_experts, topk, provider):
|
||||
block_size = 128
|
||||
|
||||
if USE_RANDOM_PERM:
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
else:
|
||||
topk_ids = torch.randint(
|
||||
0,
|
||||
num_experts,
|
||||
(num_tokens, topk),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "sgl_fusion":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
pad_sorted_token_ids=True,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/moe_align_blocks/",
|
||||
help="Path to save moe align benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_experts",
|
||||
type=int,
|
||||
default=256,
|
||||
choices=[8, 16, 32, 64, 128, 256],
|
||||
help="Number of experts for benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=[2, 4, 8],
|
||||
help="Top-k value for benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_full_benchmark",
|
||||
action="store_true",
|
||||
help="Only run the calculate_diff function, skip full benchmarking",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
||||
|
||||
if not args.skip_full_benchmark:
|
||||
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
|
||||
benchmark.run(print_data=True)
|
||||
93
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
Normal file
93
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import ep_moe_post_reorder
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
||||
|
||||
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||
configs = [(bs,) for bs in batch_sizes]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["cuda", "triton"],
|
||||
line_names=["CUDA Kernel", "Triton Kernel"],
|
||||
styles=[("green", "-"), ("orange", "-")],
|
||||
ylabel="us",
|
||||
plot_name="ep-moe-post-reorder-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512
|
||||
|
||||
def alloc_tensors():
|
||||
down_output = torch.randn(
|
||||
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
output = torch.zeros(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
src2dst = torch.randint(
|
||||
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
|
||||
)
|
||||
topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size, topk),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
|
||||
return down_output, output, src2dst, topk_ids, topk_weights
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "cuda":
|
||||
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
|
||||
|
||||
def run_cuda():
|
||||
ep_moe_post_reorder(
|
||||
d_out,
|
||||
out,
|
||||
s2d,
|
||||
tk_ids,
|
||||
tk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
||||
|
||||
elif provider == "triton":
|
||||
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
|
||||
|
||||
def run_triton():
|
||||
post_reorder_triton_kernel[(batch_size,)](
|
||||
d_out.view(-1),
|
||||
out.view(-1),
|
||||
s2d.view(-1),
|
||||
tk_ids.view(-1),
|
||||
tk_weights.view(-1),
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
0,
|
||||
block_size,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
103
sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py
Normal file
103
sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import ep_moe_pre_reorder
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
|
||||
|
||||
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||
configs = [(bs,) for bs in batch_sizes]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["cuda", "triton"],
|
||||
line_names=["CUDA Kernel", "Triton Kernel"],
|
||||
styles=[("green", "-"), ("orange", "-")],
|
||||
ylabel="us",
|
||||
plot_name="ep-moe-pre-reorder-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
hidden_size, topk, start_expert_id, end_expert_id, block_size = (
|
||||
4096,
|
||||
8,
|
||||
0,
|
||||
255,
|
||||
512,
|
||||
)
|
||||
|
||||
# Allocate fresh tensors for every run to match bench_moe_fused_gate style
|
||||
def alloc_tensors():
|
||||
input_ = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
gateup_input = torch.zeros(
|
||||
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
src2dst = torch.randint(
|
||||
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
|
||||
)
|
||||
topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size, topk),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
a1_scales = torch.rand(
|
||||
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
|
||||
)
|
||||
return input_, gateup_input, src2dst, topk_ids, a1_scales
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "cuda":
|
||||
inp, gout, s2d, tk_ids, scales = alloc_tensors()
|
||||
|
||||
def run_cuda():
|
||||
ep_moe_pre_reorder(
|
||||
inp,
|
||||
gout,
|
||||
s2d,
|
||||
tk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
True,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
||||
|
||||
elif provider == "triton":
|
||||
inp, gout, s2d, tk_ids, scales = alloc_tensors()
|
||||
|
||||
def run_triton():
|
||||
pre_reorder_triton_kernel[(batch_size,)](
|
||||
inp.view(-1),
|
||||
gout.view(-1),
|
||||
s2d.view(-1),
|
||||
tk_ids.view(-1),
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
block_size,
|
||||
True,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
77
sgl-kernel/benchmark/bench_moe_fused_gate.py
Normal file
77
sgl-kernel/benchmark/bench_moe_fused_gate.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import moe_fused_gate
|
||||
|
||||
from sglang.srt.layers.moe.topk import biased_grouped_topk
|
||||
|
||||
|
||||
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
|
||||
return biased_grouped_topk(
|
||||
scores,
|
||||
scores,
|
||||
bias,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
routed_scaling_factor=2.5, # DeepSeek-R1 : 2.5, Kimi K2: 2.872
|
||||
)
|
||||
|
||||
|
||||
def biased_grouped_topk_org_fuse_kernel(
|
||||
scores, bias, num_expert_group, topk_group, topk
|
||||
):
|
||||
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
|
||||
|
||||
|
||||
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
|
||||
configs = [(sq,) for sq in seq_length_range]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["seq_length"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["original", "kernel"],
|
||||
line_names=["Original", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="moe-fused-gate-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(seq_length, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
|
||||
|
||||
scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
|
||||
bias = torch.rand(num_experts, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "original":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: biased_grouped_topk_org(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: biased_grouped_topk_org_fuse_kernel(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
92
sgl-kernel/benchmark/bench_moe_silu_and_mul.py
Normal file
92
sgl-kernel/benchmark/bench_moe_silu_and_mul.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import ep_moe_silu_and_mul
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel
|
||||
|
||||
batch_size_range = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||
hidden_size_range = [1024, 2048, 4096, 8192]
|
||||
block_size_range = [128, 256, 512]
|
||||
configs = list(itertools.product(batch_size_range, hidden_size_range, block_size_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "hidden_size", "block_size"],
|
||||
x_vals=[list(cfg) for cfg in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["cuda", "triton"],
|
||||
line_names=["CUDA Kernel", "Triton Kernel"],
|
||||
styles=[("green", "-"), ("orange", "-")],
|
||||
ylabel="us",
|
||||
plot_name="ep-moe-silu-and-mul-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, hidden_size, block_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
|
||||
half_hidden_size = hidden_size // 2
|
||||
start_expert_id, end_expert_id = 0, 255
|
||||
block_size = 512
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
def alloc_tensors():
|
||||
gateup_output = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
down_input = torch.empty(
|
||||
batch_size, half_hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
reorder_topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
scales = torch.rand(
|
||||
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
|
||||
)
|
||||
return gateup_output, down_input, reorder_topk_ids, scales
|
||||
|
||||
if provider == "cuda":
|
||||
gateup, down, ids, scales = alloc_tensors()
|
||||
|
||||
def run_cuda():
|
||||
ep_moe_silu_and_mul(
|
||||
gateup,
|
||||
down,
|
||||
ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
||||
|
||||
elif provider == "triton":
|
||||
gateup, down, ids, scales = alloc_tensors()
|
||||
|
||||
def run_triton():
|
||||
silu_and_mul_triton_kernel[(batch_size,)](
|
||||
gateup.view(-1),
|
||||
down.view(-1),
|
||||
hidden_size,
|
||||
ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
block_size,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
116
sgl-kernel/benchmark/bench_moe_topk_softmax.py
Normal file
116
sgl-kernel/benchmark/bench_moe_topk_softmax.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import topk_softmax
|
||||
from vllm import _custom_ops as vllm_custom_ops
|
||||
|
||||
|
||||
def vllm_topk_softmax(gating_output, topk):
|
||||
num_tokens, num_experts = gating_output.shape
|
||||
|
||||
topk_weights = torch.empty(
|
||||
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
|
||||
)
|
||||
topk_indices = torch.empty(
|
||||
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
|
||||
)
|
||||
token_expert_indices = torch.empty(
|
||||
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
|
||||
)
|
||||
torch.ops._moe_C.topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output
|
||||
)
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def sglang_topk_softmax(gating_output, topk):
|
||||
num_tokens, num_experts = gating_output.shape
|
||||
|
||||
topk_weights = torch.empty(
|
||||
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
|
||||
)
|
||||
topk_indices = torch.empty(
|
||||
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
|
||||
)
|
||||
|
||||
topk_softmax(
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_indices,
|
||||
gating_output=gating_output,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def calculate_diff(num_tokens, num_experts, topk):
|
||||
gating_output = torch.randn(
|
||||
(num_tokens, num_experts), device="cuda", dtype=torch.float32
|
||||
)
|
||||
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
|
||||
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
|
||||
|
||||
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
|
||||
indices_match = torch.equal(indices_vllm, indices_sglang)
|
||||
|
||||
if (
|
||||
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
|
||||
and indices_match
|
||||
):
|
||||
print("✅ VLLM and SGLang topk_softmax implementations match")
|
||||
else:
|
||||
print(
|
||||
f"❌ Implementations differ: Weights diff={weights_diff}, Indices match={indices_match}"
|
||||
)
|
||||
|
||||
|
||||
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
||||
num_experts_range = [32, 64, 128, 256, 12, 512]
|
||||
topk_range = [1, 2, 4, 8]
|
||||
|
||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_experts", "topk"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["sglang", "vllm"],
|
||||
line_names=["SGLang", "VLLM"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="Latency (us)",
|
||||
plot_name="topk-softmax-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, num_experts, topk, provider):
|
||||
|
||||
gating_output = torch.randn(
|
||||
(num_tokens, num_experts), device="cuda", dtype=torch.float32
|
||||
)
|
||||
|
||||
if provider == "vllm" or provider == "vllm1":
|
||||
fn = lambda: vllm_topk_softmax(gating_output, topk)
|
||||
elif provider == "sglang" or provider == "sglang1":
|
||||
fn = lambda: sglang_topk_softmax(gating_output, topk)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
configs = [
|
||||
(20, 256, 4),
|
||||
(20, 256, 8),
|
||||
(20, 12, 4),
|
||||
(20, 12, 1),
|
||||
(20, 512, 4),
|
||||
(20, 512, 1),
|
||||
]
|
||||
for num_tokens, num_experts, topk in configs:
|
||||
calculate_diff(num_tokens, num_experts, topk)
|
||||
benchmark.run(print_data=True)
|
||||
172
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
Normal file
172
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
# Example:
|
||||
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||
# - TP1 : K = 14336, N = 4096
|
||||
# - TP2 : K = 7168, N = 4096
|
||||
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||
# - TP1 : K = 4096, N = 6144
|
||||
# - TP4 : K = 4096, N = 1536
|
||||
|
||||
# TP1 shapes
|
||||
WEIGHT_SHAPES = {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"sglang-fp4-fp16",
|
||||
"sglang-fp4-bf16",
|
||||
],
|
||||
line_names=[
|
||||
"sglang-fp4-fp16",
|
||||
"sglang-fp4-bf16",
|
||||
],
|
||||
styles=[("green", "-"), ("blue", "-")],
|
||||
ylabel="TFLOPS",
|
||||
plot_name="fp4 block scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
# M, N, K = batch_size, 4096, 8192
|
||||
run_step = 100
|
||||
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
|
||||
M = batch_size
|
||||
a = torch.randn((M, K), dtype=dtype, device="cuda")
|
||||
b = torch.randn((N, K), dtype=dtype, device="cuda")
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = scaled_fp4_quant(b, b_global_scale)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
# Bridging the gap between CPU and GPU
|
||||
for _ in range(25):
|
||||
c = a @ b.t()
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
start_event.record()
|
||||
for _ in range(run_step):
|
||||
cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
ms = start_event.elapsed_time(end_event) / run_step
|
||||
|
||||
tflops = lambda ms: (2 * M * N * K) * 1e-9 / ms
|
||||
return tflops(ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
98
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
Normal file
98
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def vllm_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input, scale)
|
||||
|
||||
|
||||
def sglang_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_type_: torch.dtype = torch.float8_e4m3fn
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
is_static = True
|
||||
if scale is None:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
is_static = False
|
||||
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between VLLM and SGLang implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
|
||||
|
||||
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
|
||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
|
||||
|
||||
scale_diff = torch.abs(vllm_scale - sglang_scale).item()
|
||||
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||
|
||||
if torch.allclose(
|
||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-tensor-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
fn = lambda: vllm_scaled_fp8_quant(x.clone())
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_scaled_fp8_quant(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark.run(print_data=True)
|
||||
98
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
Normal file
98
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import itertools
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.bench_utils import bench_kineto
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
create_per_token_group_quant_fp8_output_scale,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
|
||||
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
|
||||
group_size_range = [128] # For DeepSeek V3/R1
|
||||
# TODO test int8
|
||||
dst_dtype_range = [fp8_type_]
|
||||
flags_range = [
|
||||
dict(
|
||||
column_major_scales=False,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
configs = list(
|
||||
itertools.product(
|
||||
num_tokens_range,
|
||||
hidden_dim_range,
|
||||
group_size_range,
|
||||
dst_dtype_range,
|
||||
flags_range,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["triton", "sglang"],
|
||||
line_names=["Triton", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-group-quant-8bit-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
|
||||
if flags["scale_ue8m0"] and group_size != 128:
|
||||
return
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
|
||||
|
||||
fn, kernel_names = {
|
||||
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
|
||||
"sglang": (
|
||||
sglang_per_token_group_quant_8bit,
|
||||
"per_token_group_quant_8bit_kernel",
|
||||
),
|
||||
}[provider]
|
||||
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
|
||||
|
||||
time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
|
||||
return time_s * 1e6
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
177
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
Normal file
177
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import sgl_per_token_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
# Get correct FP8 E4M3 maximum value
|
||||
if _is_hip:
|
||||
FP8_E4M3_MAX = 224.0 # ROCM uses 224.0
|
||||
else:
|
||||
# For CUDA, get the actual max value from the type
|
||||
FP8_E4M3_MAX = float(torch.finfo(fp8_type_).max)
|
||||
|
||||
|
||||
def torch_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pure PyTorch reference implementation for per-token FP8 quantization."""
|
||||
device = input.device
|
||||
dtype = input.dtype
|
||||
|
||||
# Find max absolute value per token (row) - exactly like CUDA kernel
|
||||
max_vals = torch.abs(input).max(dim=1)[0] # [num_tokens]
|
||||
|
||||
# Calculate scale per token - exactly like CUDA kernel: scale = max_value / FP8_E4M3_MAX
|
||||
scales = max_vals / FP8_E4M3_MAX # [num_tokens]
|
||||
|
||||
# No special zero handling - directly compute 1.0 / scale like CUDA kernel
|
||||
scale_inv = 1.0 / scales # [num_tokens]
|
||||
|
||||
# Quantize: input * scale_inv, then clamp to FP8 range
|
||||
quantized_float = input * scale_inv.unsqueeze(1) # Broadcast scale_inv
|
||||
quantized_float = torch.clamp(quantized_float, -FP8_E4M3_MAX, FP8_E4M3_MAX)
|
||||
|
||||
# Convert to FP8 - use more explicit conversion
|
||||
quantized_fp8 = quantized_float.to(fp8_type_)
|
||||
|
||||
return quantized_fp8, scales
|
||||
|
||||
|
||||
def vllm_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
|
||||
"""Compare Torch reference, VLLM, and SGLang implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand(
|
||||
(batch_size * seq_len, hidden_dim), dtype=torch.float16, device=device
|
||||
)
|
||||
|
||||
# Get all three implementations
|
||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
||||
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
|
||||
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
|
||||
|
||||
print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
|
||||
|
||||
# Compare scales
|
||||
torch_vllm_scale_diff = torch.abs(torch_scale - vllm_scale).mean().item()
|
||||
torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
|
||||
vllm_sglang_scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
|
||||
|
||||
print(f"Scale differences:")
|
||||
print(f" Torch vs VLLM: {torch_vllm_scale_diff:.8f}")
|
||||
print(f" Torch vs SGLang: {torch_sglang_scale_diff:.8f}")
|
||||
print(f" VLLM vs SGLang: {vllm_sglang_scale_diff:.8f}")
|
||||
|
||||
# Compare outputs
|
||||
torch_vllm_out_diff = torch.abs(torch_out.float() - vllm_out.float()).mean().item()
|
||||
torch_sglang_out_diff = (
|
||||
torch.abs(torch_out.float() - sglang_out.float()).mean().item()
|
||||
)
|
||||
vllm_sglang_out_diff = (
|
||||
torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||
)
|
||||
|
||||
print(f"Output differences:")
|
||||
print(f" Torch vs VLLM: {torch_vllm_out_diff:.8f}")
|
||||
print(f" Torch vs SGLang: {torch_sglang_out_diff:.8f}")
|
||||
print(f" VLLM vs SGLang: {vllm_sglang_out_diff:.8f}")
|
||||
|
||||
# Check tolerances
|
||||
rtol, atol = 1e-3, 1e-5
|
||||
|
||||
torch_vllm_match = torch.allclose(
|
||||
torch_out.float(), vllm_out.float(), rtol=rtol, atol=atol
|
||||
) and torch.allclose(torch_scale, vllm_scale, rtol=rtol, atol=atol)
|
||||
torch_sglang_match = torch.allclose(
|
||||
torch_out.float(), sglang_out.float(), rtol=rtol, atol=atol
|
||||
) and torch.allclose(torch_scale, sglang_scale, rtol=rtol, atol=atol)
|
||||
|
||||
if hidden_dim == 1368:
|
||||
rtol = 1e-2
|
||||
# we found vllm sglang has diff when hidden dim is not dividable by 16
|
||||
# and we believe SGLang is closer to Torch implementation
|
||||
|
||||
vllm_sglang_match = torch.allclose(
|
||||
vllm_out.float(), sglang_out.float(), rtol=rtol, atol=atol
|
||||
) and torch.allclose(vllm_scale, sglang_scale, rtol=rtol, atol=atol)
|
||||
|
||||
print(f"Matches (rtol={rtol}, atol={atol}):")
|
||||
print(f" Torch vs VLLM: {'✅' if torch_vllm_match else '❌'}")
|
||||
print(f" Torch vs SGLang: {'✅' if torch_sglang_match else '❌'}")
|
||||
print(f" VLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}")
|
||||
|
||||
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
|
||||
hidden_dim_range = [1368, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "hidden_dim"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "vllm", "sglang"],
|
||||
line_names=["Torch Reference", "VLLM", "SGL Kernel"],
|
||||
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
elif provider == "vllm":
|
||||
fn = lambda: vllm_per_token_quant_fp8(x.clone())
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test various hidden dimensions for correctness
|
||||
test_dims = [1368, 2048, 4096]
|
||||
|
||||
for dim in test_dims:
|
||||
calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting performance benchmark...")
|
||||
benchmark_quantization.run(print_data=True)
|
||||
198
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
Normal file
198
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import (
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
WEIGHT_SHAPES = {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
|
||||
line_names=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="FP16_vs_W8A8_vs_Qserve_W4A8_GEMM",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
# For W8A8
|
||||
a = to_int8(torch.randn((M, K), device="cuda") * 5)
|
||||
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
|
||||
a_fp16 = a.to(torch.float16)
|
||||
b_fp16 = b.to(torch.float16)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
|
||||
# For Qserve W4A8 per channel
|
||||
a_qserve_chn = a
|
||||
# two int4s pack into one int8
|
||||
b_qserve_chn = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
|
||||
# b_qserve_chn = b.t().contiguous()
|
||||
scale_a_qserve_chn = scale_a.to(torch.float16)
|
||||
scale_b_qserve_chn = scale_b.to(torch.float16)
|
||||
szero_b_qserve_chn = torch.randn((N,), device="cuda", dtype=torch.float16)
|
||||
a_sum_qserve_chn = torch.randn((M,), device="cuda", dtype=torch.float16)
|
||||
|
||||
# For Qserve W4A8 per group
|
||||
group_size = 128
|
||||
assert K % group_size == 0, "K must be divisible by group_size"
|
||||
a_qserve_group = a
|
||||
# two int4s pack into one int8
|
||||
b_qserve_group = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
|
||||
# b_qserve_group = b.t().contiguous()
|
||||
scale_a_qserve_group = scale_a.to(torch.float16)
|
||||
scale_b_qserve_group = scale_b.to(torch.float16)
|
||||
scale_i8_b_qserve_group = to_int8(
|
||||
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
|
||||
)
|
||||
zero_i8_b_qserve_group = to_int8(
|
||||
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "FP16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.matmul(a_fp16, b_fp16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "W8A8":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Channel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: qserve_w4a8_per_chn_gemm(
|
||||
a_qserve_chn,
|
||||
b_qserve_chn,
|
||||
scale_b_qserve_chn,
|
||||
scale_a_qserve_chn,
|
||||
szero_b_qserve_chn,
|
||||
a_sum_qserve_chn,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Group":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: qserve_w4a8_per_group_gemm(
|
||||
a_qserve_group,
|
||||
b_qserve_group,
|
||||
zero_i8_b_qserve_group,
|
||||
scale_i8_b_qserve_group,
|
||||
scale_b_qserve_group,
|
||||
scale_a_qserve_group,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return ms, max_ms, min_ms
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_qserve_w4a8_gemm_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
96
sgl-kernel/benchmark/bench_rotary_embedding.py
Normal file
96
sgl-kernel/benchmark/bench_rotary_embedding.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import FusedSetKVBufferArg
|
||||
from sgl_kernel.testing.rotary_embedding import (
|
||||
FlashInferRotaryEmbedding,
|
||||
MHATokenToKVPool,
|
||||
RotaryEmbedding,
|
||||
create_inputs,
|
||||
)
|
||||
|
||||
from sglang.srt.bench_utils import bench_kineto
|
||||
|
||||
configs = [
|
||||
(batch_size, seq_len, save_kv_cache)
|
||||
for batch_size, seq_len in (
|
||||
(1, 1),
|
||||
(32, 1),
|
||||
(128, 1),
|
||||
(512, 1),
|
||||
(2, 512),
|
||||
(4, 4096),
|
||||
)
|
||||
for save_kv_cache in (False, True)
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "save_kv_cache"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["sglang"],
|
||||
line_names=["SGL Kernel"],
|
||||
styles=[("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="bench_rotary_embedding",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, save_kv_cache, provider):
|
||||
device = torch.device("cuda")
|
||||
|
||||
num_q_heads = 32
|
||||
num_kv_heads = 8
|
||||
head_size = 64
|
||||
dtype = torch.bfloat16
|
||||
|
||||
config = dict(
|
||||
head_size=head_size,
|
||||
rotary_dim=64,
|
||||
max_position_embeddings=4096,
|
||||
base=8000,
|
||||
is_neox_style=True,
|
||||
dtype=dtype,
|
||||
)
|
||||
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
|
||||
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
|
||||
|
||||
inputs = create_inputs(
|
||||
head_size=head_size,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
)
|
||||
|
||||
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
|
||||
|
||||
bench_fn = lambda: rope_flashinfer.forward_cuda(
|
||||
inputs["pos_ids"],
|
||||
query_flashinfer,
|
||||
key_flashinfer,
|
||||
fused_set_kv_buffer_arg=(
|
||||
FusedSetKVBufferArg(
|
||||
value=inputs["value"],
|
||||
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
cache_loc=inputs["out_cache_loc"],
|
||||
)
|
||||
if save_kv_cache
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
time_s = bench_kineto(bench_fn, kernel_names="BatchQKApplyRotaryPosIds")
|
||||
return time_s * 1e6
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
128
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
Normal file
128
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import itertools
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
|
||||
|
||||
def torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k, top_p, eps=1e-4
|
||||
):
|
||||
"""Reference PyTorch implementation of joint top-k top-p sampling."""
|
||||
batch_size, vocab_size = normalized_prob.shape
|
||||
samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)
|
||||
|
||||
for i in range(batch_size):
|
||||
p_val = top_p[i].item()
|
||||
k_val = top_k[i].item()
|
||||
|
||||
# top-p mask
|
||||
sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask_top_p = torch.zeros(
|
||||
vocab_size, dtype=torch.int32, device=normalized_prob.device
|
||||
)
|
||||
mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())
|
||||
|
||||
# top-k mask
|
||||
sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
|
||||
pivot = sorted_prob_desc[k_val - 1]
|
||||
mask_top_k = (normalized_prob[i] >= pivot).int()
|
||||
|
||||
# joint mask
|
||||
mask = torch.minimum(mask_top_p, mask_top_k).bool()
|
||||
|
||||
# sample from masked probs
|
||||
masked_probs = normalized_prob[i] * mask
|
||||
masked_probs = masked_probs / masked_probs.sum()
|
||||
idx = torch.multinomial(masked_probs, 1)
|
||||
samples[i] = idx
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def calculate_diff(batch_size, vocab_size, p):
|
||||
"""Compare Torch reference and SGLang kernel for correctness."""
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
|
||||
device = torch.device("cuda")
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
|
||||
top_p_tensor = torch.full((batch_size,), p, device=device)
|
||||
top_k_tensor = torch.full((batch_size,), k, device=device)
|
||||
|
||||
torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k_tensor, top_p_tensor
|
||||
)
|
||||
sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
|
||||
)
|
||||
|
||||
|
||||
# parameter space
|
||||
batch_size_range = [16, 64, 128]
|
||||
vocab_size_range = [111, 32000]
|
||||
p_range = [0.1, 0.5]
|
||||
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "vocab_size", "p"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "sglang"],
|
||||
line_names=["Torch Reference", "SGL Kernel"],
|
||||
styles=[("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="top-k-top-p-joint-sampling-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_sampling(batch_size, vocab_size, p, provider):
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
|
||||
device = torch.device("cuda")
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
top_p_tensor = torch.full((batch_size,), p, device=device)
|
||||
top_k_tensor = torch.full((batch_size,), k, device=device)
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob.clone(), top_k_tensor, top_p_tensor
|
||||
)
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob.clone(),
|
||||
top_k_tensor,
|
||||
top_p_tensor,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Correctness check
|
||||
for cfg in configs:
|
||||
calculate_diff(*cfg)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting performance benchmark...")
|
||||
benchmark_sampling.run(print_data=True)
|
||||
Reference in New Issue
Block a user