This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
@@ -109,8 +108,7 @@ def bench_kineto(
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([int(re.search(name, line) is not None) for line in prof_lines])
|
||||
== 1
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
|
||||
|
||||
# Save chrome traces
|
||||
@@ -124,7 +122,7 @@ def bench_kineto(
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if re.search(name, line) is not None:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
|
||||
@@ -43,17 +43,11 @@ _is_cpu = is_cpu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
# Temporary
|
||||
try:
|
||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = True
|
||||
except ImportError:
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = False
|
||||
from sgl_kernel import (
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
|
||||
if _is_hip:
|
||||
if _use_aiter:
|
||||
@@ -502,24 +496,9 @@ def sglang_per_token_group_quant_fp8(
|
||||
)
|
||||
|
||||
if x.shape[0] > 0:
|
||||
# Temporary
|
||||
if enable_sgl_per_token_group_quant_8bit:
|
||||
sgl_per_token_group_quant_8bit(
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
scale_ue8m0,
|
||||
fuse_silu_and_mul,
|
||||
masked_m,
|
||||
)
|
||||
else:
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@@ -12,13 +12,7 @@ from sglang.srt.utils import get_device_name, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
# Temporary
|
||||
try:
|
||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||
except ImportError:
|
||||
from sgl_kernel import (
|
||||
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
|
||||
)
|
||||
from sgl_kernel import sgl_per_token_group_quant_int8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -210,7 +204,7 @@ def sglang_per_token_group_quant_int8(
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
Reference in New Issue
Block a user