diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py index d1e74c604..91b56f9e0 100644 --- a/python/sglang/srt/layers/quantization/int8_kernel.py +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -22,7 +22,8 @@ def _per_token_quant_int8( x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) scale_x = absmax / 127 - x_q = tl.extra.cuda.libdevice.round(x / scale_x).to(tl.int8) + x_q = x * (127 / absmax) + x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) tl.store(scale_ptr + row_id, scale_x)