Fix enable_v2 in int8 quant (#11470)

This commit is contained in:
fzyzcjy
2025-10-11 21:56:30 +08:00
committed by GitHub
parent f5754d1256
commit bf3e7149be

View File

@@ -15,10 +15,12 @@ if _is_cuda:
# Temporary # Temporary
try: try:
from sgl_kernel import sgl_per_token_group_quant_8bit from sgl_kernel import sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit = True
except ImportError: except ImportError:
from sgl_kernel import ( from sgl_kernel import sgl_per_token_group_quant_int8
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
) enable_sgl_per_token_group_quant_8bit = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -211,9 +213,14 @@ def sglang_per_token_group_quant_int8(
dtype=torch.float32, dtype=torch.float32,
) )
sgl_per_token_group_quant_8bit( # Temporary
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2 if enable_sgl_per_token_group_quant_8bit:
) sgl_per_token_group_quant_8bit(
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
)
else:
assert not enable_v2
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
return x_q, x_s return x_q, x_s