Fix quant kernel accuracy issue (#2865)

This commit is contained in:
Ke Bao
2025-01-13 20:32:17 +08:00
committed by GitHub
parent 17de02f98d
commit f3516c2894

View File

@@ -22,7 +22,8 @@ def _per_token_quant_int8(
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = tl.extra.cuda.libdevice.round(x / scale_x).to(tl.int8)
x_q = x * (127 / absmax)
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)