CUDA Arch Independent (#8813)
This commit is contained in:
40
sgl-kernel/python/sgl_kernel/__init__.py
Executable file → Normal file
40
sgl-kernel/python/sgl_kernel/__init__.py
Executable file → Normal file
@@ -1,14 +1,46 @@
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
SYSTEM_ARCH = platform.machine()
|
||||
|
||||
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
|
||||
if os.path.exists(cuda_path):
|
||||
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
|
||||
# copy & modify from torch/utils/cpp_extension.py
|
||||
def _find_cuda_home():
|
||||
"""Find the CUDA install path."""
|
||||
# Guess #1
|
||||
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
|
||||
if cuda_home is None:
|
||||
# Guess #2
|
||||
nvcc_path = shutil.which("nvcc")
|
||||
if nvcc_path is not None:
|
||||
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
|
||||
else:
|
||||
# Guess #3
|
||||
cuda_home = "/usr/local/cuda"
|
||||
return cuda_home
|
||||
|
||||
|
||||
if torch.version.hip is None:
|
||||
cuda_home = Path(_find_cuda_home())
|
||||
|
||||
if (cuda_home / "lib").is_dir():
|
||||
cuda_path = cuda_home / "lib"
|
||||
elif (cuda_home / "lib64").is_dir():
|
||||
cuda_path = cuda_home / "lib64"
|
||||
else:
|
||||
# Search for 'libcudart.so.12' in subdirectories
|
||||
for path in cuda_home.rglob("libcudart.so.12"):
|
||||
cuda_path = path.parent
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Could not find CUDA lib directory.")
|
||||
|
||||
cuda_include = (cuda_path / "libcudart.so.12").resolve()
|
||||
if cuda_include.exists():
|
||||
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
|
||||
|
||||
from sgl_kernel import common_ops
|
||||
from sgl_kernel.allreduce import *
|
||||
|
||||
Reference in New Issue
Block a user