sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
133
benchmark/kernels/quantization/bench_fp4_quant.py
Normal file
133
benchmark/kernels/quantization/bench_fp4_quant.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
|
||||
from sgl_kernel.elementwise import silu_and_mul
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
|
||||
|
||||
def _test_accuracy_once(E, M, K, input_dtype, device):
|
||||
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
|
||||
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
|
||||
masks = torch.full((E,), M, dtype=torch.int32, device=device)
|
||||
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
|
||||
out1, blk_scales1 = scaled_fp4_grouped_quant(
|
||||
silu_and_mul(x),
|
||||
glb_scales,
|
||||
masks,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out1)
|
||||
torch.testing.assert_close(blk_scales, blk_scales1)
|
||||
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")
|
||||
|
||||
|
||||
NUM_RANKS = 48
|
||||
M_PER_RANKs = [128, 256, 512, 1024]
|
||||
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
|
||||
Ks = [2048, 4096, 7168]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["M", "K"],
|
||||
x_vals=list(itertools.product(Ms, Ks)),
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
|
||||
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="fp4 quant",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(M, K, provider):
|
||||
E = 6
|
||||
device = "cuda"
|
||||
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
|
||||
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
|
||||
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
|
||||
fp8_out = torch.empty(
|
||||
(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
x.shape[2] // 2,
|
||||
),
|
||||
device=x.device,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
scale_block_size = 128
|
||||
fp8_scales = torch.empty(
|
||||
(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
x.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "triton_fp8":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: silu_and_mul_masked_post_quant_fwd(
|
||||
x,
|
||||
fp8_out,
|
||||
fp8_scales,
|
||||
scale_block_size,
|
||||
masks,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cuda_unfused_fp4":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: scaled_fp4_grouped_quant(
|
||||
silu_and_mul(x),
|
||||
glb_scales,
|
||||
masks,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cuda_fused_fp4":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: silu_and_mul_scaled_fp4_grouped_quant(
|
||||
x,
|
||||
glb_scales,
|
||||
masks,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
E = 6
|
||||
N_RANKS = 48
|
||||
Ms = [128, 256, 512, 1024]
|
||||
Ks = [2048, 4096, 7168]
|
||||
input_dtype = torch.bfloat16
|
||||
for M in Ms:
|
||||
for K in Ks:
|
||||
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./bench_fp4_quant_res",
|
||||
help="Path to save fp4 quant benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
test_accuracy()
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
|
||||
Reference in New Issue
Block a user