From bf3e7149bed51f46ccea2493460c7865585683a0 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 11 Oct 2025 21:56:30 +0800 Subject: [PATCH] Fix enable_v2 in int8 quant (#11470) --- .../srt/layers/quantization/int8_kernel.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py index c77dab1eb..9e92412ac 100644 --- a/python/sglang/srt/layers/quantization/int8_kernel.py +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -15,10 +15,12 @@ if _is_cuda: # Temporary try: from sgl_kernel import sgl_per_token_group_quant_8bit + + enable_sgl_per_token_group_quant_8bit = True except ImportError: - from sgl_kernel import ( - sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit, - ) + from sgl_kernel import sgl_per_token_group_quant_int8 + + enable_sgl_per_token_group_quant_8bit = False logger = logging.getLogger(__name__) @@ -211,9 +213,14 @@ def sglang_per_token_group_quant_int8( dtype=torch.float32, ) - sgl_per_token_group_quant_8bit( - x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2 - ) + # Temporary + if enable_sgl_per_token_group_quant_8bit: + sgl_per_token_group_quant_8bit( + x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2 + ) + else: + assert not enable_v2 + sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) return x_q, x_s