Use FlashInfer FP4 gemm. (#8241)
This commit is contained in:
25
python/sglang/srt/layers/quantization/modelopt_quant.py
Normal file → Executable file
25
python/sglang/srt/layers/quantization/modelopt_quant.py
Normal file → Executable file
@@ -35,10 +35,20 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
|
||||
try:
|
||||
from flashinfer import mm_fp4 as fp4_gemm
|
||||
|
||||
enable_flashinfer_fp4_gemm = True
|
||||
except ImportError:
|
||||
if is_cuda():
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm as fp4_gemm
|
||||
else:
|
||||
fp4_gemm = None
|
||||
enable_flashinfer_fp4_gemm = False
|
||||
|
||||
try:
|
||||
from flashinfer import fp4_quantize as fp4_quantize
|
||||
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
||||
except ImportError:
|
||||
flashinfer_cutlass_fused_moe = None
|
||||
@@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
|
||||
assert layer.alpha.dtype == torch.float32
|
||||
|
||||
out = cutlass_scaled_fp4_mm(
|
||||
w = layer.weight
|
||||
w_scale_interleaved = layer.weight_scale_interleaved
|
||||
if enable_flashinfer_fp4_gemm:
|
||||
w = layer.weight.T
|
||||
w_scale_interleaved = layer.weight_scale_interleaved.T
|
||||
out = fp4_gemm(
|
||||
x_fp4,
|
||||
layer.weight,
|
||||
w,
|
||||
x_scale_interleaved,
|
||||
layer.weight_scale_interleaved,
|
||||
w_scale_interleaved,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
210
sgl-kernel/benchmark/bench_fp4_gemm.py
Executable file
210
sgl-kernel/benchmark/bench_fp4_gemm.py
Executable file
@@ -0,0 +1,210 @@
|
||||
import argparse
|
||||
import copy
|
||||
import csv
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
from flashinfer import mm_fp4
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def get_weight_shapes(args):
|
||||
models_tps = args.tp_sizes
|
||||
|
||||
if models_tps == [4]:
|
||||
return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]]
|
||||
|
||||
if models_tps == [8]:
|
||||
return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]]
|
||||
return [
|
||||
[1024, 3584],
|
||||
[7168, 256],
|
||||
[7168, 2304],
|
||||
[9216, 3584],
|
||||
[512, 3584],
|
||||
[7168, 128],
|
||||
[7168, 1152],
|
||||
[4608, 3584],
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
],
|
||||
# x_vals = [64],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["cutlass", "cudnn", "trtllm"],
|
||||
line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"],
|
||||
styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")],
|
||||
ylabel="latency (ms)",
|
||||
plot_name="fp4_gemm_benchmark",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
|
||||
M = batch_size
|
||||
packed_k = K
|
||||
K = 2 * packed_k
|
||||
a_dtype = torch.randn((M, K), dtype=dtype, device="cuda")
|
||||
b_dtype = torch.randn((N, K), dtype=dtype, device="cuda")
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
# print("a_fp4", a_fp4)
|
||||
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
res_fi = torch.empty((M, N), dtype=dtype, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "cutlass":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cudnn":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "trtllm":
|
||||
a_scale_interleaved = a_scale_interleaved.to(torch.uint8)
|
||||
b_scale_interleaved = b_scale_interleaved.to(torch.uint8)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
backend="trtllm",
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if correctness:
|
||||
res_cutlass = cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
backend="cudnn",
|
||||
)
|
||||
assert torch.allclose(
|
||||
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
|
||||
), "cudnn fp4 doesn't match cutlass fp4"
|
||||
mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
backend="trtllm",
|
||||
)
|
||||
assert torch.allclose(
|
||||
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
|
||||
), "trtllm fp4 doesn't match cutlass fp4"
|
||||
|
||||
if csv_file:
|
||||
with open(csv_file, "a", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([provider, M, N, K, ms])
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=torch.dtype,
|
||||
default=torch.bfloat16,
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--correctness",
|
||||
action="store_true",
|
||||
help="Check correctness",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
type=str,
|
||||
default="results_cutlass_cudnn.csv",
|
||||
help="CSV file to save results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.csv:
|
||||
with open(args.csv, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["provider", "m", "n", "k", "time_ms"])
|
||||
|
||||
NKs = get_weight_shapes(args)
|
||||
for N, K in NKs:
|
||||
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_fp4_res",
|
||||
N=N,
|
||||
K=K,
|
||||
dtype=args.dtype,
|
||||
correctness=args.correctness,
|
||||
csv_file=args.csv,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
Reference in New Issue
Block a user