55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def _per_token_quant_int8(
|
|
x_ptr,
|
|
xq_ptr,
|
|
scale_ptr,
|
|
stride_x,
|
|
stride_xq,
|
|
N,
|
|
BLOCK: tl.constexpr,
|
|
):
|
|
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
|
row_id = tl.program_id(0)
|
|
|
|
cols = tl.arange(0, BLOCK)
|
|
mask = cols < N
|
|
|
|
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
|
|
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
|
scale_x = absmax / 127
|
|
x_q = x * (127 / absmax)
|
|
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
|
|
|
|
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
|
tl.store(scale_ptr + row_id, scale_x)
|
|
|
|
|
|
def per_token_quant_int8(x):
|
|
M = x.numel() // x.shape[-1]
|
|
N = x.shape[-1]
|
|
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
|
|
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
|
|
BLOCK = triton.next_power_of_2(N)
|
|
# heuristics for number of warps
|
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
|
|
|
assert x.is_contiguous()
|
|
_per_token_quant_int8[(M,)](
|
|
x,
|
|
x_q,
|
|
scales,
|
|
stride_x=x.stride(-2),
|
|
stride_xq=x_q.stride(-2),
|
|
N=N,
|
|
BLOCK=BLOCK,
|
|
num_warps=num_warps,
|
|
num_stages=1,
|
|
)
|
|
|
|
return x_q, scales
|