[1/2] Optimizations and refactors about quant kernel (#9534)
This commit is contained in:
@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
# Temporary
|
||||
try:
|
||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = True
|
||||
except ImportError:
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = False
|
||||
|
||||
if _is_hip:
|
||||
if _use_aiter:
|
||||
@@ -496,9 +502,24 @@ def sglang_per_token_group_quant_fp8(
|
||||
)
|
||||
|
||||
if x.shape[0] > 0:
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
# Temporary
|
||||
if enable_sgl_per_token_group_quant_8bit:
|
||||
sgl_per_token_group_quant_8bit(
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
scale_ue8m0,
|
||||
fuse_silu_and_mul,
|
||||
masked_m,
|
||||
)
|
||||
else:
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_token_group_quant_int8
|
||||
# Temporary
|
||||
try:
|
||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||
except ImportError:
|
||||
from sgl_kernel import (
|
||||
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8(
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||
sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
Reference in New Issue
Block a user