Reland [1/2] Optimizations and refactors about quant kernel (#10312)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
fzyzcjy
2025-10-11 15:59:03 +08:00
committed by GitHub
parent 129d299278
commit 21337b22b9
13 changed files with 1065 additions and 178 deletions

View File

@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import (
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
)
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
# Temporary
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit = True
except ImportError:
from sgl_kernel import sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit = False
if _is_hip:
if _use_aiter:
@@ -477,6 +483,7 @@ def sglang_per_token_group_quant_fp8(
scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
enable_v2: Optional[bool] = None,
):
assert (
x.shape[-1] % group_size == 0
@@ -496,9 +503,26 @@ def sglang_per_token_group_quant_fp8(
)
if x.shape[0] > 0:
sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
# Temporary
if enable_sgl_per_token_group_quant_8bit:
sgl_per_token_group_quant_8bit(
x,
x_q,
x_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
enable_v2=enable_v2,
)
else:
assert not enable_v2
sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
return x_q, x_s
@@ -514,6 +538,7 @@ def sglang_per_token_group_quant_8bit(
scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
enable_v2: Optional[bool] = None,
):
from sglang.srt.layers.quantization.int8_kernel import (
sglang_per_token_group_quant_int8,
@@ -529,6 +554,7 @@ def sglang_per_token_group_quant_8bit(
group_size=group_size,
eps=eps,
dtype=dst_dtype,
enable_v2=enable_v2,
)
return sglang_per_token_group_quant_fp8(
@@ -540,6 +566,7 @@ def sglang_per_token_group_quant_8bit(
scale_ue8m0=scale_ue8m0,
fuse_silu_and_mul=fuse_silu_and_mul,
masked_m=masked_m,
enable_v2=enable_v2,
)

View File

@@ -8,11 +8,17 @@ import torch
import triton
import triton.language as tl
from sglang.srt.utils import get_device_name, is_cuda
from sglang.srt.utils import get_bool_env_var, get_device_name, is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_int8
# Temporary
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
except ImportError:
from sgl_kernel import (
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
)
logger = logging.getLogger(__name__)
@@ -187,6 +193,7 @@ def sglang_per_token_group_quant_int8(
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
enable_v2: Optional[bool] = None,
):
assert (
x.shape[-1] % group_size == 0
@@ -204,7 +211,9 @@ def sglang_per_token_group_quant_int8(
dtype=torch.float32,
)
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
sgl_per_token_group_quant_8bit(
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
)
return x_q, x_s

View File

@@ -1,4 +1,5 @@
import os
import re
import sys
from contextlib import nullcontext
@@ -108,7 +109,8 @@ def bench_kineto(
if not with_multiple_kernels:
for name in kernel_names:
assert (
sum([name in line for line in prof_lines]) == 1
sum([int(re.search(name, line) is not None) for line in prof_lines])
== 1
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
# Save chrome traces
@@ -122,7 +124,7 @@ def bench_kineto(
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
if re.search(name, line) is not None:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():