From 17299f088a01a4ae1428a5d7e7073def61c584e7 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Tue, 13 May 2025 16:41:35 +0800 Subject: [PATCH] [misc] deep_gemm fallback to NVRTC when NVCC not found (#6252) --- python/sglang/srt/layers/quantization/deep_gemm.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/deep_gemm.py b/python/sglang/srt/layers/quantization/deep_gemm.py index 16e7d6be2..0b1aa591f 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm.py +++ b/python/sglang/srt/layers/quantization/deep_gemm.py @@ -15,6 +15,7 @@ _ENABLE_JIT_DEEPGEMM = False if is_cuda(): import deep_gemm from deep_gemm import get_num_sms + from deep_gemm.jit.compiler import get_nvcc_compiler from deep_gemm.jit_kernels.gemm import get_best_configs from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType from deep_gemm.jit_kernels.tuner import jit_tuner @@ -48,7 +49,17 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv( # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f # NVRTC may have performance loss with some cases. # And NVCC JIT speed is also 9x faster in the ref commit -os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0") +_USE_NVRTC_DEFAULT = "0" +if _ENABLE_JIT_DEEPGEMM: + try: + get_nvcc_compiler() + except: + logger.warning( + "NVCC Compiler not found, use NVRTC for DeepGEMM JIT " + "and may have performance loss with some cases." + ) + _USE_NVRTC_DEFAULT = "1" +os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT) def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):