fix undefined symbol cudaGetDriverEntryPointByVersion (#3372)

This commit is contained in:
Yineng Zhang
2025-02-07 19:32:45 +08:00
committed by GitHub
parent 2b1808cec4
commit 45c87e083f

View File

@@ -1,3 +1,12 @@
import ctypes
import os
if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
ctypes.CDLL(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
mode=ctypes.RTLD_GLOBAL,
)
from sgl_kernel.ops import (
apply_rope_with_cos_sin_cache_inplace,
bmm_fp8,