diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py old mode 100644 new mode 100755 index 223d7d43f..36ffd1275 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -35,10 +35,20 @@ if TYPE_CHECKING: from sglang.srt.layers.moe.topk import TopKOutput if is_cuda(): - from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant + from sgl_kernel import scaled_fp4_quant + +try: + from flashinfer import mm_fp4 as fp4_gemm + + enable_flashinfer_fp4_gemm = True +except ImportError: + if is_cuda(): + from sgl_kernel import cutlass_scaled_fp4_mm as fp4_gemm + else: + fp4_gemm = None + enable_flashinfer_fp4_gemm = False try: - from flashinfer import fp4_quantize as fp4_quantize from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe except ImportError: flashinfer_cutlass_fused_moe = None @@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase): assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn assert layer.alpha.dtype == torch.float32 - out = cutlass_scaled_fp4_mm( + w = layer.weight + w_scale_interleaved = layer.weight_scale_interleaved + if enable_flashinfer_fp4_gemm: + w = layer.weight.T + w_scale_interleaved = layer.weight_scale_interleaved.T + out = fp4_gemm( x_fp4, - layer.weight, + w, x_scale_interleaved, - layer.weight_scale_interleaved, + w_scale_interleaved, layer.alpha, output_dtype, ) diff --git a/sgl-kernel/benchmark/bench_fp4_gemm.py b/sgl-kernel/benchmark/bench_fp4_gemm.py new file mode 100755 index 000000000..80773eb07 --- /dev/null +++ b/sgl-kernel/benchmark/bench_fp4_gemm.py @@ -0,0 +1,210 @@ +import argparse +import copy +import csv +import itertools + +import pytest +import torch +import triton +from flashinfer import mm_fp4 +from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def get_weight_shapes(args): + models_tps = args.tp_sizes + + if models_tps == [4]: + return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]] + + if models_tps == [8]: + return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]] + return [ + [1024, 3584], + [7168, 256], + [7168, 2304], + [9216, 3584], + [512, 3584], + [7168, 128], + [7168, 1152], + [4608, 3584], + ] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[ + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 3072, + 4096, + 8192, + 16384, + ], + # x_vals = [64], + x_log=False, + line_arg="provider", + line_vals=["cutlass", "cudnn", "trtllm"], + line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"], + styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")], + ylabel="latency (ms)", + plot_name="fp4_gemm_benchmark", + args={}, + ) +) +def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): + M = batch_size + packed_k = K + K = 2 * packed_k + a_dtype = torch.randn((M, K), dtype=dtype, device="cuda") + b_dtype = torch.randn((N, K), dtype=dtype, device="cuda") + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) + ).to(torch.float32) + + alpha = 1.0 / (a_global_scale * b_global_scale) + a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale) + # print("a_fp4", a_fp4) + b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale) + res_fi = torch.empty((M, N), dtype=dtype, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + if provider == "cutlass": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: cutlass_scaled_fp4_mm( + a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype + ), + quantiles=quantiles, + ) + if provider == "cudnn": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: mm_fp4( + a_fp4, + b_fp4.T, + a_scale_interleaved, + b_scale_interleaved.T, + alpha, + dtype, + res_fi, + ), + quantiles=quantiles, + ) + if provider == "trtllm": + a_scale_interleaved = a_scale_interleaved.to(torch.uint8) + b_scale_interleaved = b_scale_interleaved.to(torch.uint8) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: mm_fp4( + a_fp4, + b_fp4.T, + a_scale_interleaved, + b_scale_interleaved.T, + alpha, + dtype, + res_fi, + backend="trtllm", + ), + quantiles=quantiles, + ) + if correctness: + res_cutlass = cutlass_scaled_fp4_mm( + a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype + ) + mm_fp4( + a_fp4, + b_fp4.T, + a_scale_interleaved, + b_scale_interleaved.T, + alpha, + dtype, + res_fi, + backend="cudnn", + ) + assert torch.allclose( + res_fi, res_cutlass, atol=1e-3, rtol=1e-3 + ), "cudnn fp4 doesn't match cutlass fp4" + mm_fp4( + a_fp4, + b_fp4.T, + a_scale_interleaved, + b_scale_interleaved.T, + alpha, + dtype, + res_fi, + backend="trtllm", + ) + assert torch.allclose( + res_fi, res_cutlass, atol=1e-3, rtol=1e-3 + ), "trtllm fp4 doesn't match cutlass fp4" + + if csv_file: + with open(csv_file, "a", newline="") as f: + writer = csv.writer(f) + writer.writerow([provider, M, N, K, ms]) + + return ms, min_ms, max_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + parser.add_argument( + "--dtype", + type=torch.dtype, + default=torch.bfloat16, + help="Data type", + ) + parser.add_argument( + "--correctness", + action="store_true", + help="Check correctness", + ) + parser.add_argument( + "--csv", + type=str, + default="results_cutlass_cudnn.csv", + help="CSV file to save results", + ) + args = parser.parse_args() + + if args.csv: + with open(args.csv, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["provider", "m", "n", "k", "time_ms"]) + + NKs = get_weight_shapes(args) + for N, K in NKs: + print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") + benchmark.run( + print_data=True, + show_plots=True, + save_path="bench_fp4_res", + N=N, + K=K, + dtype=args.dtype, + correctness=args.correctness, + csv_file=args.csv, + ) + + print("Benchmark finished!")