Fix enable_v2 in int8 quant (#11470)
This commit is contained in:
@@ -15,10 +15,12 @@ if _is_cuda:
|
|||||||
# Temporary
|
# Temporary
|
||||||
try:
|
try:
|
||||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||||
|
|
||||||
|
enable_sgl_per_token_group_quant_8bit = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from sgl_kernel import (
|
from sgl_kernel import sgl_per_token_group_quant_int8
|
||||||
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
|
|
||||||
)
|
enable_sgl_per_token_group_quant_8bit = False
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -211,9 +213,14 @@ def sglang_per_token_group_quant_int8(
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
sgl_per_token_group_quant_8bit(
|
# Temporary
|
||||||
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
|
if enable_sgl_per_token_group_quant_8bit:
|
||||||
)
|
sgl_per_token_group_quant_8bit(
|
||||||
|
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert not enable_v2
|
||||||
|
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||||
|
|
||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user