chore: upgrade sgl-kernel v0.1.6 (#6945)
This commit is contained in:
@@ -49,7 +49,7 @@ runtime_common = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.1.5",
|
||||
"sgl-kernel==0.1.6",
|
||||
"flashinfer_python==0.2.5",
|
||||
"torch==2.6.0",
|
||||
"torchvision==0.21.0",
|
||||
|
||||
@@ -579,7 +579,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if _is_cuda:
|
||||
assert_pkg_version(
|
||||
"sgl-kernel",
|
||||
"0.1.5",
|
||||
"0.1.6",
|
||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||
)
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@ _ENABLE_JIT_DEEPGEMM = False
|
||||
try:
|
||||
import deep_gemm
|
||||
from deep_gemm import get_num_sms
|
||||
from deep_gemm.jit import build
|
||||
from deep_gemm.jit.compiler import get_nvcc_compiler
|
||||
from deep_gemm.jit_kernels.gemm import get_best_configs
|
||||
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
||||
from deep_gemm.jit_kernels.tuner import jit_tuner
|
||||
|
||||
sm_version = get_device_sm()
|
||||
if sm_version == 90:
|
||||
@@ -148,32 +148,28 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
kwargs = {
|
||||
"GEMM_TYPE": GemmType.GroupedMasked,
|
||||
"NUM_TMA_THREADS": num_tma_threads,
|
||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_GROUPS": 1,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"NUM_SMS": num_sms,
|
||||
"SMEM_SIZE": smem_config[0],
|
||||
}
|
||||
_, _ = jit_tuner.compile_and_tune(
|
||||
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
||||
keys={
|
||||
"N": n,
|
||||
"K": k,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_GROUPS": num_groups,
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"GEMM_TYPE": GemmType.GroupedMasked,
|
||||
},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
||||
@@ -187,31 +183,26 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
kwargs = {
|
||||
"GEMM_TYPE": GemmType.GroupedContiguous,
|
||||
"NUM_TMA_THREADS": num_tma_threads,
|
||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_GROUPS": 1,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"NUM_SMS": num_sms,
|
||||
"SMEM_SIZE": smem_config[0],
|
||||
}
|
||||
_, _ = jit_tuner.compile_and_tune(
|
||||
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
||||
keys={
|
||||
"N": n,
|
||||
"K": k,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_GROUPS": num_groups,
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"GEMM_TYPE": GemmType.GroupedContiguous,
|
||||
},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
def _compile_gemm_nt_f8f8bf16_one(
|
||||
@@ -228,28 +219,23 @@ def _compile_gemm_nt_f8f8bf16_one(
|
||||
"GEMM_TYPE": GemmType.Normal,
|
||||
"NUM_TMA_THREADS": num_tma_threads,
|
||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_GROUPS": 1,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"NUM_SMS": num_sms,
|
||||
"SMEM_SIZE": smem_config[0],
|
||||
}
|
||||
_, _ = jit_tuner.compile_and_tune(
|
||||
name="gemm_fp8_fp8_bf16_nt",
|
||||
keys={
|
||||
"N": n,
|
||||
"K": k,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
||||
|
||||
Reference in New Issue
Block a user