chore: upgrade sgl-kernel v0.1.6 (#6945)
This commit is contained in:
@@ -49,7 +49,7 @@ runtime_common = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.1.5",
|
"sgl-kernel==0.1.6",
|
||||||
"flashinfer_python==0.2.5",
|
"flashinfer_python==0.2.5",
|
||||||
"torch==2.6.0",
|
"torch==2.6.0",
|
||||||
"torchvision==0.21.0",
|
"torchvision==0.21.0",
|
||||||
|
|||||||
@@ -579,7 +579,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"sgl-kernel",
|
"sgl-kernel",
|
||||||
"0.1.5",
|
"0.1.6",
|
||||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ _ENABLE_JIT_DEEPGEMM = False
|
|||||||
try:
|
try:
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from deep_gemm import get_num_sms
|
from deep_gemm import get_num_sms
|
||||||
|
from deep_gemm.jit import build
|
||||||
from deep_gemm.jit.compiler import get_nvcc_compiler
|
from deep_gemm.jit.compiler import get_nvcc_compiler
|
||||||
from deep_gemm.jit_kernels.gemm import get_best_configs
|
from deep_gemm.jit_kernels.gemm import get_best_configs
|
||||||
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
||||||
from deep_gemm.jit_kernels.tuner import jit_tuner
|
|
||||||
|
|
||||||
sm_version = get_device_sm()
|
sm_version = get_device_sm()
|
||||||
if sm_version == 90:
|
if sm_version == 90:
|
||||||
@@ -148,32 +148,28 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|||||||
block_k = 128
|
block_k = 128
|
||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
"GEMM_TYPE": GemmType.GroupedMasked,
|
||||||
"NUM_TMA_THREADS": num_tma_threads,
|
"NUM_TMA_THREADS": num_tma_threads,
|
||||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||||
"BLOCK_K": block_k,
|
|
||||||
"NUM_SMS": num_sms,
|
|
||||||
"SMEM_SIZE": smem_config[0],
|
|
||||||
}
|
|
||||||
_, _ = jit_tuner.compile_and_tune(
|
|
||||||
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
|
||||||
keys={
|
|
||||||
"N": n,
|
"N": n,
|
||||||
"K": k,
|
"K": k,
|
||||||
|
"NUM_GROUPS": 1,
|
||||||
"BLOCK_M": block_m,
|
"BLOCK_M": block_m,
|
||||||
"BLOCK_N": block_n,
|
"BLOCK_N": block_n,
|
||||||
|
"BLOCK_K": block_k,
|
||||||
"SWIZZLE_D_MODE": smem_config[1],
|
"SWIZZLE_D_MODE": smem_config[1],
|
||||||
"BLOCK_N_PADDING": smem_config[2],
|
"BLOCK_N_PADDING": smem_config[2],
|
||||||
"NUM_GROUPS": num_groups,
|
|
||||||
"NUM_STAGES": num_stages,
|
"NUM_STAGES": num_stages,
|
||||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||||
"GEMM_TYPE": GemmType.GroupedMasked,
|
"NUM_SMS": num_sms,
|
||||||
},
|
"SMEM_SIZE": smem_config[0],
|
||||||
space=(),
|
}
|
||||||
kwargs=kwargs,
|
|
||||||
runtime_cls=FP8GemmRuntime,
|
code = FP8GemmRuntime.generate(kwargs)
|
||||||
)
|
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
||||||
@@ -187,31 +183,26 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
|||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
"GEMM_TYPE": GemmType.GroupedContiguous,
|
||||||
"NUM_TMA_THREADS": num_tma_threads,
|
"NUM_TMA_THREADS": num_tma_threads,
|
||||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||||
"BLOCK_K": block_k,
|
|
||||||
"NUM_SMS": num_sms,
|
|
||||||
"SMEM_SIZE": smem_config[0],
|
|
||||||
}
|
|
||||||
_, _ = jit_tuner.compile_and_tune(
|
|
||||||
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
|
||||||
keys={
|
|
||||||
"N": n,
|
"N": n,
|
||||||
"K": k,
|
"K": k,
|
||||||
|
"NUM_GROUPS": 1,
|
||||||
"BLOCK_M": block_m,
|
"BLOCK_M": block_m,
|
||||||
"BLOCK_N": block_n,
|
"BLOCK_N": block_n,
|
||||||
|
"BLOCK_K": block_k,
|
||||||
"SWIZZLE_D_MODE": smem_config[1],
|
"SWIZZLE_D_MODE": smem_config[1],
|
||||||
"BLOCK_N_PADDING": smem_config[2],
|
"BLOCK_N_PADDING": smem_config[2],
|
||||||
"NUM_GROUPS": num_groups,
|
|
||||||
"NUM_STAGES": num_stages,
|
"NUM_STAGES": num_stages,
|
||||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||||
"GEMM_TYPE": GemmType.GroupedContiguous,
|
"NUM_SMS": num_sms,
|
||||||
},
|
"SMEM_SIZE": smem_config[0],
|
||||||
space=(),
|
}
|
||||||
kwargs=kwargs,
|
|
||||||
runtime_cls=FP8GemmRuntime,
|
code = FP8GemmRuntime.generate(kwargs)
|
||||||
)
|
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _compile_gemm_nt_f8f8bf16_one(
|
def _compile_gemm_nt_f8f8bf16_one(
|
||||||
@@ -228,28 +219,23 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|||||||
"GEMM_TYPE": GemmType.Normal,
|
"GEMM_TYPE": GemmType.Normal,
|
||||||
"NUM_TMA_THREADS": num_tma_threads,
|
"NUM_TMA_THREADS": num_tma_threads,
|
||||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||||
"NUM_GROUPS": 1,
|
|
||||||
"BLOCK_K": block_k,
|
|
||||||
"NUM_SMS": num_sms,
|
|
||||||
"SMEM_SIZE": smem_config[0],
|
|
||||||
}
|
|
||||||
_, _ = jit_tuner.compile_and_tune(
|
|
||||||
name="gemm_fp8_fp8_bf16_nt",
|
|
||||||
keys={
|
|
||||||
"N": n,
|
"N": n,
|
||||||
"K": k,
|
"K": k,
|
||||||
|
"NUM_GROUPS": 1,
|
||||||
"BLOCK_M": block_m,
|
"BLOCK_M": block_m,
|
||||||
"BLOCK_N": block_n,
|
"BLOCK_N": block_n,
|
||||||
|
"BLOCK_K": block_k,
|
||||||
"SWIZZLE_D_MODE": smem_config[1],
|
"SWIZZLE_D_MODE": smem_config[1],
|
||||||
"BLOCK_N_PADDING": smem_config[2],
|
"BLOCK_N_PADDING": smem_config[2],
|
||||||
"NUM_STAGES": num_stages,
|
"NUM_STAGES": num_stages,
|
||||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||||
},
|
"NUM_SMS": num_sms,
|
||||||
space=(),
|
"SMEM_SIZE": smem_config[0],
|
||||||
kwargs=kwargs,
|
}
|
||||||
runtime_cls=FP8GemmRuntime,
|
|
||||||
)
|
code = FP8GemmRuntime.generate(kwargs)
|
||||||
|
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||||
|
|
||||||
|
|
||||||
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
||||||
|
|||||||
Reference in New Issue
Block a user