Add int8 quant kernel (#2848)
This commit is contained in:
53
python/sglang/srt/layers/quantization/int8_kernel.py
Normal file
53
python/sglang/srt/layers/quantization/int8_kernel.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_quant_int8(
|
||||
x_ptr,
|
||||
xq_ptr,
|
||||
scale_ptr,
|
||||
stride_x,
|
||||
stride_xq,
|
||||
N,
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
||||
row_id = tl.program_id(0)
|
||||
|
||||
cols = tl.arange(0, BLOCK)
|
||||
mask = cols < N
|
||||
|
||||
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
||||
scale_x = absmax / 127
|
||||
x_q = tl.extra.cuda.libdevice.round(x / scale_x).to(tl.int8)
|
||||
|
||||
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
||||
tl.store(scale_ptr + row_id, scale_x)
|
||||
|
||||
|
||||
def per_token_quant_int8(x):
|
||||
M = x.numel() // x.shape[-1]
|
||||
N = x.shape[-1]
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
|
||||
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
|
||||
assert x.is_contiguous()
|
||||
_per_token_quant_int8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
scales,
|
||||
stride_x=x.stride(-2),
|
||||
stride_xq=x_q.stride(-2),
|
||||
N=N,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
return x_q, scales
|
||||
Reference in New Issue
Block a user