Fix enable_v2 in int8 quant (#11470)
This commit is contained in:
@@ -15,10 +15,12 @@ if _is_cuda:
|
||||
# Temporary
|
||||
try:
|
||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = True
|
||||
except ImportError:
|
||||
from sgl_kernel import (
|
||||
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
|
||||
)
|
||||
from sgl_kernel import sgl_per_token_group_quant_int8
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -211,9 +213,14 @@ def sglang_per_token_group_quant_int8(
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_8bit(
|
||||
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
|
||||
)
|
||||
# Temporary
|
||||
if enable_sgl_per_token_group_quant_8bit:
|
||||
sgl_per_token_group_quant_8bit(
|
||||
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
|
||||
)
|
||||
else:
|
||||
assert not enable_v2
|
||||
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
Reference in New Issue
Block a user