Add int8 quant kernel (#2848)
This commit is contained in:
94
benchmark/kernels/quantization/bench_int8_quant.py
Normal file
94
benchmark/kernels/quantization/bench_int8_quant.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(backend="inductor")
|
||||||
|
def torch_int8_quant(x):
|
||||||
|
int8_max = torch.iinfo(torch.int8).max
|
||||||
|
|
||||||
|
abs_max = x.abs().max(dim=-1, keepdim=True).values
|
||||||
|
scales = abs_max.to(torch.float32) / float(int8_max)
|
||||||
|
|
||||||
|
q_x = (x / scales).round().to(torch.int8)
|
||||||
|
|
||||||
|
return q_x, scales
|
||||||
|
|
||||||
|
|
||||||
|
def _test_accuracy_once(M, K, input_dtype, device):
|
||||||
|
x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000
|
||||||
|
out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True)
|
||||||
|
out1, scales1 = per_token_quant_int8(x)
|
||||||
|
out2, scales2 = torch_int8_quant(x)
|
||||||
|
torch.testing.assert_close(out, out2, atol=1, rtol=0)
|
||||||
|
torch.testing.assert_close(out, out1, atol=1, rtol=0)
|
||||||
|
torch.testing.assert_close(scales, scales2)
|
||||||
|
torch.testing.assert_close(scales1, scales2)
|
||||||
|
print(f"M: {M}, K: {K}, type: {input_dtype} OK")
|
||||||
|
|
||||||
|
|
||||||
|
def test_accuracy():
|
||||||
|
Ms = [1, 13, 128, 1024, 2048, 4096]
|
||||||
|
Ks = [512, 1024, 2048, 8192]
|
||||||
|
input_dtypes = [torch.float16, torch.bfloat16]
|
||||||
|
for M in Ms:
|
||||||
|
for K in Ks:
|
||||||
|
for input_dtype in input_dtypes:
|
||||||
|
_test_accuracy_once(M, K, input_dtype, "cuda")
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["vllm op", "triton", "torch.compile"],
|
||||||
|
line_names=["vllm op", "triton", "torch.compile"],
|
||||||
|
styles=[("blue", "-"), ("orange", "-"), ("red", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name="int8 per token quant",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider):
|
||||||
|
M, K = batch_size, 16384
|
||||||
|
x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
if provider == "vllm op":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: vllm_scaled_int8_quant(x, symmetric=True),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
if provider == "triton":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: per_token_quant_int8(x),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
if provider == "torch.compile":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: torch_int8_quant(x),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./bench_int8_quant_res",
|
||||||
|
help="Path to save int8 quant benchmark results",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
test_accuracy()
|
||||||
|
|
||||||
|
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
|
||||||
53
python/sglang/srt/layers/quantization/int8_kernel.py
Normal file
53
python/sglang/srt/layers/quantization/int8_kernel.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _per_token_quant_int8(
|
||||||
|
x_ptr,
|
||||||
|
xq_ptr,
|
||||||
|
scale_ptr,
|
||||||
|
stride_x,
|
||||||
|
stride_xq,
|
||||||
|
N,
|
||||||
|
BLOCK: tl.constexpr,
|
||||||
|
):
|
||||||
|
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
||||||
|
row_id = tl.program_id(0)
|
||||||
|
|
||||||
|
cols = tl.arange(0, BLOCK)
|
||||||
|
mask = cols < N
|
||||||
|
|
||||||
|
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
||||||
|
scale_x = absmax / 127
|
||||||
|
x_q = tl.extra.cuda.libdevice.round(x / scale_x).to(tl.int8)
|
||||||
|
|
||||||
|
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
||||||
|
tl.store(scale_ptr + row_id, scale_x)
|
||||||
|
|
||||||
|
|
||||||
|
def per_token_quant_int8(x):
|
||||||
|
M = x.numel() // x.shape[-1]
|
||||||
|
N = x.shape[-1]
|
||||||
|
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
|
||||||
|
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
|
||||||
|
BLOCK = triton.next_power_of_2(N)
|
||||||
|
# heuristics for number of warps
|
||||||
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||||
|
|
||||||
|
assert x.is_contiguous()
|
||||||
|
_per_token_quant_int8[(M,)](
|
||||||
|
x,
|
||||||
|
x_q,
|
||||||
|
scales,
|
||||||
|
stride_x=x.stride(-2),
|
||||||
|
stride_xq=x_q.stride(-2),
|
||||||
|
N=N,
|
||||||
|
BLOCK=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return x_q, scales
|
||||||
Reference in New Issue
Block a user