[Perf] Tunings for SM100 FP8 CUTLASS kernel (#8818)

This commit is contained in:
henryg
2025-08-13 21:59:22 -07:00
committed by GitHub
parent 733446dd36
commit 841810f227
3 changed files with 177 additions and 27 deletions

View File

@@ -1,10 +1,12 @@
import argparse
import copy
import itertools
from typing import Optional, Tuple
import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
@@ -69,6 +71,21 @@ WEIGHT_SHAPES = {
}
def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
return output, scale
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
@@ -100,19 +117,22 @@ def benchmark(batch_size, provider, N, K):
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
quantiles = [0.5, 0.2, 0.8]
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
if "vllm-fp8" in provider:
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None