import torch import triton import triton.language as tl @triton.jit def _per_token_quant_int8( x_ptr, xq_ptr, scale_ptr, stride_x, stride_xq, N, BLOCK: tl.constexpr, ): # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 row_id = tl.program_id(0) cols = tl.arange(0, BLOCK) mask = cols < N x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) scale_x = absmax / 127 x_q = x * (127 / absmax) x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) tl.store(scale_ptr + row_id, scale_x) def per_token_quant_int8(x): M = x.numel() // x.shape[-1] N = x.shape[-1] x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32) BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) assert x.is_contiguous() _per_token_quant_int8[(M,)]( x, x_q, scales, stride_x=x.stride(-2), stride_xq=x_q.stride(-2), N=N, BLOCK=BLOCK, num_warps=num_warps, num_stages=1, ) return x_q, scales