[1/2] Optimizations and refactors about quant kernel (#9534)

This commit is contained in:
fzyzcjy
2025-09-05 18:45:08 +08:00
committed by GitHub
parent afd9f2f560
commit 339f8eef09
11 changed files with 996 additions and 329 deletions

View File

@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import (
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
)
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
# Temporary
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit = True
except ImportError:
from sgl_kernel import sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit = False
if _is_hip:
if _use_aiter:
@@ -496,9 +502,24 @@ def sglang_per_token_group_quant_fp8(
)
if x.shape[0] > 0:
sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
# Temporary
if enable_sgl_per_token_group_quant_8bit:
sgl_per_token_group_quant_8bit(
x,
x_q,
x_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
)
else:
sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
return x_q, x_s

View File

@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_int8
# Temporary
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
except ImportError:
from sgl_kernel import (
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
)
logger = logging.getLogger(__name__)
@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8(
dtype=torch.float32,
)
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, int8_min, int8_max)
return x_q, x_s