diff --git a/benchmark/kernels/quantization/bench_int8_quant.py b/benchmark/kernels/quantization/bench_int8_quant.py new file mode 100644 index 000000000..94b795690 --- /dev/null +++ b/benchmark/kernels/quantization/bench_int8_quant.py @@ -0,0 +1,94 @@ +import argparse + +import torch +import triton +from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant + +from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 + + +@torch.compile(backend="inductor") +def torch_int8_quant(x): + int8_max = torch.iinfo(torch.int8).max + + abs_max = x.abs().max(dim=-1, keepdim=True).values + scales = abs_max.to(torch.float32) / float(int8_max) + + q_x = (x / scales).round().to(torch.int8) + + return q_x, scales + + +def _test_accuracy_once(M, K, input_dtype, device): + x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000 + out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True) + out1, scales1 = per_token_quant_int8(x) + out2, scales2 = torch_int8_quant(x) + torch.testing.assert_close(out, out2, atol=1, rtol=0) + torch.testing.assert_close(out, out1, atol=1, rtol=0) + torch.testing.assert_close(scales, scales2) + torch.testing.assert_close(scales1, scales2) + print(f"M: {M}, K: {K}, type: {input_dtype} OK") + + +def test_accuracy(): + Ms = [1, 13, 128, 1024, 2048, 4096] + Ks = [512, 1024, 2048, 8192] + input_dtypes = [torch.float16, torch.bfloat16] + for M in Ms: + for K in Ks: + for input_dtype in input_dtypes: + _test_accuracy_once(M, K, input_dtype, "cuda") + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=["vllm op", "triton", "torch.compile"], + line_names=["vllm op", "triton", "torch.compile"], + styles=[("blue", "-"), ("orange", "-"), ("red", "-")], + ylabel="ms", + plot_name="int8 per token quant", + args={}, + ) +) +def benchmark(batch_size, provider): + M, K = batch_size, 16384 + x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000 + + quantiles = [0.5, 0.2, 0.8] + if provider == "vllm op": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: vllm_scaled_int8_quant(x, symmetric=True), + quantiles=quantiles, + ) + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: per_token_quant_int8(x), + quantiles=quantiles, + ) + if provider == "torch.compile": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch_int8_quant(x), + quantiles=quantiles, + ) + + return ms, min_ms, max_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./bench_int8_quant_res", + help="Path to save int8 quant benchmark results", + ) + args = parser.parse_args() + + test_accuracy() + + benchmark.run(print_data=True, show_plots=True, save_path=args.save_path) diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py new file mode 100644 index 000000000..d1e74c604 --- /dev/null +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -0,0 +1,53 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _per_token_quant_int8( + x_ptr, + xq_ptr, + scale_ptr, + stride_x, + stride_xq, + N, + BLOCK: tl.constexpr, +): + # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 + row_id = tl.program_id(0) + + cols = tl.arange(0, BLOCK) + mask = cols < N + + x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) + scale_x = absmax / 127 + x_q = tl.extra.cuda.libdevice.round(x / scale_x).to(tl.int8) + + tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) + tl.store(scale_ptr + row_id, scale_x) + + +def per_token_quant_int8(x): + M = x.numel() // x.shape[-1] + N = x.shape[-1] + x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) + scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32) + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + + assert x.is_contiguous() + _per_token_quant_int8[(M,)]( + x, + x_q, + scales, + stride_x=x.stride(-2), + stride_xq=x_q.stride(-2), + N=N, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + + return x_q, scales