2025-01-13 13:16:58 +08:00
import torch
import triton
import triton . language as tl
@triton.jit
def _per_token_quant_int8 (
x_ptr ,
xq_ptr ,
scale_ptr ,
stride_x ,
stride_xq ,
N ,
BLOCK : tl . constexpr ,
) :
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl . program_id ( 0 )
cols = tl . arange ( 0 , BLOCK )
mask = cols < N
x = tl . load ( x_ptr + row_id * stride_x + cols , mask = mask , other = 0.0 ) . to ( tl . float32 )
absmax = tl . maximum ( tl . max ( tl . abs ( x ) ) , 1e-10 )
scale_x = absmax / 127
2025-01-13 20:32:17 +08:00
x_q = x * ( 127 / absmax )
x_q = tl . extra . cuda . libdevice . round ( x_q ) . to ( tl . int8 )
2025-01-13 13:16:58 +08:00
tl . store ( xq_ptr + row_id * stride_xq + cols , x_q , mask = mask )
tl . store ( scale_ptr + row_id , scale_x )
def per_token_quant_int8 ( x ) :
M = x . numel ( ) / / x . shape [ - 1 ]
N = x . shape [ - 1 ]
x_q = torch . empty_like ( x , device = x . device , dtype = torch . int8 )
scales = torch . empty ( x . shape [ : - 1 ] + ( 1 , ) , device = x . device , dtype = torch . float32 )
BLOCK = triton . next_power_of_2 ( N )
# heuristics for number of warps
num_warps = min ( max ( BLOCK / / 256 , 1 ) , 8 )
assert x . is_contiguous ( )
_per_token_quant_int8 [ ( M , ) ] (
x ,
x_q ,
scales ,
stride_x = x . stride ( - 2 ) ,
stride_xq = x_q . stride ( - 2 ) ,
N = N ,
BLOCK = BLOCK ,
num_warps = num_warps ,
num_stages = 1 ,
)
return x_q , scales