Reland [1/2] Optimizations and refactors about quant kernel (#10312)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
# Temporary
|
||||
try:
|
||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = True
|
||||
except ImportError:
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = False
|
||||
|
||||
if _is_hip:
|
||||
if _use_aiter:
|
||||
@@ -477,6 +483,7 @@ def sglang_per_token_group_quant_fp8(
|
||||
scale_ue8m0: bool = False,
|
||||
fuse_silu_and_mul: bool = False,
|
||||
masked_m: Optional[torch.Tensor] = None,
|
||||
enable_v2: Optional[bool] = None,
|
||||
):
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
@@ -496,9 +503,26 @@ def sglang_per_token_group_quant_fp8(
|
||||
)
|
||||
|
||||
if x.shape[0] > 0:
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
# Temporary
|
||||
if enable_sgl_per_token_group_quant_8bit:
|
||||
sgl_per_token_group_quant_8bit(
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
scale_ue8m0,
|
||||
fuse_silu_and_mul,
|
||||
masked_m,
|
||||
enable_v2=enable_v2,
|
||||
)
|
||||
else:
|
||||
assert not enable_v2
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
@@ -514,6 +538,7 @@ def sglang_per_token_group_quant_8bit(
|
||||
scale_ue8m0: bool = False,
|
||||
fuse_silu_and_mul: bool = False,
|
||||
masked_m: Optional[torch.Tensor] = None,
|
||||
enable_v2: Optional[bool] = None,
|
||||
):
|
||||
from sglang.srt.layers.quantization.int8_kernel import (
|
||||
sglang_per_token_group_quant_int8,
|
||||
@@ -529,6 +554,7 @@ def sglang_per_token_group_quant_8bit(
|
||||
group_size=group_size,
|
||||
eps=eps,
|
||||
dtype=dst_dtype,
|
||||
enable_v2=enable_v2,
|
||||
)
|
||||
|
||||
return sglang_per_token_group_quant_fp8(
|
||||
@@ -540,6 +566,7 @@ def sglang_per_token_group_quant_8bit(
|
||||
scale_ue8m0=scale_ue8m0,
|
||||
fuse_silu_and_mul=fuse_silu_and_mul,
|
||||
masked_m=masked_m,
|
||||
enable_v2=enable_v2,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,11 +8,17 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import get_device_name, is_cuda
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_name, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_token_group_quant_int8
|
||||
# Temporary
|
||||
try:
|
||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||
except ImportError:
|
||||
from sgl_kernel import (
|
||||
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -187,6 +193,7 @@ def sglang_per_token_group_quant_int8(
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = torch.int8,
|
||||
enable_v2: Optional[bool] = None,
|
||||
):
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
@@ -204,7 +211,9 @@ def sglang_per_token_group_quant_int8(
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||
sgl_per_token_group_quant_8bit(
|
||||
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
@@ -108,7 +109,8 @@ def bench_kineto(
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
sum([int(re.search(name, line) is not None) for line in prof_lines])
|
||||
== 1
|
||||
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
|
||||
|
||||
# Save chrome traces
|
||||
@@ -122,7 +124,7 @@ def bench_kineto(
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
if re.search(name, line) is not None:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
|
||||
Reference in New Issue
Block a user