[FEATURE] Enhance platform compatibility for ARM (#5746)

This commit is contained in:
Johnny
2025-04-30 00:06:16 +02:00
committed by GitHub
parent 9a62191ba7
commit 2c7dbb7cc2
4 changed files with 32 additions and 14 deletions

View File

@@ -1,13 +1,14 @@
import ctypes
import os
import platform
import torch
if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
ctypes.CDLL(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
mode=ctypes.RTLD_GLOBAL,
)
SYSTEM_ARCH = platform.machine()
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
if os.path.exists(cuda_path):
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
from sgl_kernel import common_ops
from sgl_kernel.allreduce import *