[FEATURE] Enhance platform compatibility for ARM (#5746)
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
|
||||
import torch
|
||||
|
||||
if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
|
||||
ctypes.CDLL(
|
||||
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
|
||||
mode=ctypes.RTLD_GLOBAL,
|
||||
)
|
||||
SYSTEM_ARCH = platform.machine()
|
||||
|
||||
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
|
||||
if os.path.exists(cuda_path):
|
||||
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
|
||||
|
||||
from sgl_kernel import common_ops
|
||||
from sgl_kernel.allreduce import *
|
||||
|
||||
Reference in New Issue
Block a user