CUDA Arch Independent (#8813)
This commit is contained in:
40
sgl-kernel/python/sgl_kernel/__init__.py
Executable file → Normal file
40
sgl-kernel/python/sgl_kernel/__init__.py
Executable file → Normal file
@@ -1,14 +1,46 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
SYSTEM_ARCH = platform.machine()
|
|
||||||
|
|
||||||
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
|
# copy & modify from torch/utils/cpp_extension.py
|
||||||
if os.path.exists(cuda_path):
|
def _find_cuda_home():
|
||||||
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
|
"""Find the CUDA install path."""
|
||||||
|
# Guess #1
|
||||||
|
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
|
||||||
|
if cuda_home is None:
|
||||||
|
# Guess #2
|
||||||
|
nvcc_path = shutil.which("nvcc")
|
||||||
|
if nvcc_path is not None:
|
||||||
|
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
|
||||||
|
else:
|
||||||
|
# Guess #3
|
||||||
|
cuda_home = "/usr/local/cuda"
|
||||||
|
return cuda_home
|
||||||
|
|
||||||
|
|
||||||
|
if torch.version.hip is None:
|
||||||
|
cuda_home = Path(_find_cuda_home())
|
||||||
|
|
||||||
|
if (cuda_home / "lib").is_dir():
|
||||||
|
cuda_path = cuda_home / "lib"
|
||||||
|
elif (cuda_home / "lib64").is_dir():
|
||||||
|
cuda_path = cuda_home / "lib64"
|
||||||
|
else:
|
||||||
|
# Search for 'libcudart.so.12' in subdirectories
|
||||||
|
for path in cuda_home.rglob("libcudart.so.12"):
|
||||||
|
cuda_path = path.parent
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Could not find CUDA lib directory.")
|
||||||
|
|
||||||
|
cuda_include = (cuda_path / "libcudart.so.12").resolve()
|
||||||
|
if cuda_include.exists():
|
||||||
|
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
|
||||||
|
|
||||||
from sgl_kernel import common_ops
|
from sgl_kernel import common_ops
|
||||||
from sgl_kernel.allreduce import *
|
from sgl_kernel.allreduce import *
|
||||||
|
|||||||
Reference in New Issue
Block a user