[misc] deep_gemm fallback to NVRTC when NVCC not found (#6252)
This commit is contained in:
@@ -15,6 +15,7 @@ _ENABLE_JIT_DEEPGEMM = False
|
|||||||
if is_cuda():
|
if is_cuda():
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from deep_gemm import get_num_sms
|
from deep_gemm import get_num_sms
|
||||||
|
from deep_gemm.jit.compiler import get_nvcc_compiler
|
||||||
from deep_gemm.jit_kernels.gemm import get_best_configs
|
from deep_gemm.jit_kernels.gemm import get_best_configs
|
||||||
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
||||||
from deep_gemm.jit_kernels.tuner import jit_tuner
|
from deep_gemm.jit_kernels.tuner import jit_tuner
|
||||||
@@ -48,7 +49,17 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
|||||||
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
||||||
# NVRTC may have performance loss with some cases.
|
# NVRTC may have performance loss with some cases.
|
||||||
# And NVCC JIT speed is also 9x faster in the ref commit
|
# And NVCC JIT speed is also 9x faster in the ref commit
|
||||||
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
|
_USE_NVRTC_DEFAULT = "0"
|
||||||
|
if _ENABLE_JIT_DEEPGEMM:
|
||||||
|
try:
|
||||||
|
get_nvcc_compiler()
|
||||||
|
except:
|
||||||
|
logger.warning(
|
||||||
|
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
|
||||||
|
"and may have performance loss with some cases."
|
||||||
|
)
|
||||||
|
_USE_NVRTC_DEFAULT = "1"
|
||||||
|
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
|
||||||
|
|
||||||
|
|
||||||
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
||||||
|
|||||||
Reference in New Issue
Block a user