diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py index c77dab1eb..9e92412ac 100644 --- a/python/sglang/srt/layers/quantization/int8_kernel.py +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -15,10 +15,12 @@ if _is_cuda: # Temporary try: from sgl_kernel import sgl_per_token_group_quant_8bit + + enable_sgl_per_token_group_quant_8bit = True except ImportError: - from sgl_kernel import ( - sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit, - ) + from sgl_kernel import sgl_per_token_group_quant_int8 + + enable_sgl_per_token_group_quant_8bit = False logger = logging.getLogger(__name__) @@ -211,9 +213,14 @@ def sglang_per_token_group_quant_int8( dtype=torch.float32, ) - sgl_per_token_group_quant_8bit( - x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2 - ) + # Temporary + if enable_sgl_per_token_group_quant_8bit: + sgl_per_token_group_quant_8bit( + x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2 + ) + else: + assert not enable_v2 + sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) return x_q, x_s