Fix enable_v2 in int8 quant (#11470)

This commit is contained in:
fzyzcjy
2025-10-11 21:56:30 +08:00
committed by GitHub
parent f5754d1256
commit bf3e7149be

View File

@@ -15,10 +15,12 @@ if _is_cuda:
# Temporary
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit = True
except ImportError:
from sgl_kernel import (
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
)
from sgl_kernel import sgl_per_token_group_quant_int8
enable_sgl_per_token_group_quant_8bit = False
logger = logging.getLogger(__name__)
@@ -211,9 +213,14 @@ def sglang_per_token_group_quant_int8(
dtype=torch.float32,
)
sgl_per_token_group_quant_8bit(
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
)
# Temporary
if enable_sgl_per_token_group_quant_8bit:
sgl_per_token_group_quant_8bit(
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
)
else:
assert not enable_v2
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
return x_q, x_s