From 844e2f227ab0cce6ef818a719170ce37b9eb1e1b Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Mon, 19 May 2025 15:44:03 +0800 Subject: [PATCH] Fix nodeepgemm init (#6417) --- python/sglang/srt/layers/quantization/deep_gemm.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/deep_gemm.py b/python/sglang/srt/layers/quantization/deep_gemm.py index 0b1aa591f..0292b21aa 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm.py +++ b/python/sglang/srt/layers/quantization/deep_gemm.py @@ -11,8 +11,10 @@ from tqdm.contrib.concurrent import thread_map from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda +logger = logging.getLogger(__name__) _ENABLE_JIT_DEEPGEMM = False -if is_cuda(): + +try: import deep_gemm from deep_gemm import get_num_sms from deep_gemm.jit.compiler import get_nvcc_compiler @@ -24,14 +26,14 @@ if is_cuda(): if sm_version == 90: if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"): _ENABLE_JIT_DEEPGEMM = True +except ImportError: + logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.") def get_enable_jit_deepgemm(): return _ENABLE_JIT_DEEPGEMM -logger = logging.getLogger(__name__) - _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var( "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"