Fix nodeepgemm init (#6417)
This commit is contained in:
@@ -11,8 +11,10 @@ from tqdm.contrib.concurrent import thread_map
|
|||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
|
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
_ENABLE_JIT_DEEPGEMM = False
|
_ENABLE_JIT_DEEPGEMM = False
|
||||||
if is_cuda():
|
|
||||||
|
try:
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from deep_gemm import get_num_sms
|
from deep_gemm import get_num_sms
|
||||||
from deep_gemm.jit.compiler import get_nvcc_compiler
|
from deep_gemm.jit.compiler import get_nvcc_compiler
|
||||||
@@ -24,14 +26,14 @@ if is_cuda():
|
|||||||
if sm_version == 90:
|
if sm_version == 90:
|
||||||
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
||||||
_ENABLE_JIT_DEEPGEMM = True
|
_ENABLE_JIT_DEEPGEMM = True
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
|
||||||
|
|
||||||
|
|
||||||
def get_enable_jit_deepgemm():
|
def get_enable_jit_deepgemm():
|
||||||
return _ENABLE_JIT_DEEPGEMM
|
return _ENABLE_JIT_DEEPGEMM
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
||||||
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
||||||
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
||||||
|
|||||||
Reference in New Issue
Block a user