99 lines
2.6 KiB
Python
99 lines
2.6 KiB
Python
import itertools
|
|
import time
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import triton
|
|
|
|
from sglang.srt.bench_utils import bench_kineto
|
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
|
create_per_token_group_quant_fp8_output_scale,
|
|
)
|
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
|
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
|
|
)
|
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
|
from sglang.srt.utils import is_hip
|
|
|
|
_is_hip = is_hip()
|
|
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
|
|
|
|
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
|
|
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
|
|
group_size_range = [128] # For DeepSeek V3/R1
|
|
# TODO test int8
|
|
dst_dtype_range = [fp8_type_]
|
|
flags_range = [
|
|
dict(
|
|
column_major_scales=False,
|
|
scale_tma_aligned=False,
|
|
scale_ue8m0=False,
|
|
),
|
|
dict(
|
|
column_major_scales=True,
|
|
scale_tma_aligned=False,
|
|
scale_ue8m0=False,
|
|
),
|
|
dict(
|
|
column_major_scales=True,
|
|
scale_tma_aligned=True,
|
|
scale_ue8m0=False,
|
|
),
|
|
dict(
|
|
column_major_scales=True,
|
|
scale_tma_aligned=True,
|
|
scale_ue8m0=True,
|
|
),
|
|
]
|
|
|
|
|
|
configs = list(
|
|
itertools.product(
|
|
num_tokens_range,
|
|
hidden_dim_range,
|
|
group_size_range,
|
|
dst_dtype_range,
|
|
flags_range,
|
|
)
|
|
)
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
|
|
x_vals=configs,
|
|
line_arg="provider",
|
|
line_vals=["triton", "sglang"],
|
|
line_names=["Triton", "SGL Kernel"],
|
|
styles=[("blue", "-"), ("green", "-")],
|
|
ylabel="us",
|
|
plot_name="per-token-group-quant-8bit-performance",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
|
|
if flags["scale_ue8m0"] and group_size != 128:
|
|
return
|
|
|
|
device = torch.device("cuda")
|
|
|
|
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
|
|
|
|
fn, kernel_names = {
|
|
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
|
|
"sglang": (
|
|
sglang_per_token_group_quant_8bit,
|
|
"per_token_group_quant_8bit_kernel",
|
|
),
|
|
}[provider]
|
|
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
|
|
|
|
time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
|
|
return time_s * 1e6
|
|
|
|
|
|
if __name__ == "__main__":
|
|
benchmark.run(print_data=True)
|