[FEATURE] Enhance platform compatibility for ARM (#5746)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -24,9 +25,19 @@ from setuptools.command.build_py import build_py
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
root = Path(__file__).parent.resolve()
|
||||
arch = platform.machine().lower()
|
||||
|
||||
if arch in ("x86_64", "amd64"):
|
||||
plat_name = "manylinux2014_x86_64"
|
||||
elif arch in ("aarch64", "arm64"):
|
||||
plat_name = "manylinux2014_aarch64"
|
||||
elif arch.startswith("ppc"):
|
||||
plat_name = "manylinux2014_ppc64le"
|
||||
else:
|
||||
plat_name = f"manylinux2014_{arch}"
|
||||
|
||||
if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv:
|
||||
sys.argv.extend(["--plat-name", "manylinux2014_x86_64"])
|
||||
sys.argv.extend(["--plat-name", plat_name])
|
||||
|
||||
|
||||
def _get_version():
|
||||
@@ -70,7 +81,7 @@ cmdclass = {
|
||||
}
|
||||
Extension = CppExtension
|
||||
|
||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"]
|
||||
|
||||
ext_modules = [
|
||||
Extension(
|
||||
|
||||
Reference in New Issue
Block a user