From 230106304db3c3b3857490832a74bfeb9458ed0f Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 11 May 2025 07:41:37 -0700 Subject: [PATCH] chore: upgrade sgl-kernel v0.1.2.post1 (#6196) Co-authored-by: alcanderian --- python/pyproject.toml | 2 +- python/sglang/srt/entrypoints/engine.py | 2 +- .../srt/layers/quantization/deep_gemm.py | 124 ++++++++---------- scripts/ci_install_dependency.sh | 2 +- scripts/ci_install_dependency_8_gpu.sh | 2 +- 5 files changed, 61 insertions(+), 71 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 2e6105650..1a0a498f0 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -48,7 +48,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.1.1", + "sgl-kernel==0.1.2.post1", "flashinfer_python==0.2.5", "torch==2.6.0", "torchvision==0.21.0", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 535935654..6a6961f2f 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -486,7 +486,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda: assert_pkg_version( "sgl-kernel", - "0.1.1", + "0.1.2.post1", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) diff --git a/python/sglang/srt/layers/quantization/deep_gemm.py b/python/sglang/srt/layers/quantization/deep_gemm.py index e14f861fc..16e7d6be2 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm.py +++ b/python/sglang/srt/layers/quantization/deep_gemm.py @@ -16,11 +16,7 @@ if is_cuda(): import deep_gemm from deep_gemm import get_num_sms from deep_gemm.jit_kernels.gemm import get_best_configs - from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes - from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template - from deep_gemm.jit_kernels.m_grouped_gemm import ( - template as deep_gemm_grouped_gemm_template, - ) + from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType from deep_gemm.jit_kernels.tuner import jit_tuner sm_version = get_device_sm() @@ -45,10 +41,15 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4) _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false") # Force redirect deep_gemm cache_dir -os.environ["DG_CACHE_DIR"] = os.getenv( - "SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm" +os.environ["DG_JIT_CACHE_DIR"] = os.getenv( + "SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm") ) +# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f +# NVRTC may have performance loss with some cases. +# And NVCC JIT speed is also 9x faster in the ref commit +os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0") + def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): global _BUILTIN_M_LIST @@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one( num_groups: int, config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], ) -> None: - # Auto-tuning with compilation - global deep_gemm_includes, deep_gemm_grouped_gemm_template - _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config - _ = jit_tuner.compile_and_tune( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + kwargs = { + "NUM_TMA_THREADS": num_tma_threads, + "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, + "BLOCK_K": block_k, + "NUM_SMS": num_sms, + "SMEM_SIZE": smem_config[0], + } + _, _ = jit_tuner.compile_and_tune( name="m_grouped_gemm_fp8_fp8_bf16_nt", keys={ "N": n, @@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one( "NUM_STAGES": num_stages, "NUM_TMA_MULTICAST": tma_multicast_config[0], "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], - "GEMM_TYPE": "GroupedMasked", + "GEMM_TYPE": GemmType.GroupedMasked, }, space=(), - includes=deep_gemm_includes, - arg_defs=( - ("lhs", torch.float8_e4m3fn), - ("lhs_scales", torch.float), - ("rhs", torch.float8_e4m3fn), - ("rhs_scales", torch.float), - ("out", torch.bfloat16), - ("grouped_layout", torch.int32), - ("m", int), - ("stream", torch.cuda.Stream), - ("num_sms", int), - ("smem_size", int), - ), - template=deep_gemm_grouped_gemm_template, - args=[], + kwargs=kwargs, + runtime_cls=FP8GemmRuntime, ) @@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one( num_groups: int, config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], ) -> None: - global deep_gemm_includes, deep_gemm_grouped_gemm_template - _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config - _ = jit_tuner.compile_and_tune( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + kwargs = { + "NUM_TMA_THREADS": num_tma_threads, + "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, + "BLOCK_K": block_k, + "NUM_SMS": num_sms, + "SMEM_SIZE": smem_config[0], + } + _, _ = jit_tuner.compile_and_tune( name="m_grouped_gemm_fp8_fp8_bf16_nt", keys={ "N": n, @@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one( "NUM_STAGES": num_stages, "NUM_TMA_MULTICAST": tma_multicast_config[0], "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], - "GEMM_TYPE": "GroupedContiguous", + "GEMM_TYPE": GemmType.GroupedContiguous, }, space=(), - includes=deep_gemm_includes, - arg_defs=( - ("lhs", torch.float8_e4m3fn), - ("lhs_scales", torch.float), - ("rhs", torch.float8_e4m3fn), - ("rhs_scales", torch.float), - ("out", torch.bfloat16), - ("grouped_layout", torch.int32), - ("m", int), - ("num_groups", int), - ("stream", torch.cuda.Stream), - ("num_sms", int), - ("smem_size", int), - ), - template=deep_gemm_grouped_gemm_template, - args=[], + kwargs=kwargs, + runtime_cls=FP8GemmRuntime, ) @@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one( _: int, # _ is a dummy parameter to align with other interfaces config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], ) -> None: - global deep_gemm_includes, deep_gemm_gemm_template - _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config - _ = jit_tuner.compile_and_tune( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + kwargs = { + "GEMM_TYPE": GemmType.Normal, + "NUM_TMA_THREADS": num_tma_threads, + "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, + "NUM_GROUPS": 1, + "BLOCK_K": block_k, + "NUM_SMS": num_sms, + "SMEM_SIZE": smem_config[0], + } + _, _ = jit_tuner.compile_and_tune( name="gemm_fp8_fp8_bf16_nt", keys={ "N": n, @@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one( "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], }, space=(), - includes=deep_gemm_includes, - arg_defs=( - ("lhs", torch.float8_e4m3fn), - ("lhs_scales", torch.float), - ("rhs", torch.float8_e4m3fn), - ("rhs_scales", torch.float), - ("out", torch.bfloat16), - ("m", int), - ("stream", torch.cuda.Stream), - ("num_sms", int), - ("smem_size", int), - ), - template=deep_gemm_gemm_template, - args=[], + kwargs=kwargs, + runtime_cls=FP8GemmRuntime, ) @@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): from deep_gemm.jit.runtime import RuntimeCache - origin_func = RuntimeCache.__getitem__ + origin_func = RuntimeCache.get def __patched_func(self, *args, **kwargs): ret = origin_func(self, *args, **kwargs) @@ -385,6 +375,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): ) return ret - RuntimeCache.__getitem__ = __patched_func + RuntimeCache.get = __patched_func yield - RuntimeCache.__getitem__ = origin_func + RuntimeCache.get = origin_func diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 45290c8a6..4e50cbbdb 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -16,7 +16,7 @@ rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel* pip install --upgrade pip # Install sgl-kernel -pip install sgl-kernel==0.1.1 --no-cache-dir +pip install sgl-kernel==0.1.2.post1 --no-cache-dir # Install the main package pip install -e "python[all]" diff --git a/scripts/ci_install_dependency_8_gpu.sh b/scripts/ci_install_dependency_8_gpu.sh index 5fe1bb419..6bd28c4c5 100755 --- a/scripts/ci_install_dependency_8_gpu.sh +++ b/scripts/ci_install_dependency_8_gpu.sh @@ -34,7 +34,7 @@ rm -rf /usr/local/include/nvshmem* pip install --upgrade pip # Install sgl-kernel -pip install sgl-kernel==0.1.1 --no-cache-dir +pip install sgl-kernel==0.1.2.post1 --no-cache-dir # Install the main package pip install -e "python[all]"