fix undefined symbol cudaGetDriverEntryPointByVersion (#3372)
This commit is contained in:
@@ -1,3 +1,12 @@
|
|||||||
|
import ctypes
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
|
||||||
|
ctypes.CDLL(
|
||||||
|
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
|
||||||
|
mode=ctypes.RTLD_GLOBAL,
|
||||||
|
)
|
||||||
|
|
||||||
from sgl_kernel.ops import (
|
from sgl_kernel.ops import (
|
||||||
apply_rope_with_cos_sin_cache_inplace,
|
apply_rope_with_cos_sin_cache_inplace,
|
||||||
bmm_fp8,
|
bmm_fp8,
|
||||||
|
|||||||
Reference in New Issue
Block a user