add ops
This commit is contained in:
225
torch_mlu_ops-v1.3.2/setup.py
Normal file
225
torch_mlu_ops-v1.3.2/setup.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import torch
|
||||
import subprocess
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from packaging.version import Version
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
import distutils.command.clean # pylint: disable=C0411
|
||||
import setuptools.command.install # pylint: disable=C0411
|
||||
from setuptools import setup, find_packages, distutils # pylint: disable=C0411
|
||||
from torch_mlu.utils.cpp_extension import BuildExtension, MLUExtension
|
||||
|
||||
|
||||
def get_tmo_version(tmo_version_file):
|
||||
if (not os.path.isfile(tmo_version_file)):
|
||||
print("Failed to find version file: {0}".format(tmo_version_file))
|
||||
sys.exit(1)
|
||||
with open(tmo_version_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
if "TMO_VERSION" in line:
|
||||
return Version(line.split("=")[1].strip()).base_version
|
||||
raise RuntimeError(f"Can not find version from {tmo_version_file}.")
|
||||
|
||||
def get_source_files(dir_path, include_kernel: bool = True):
|
||||
source_list = []
|
||||
for file in Path(dir_path).rglob('*.cpp'):
|
||||
source_list.append(file.as_posix())
|
||||
if include_kernel:
|
||||
for file in Path(dir_path).rglob('*.mlu'):
|
||||
source_list.append(file.as_posix())
|
||||
return source_list
|
||||
|
||||
def get_include_dirs(dir_path):
|
||||
out_dirs = []
|
||||
for root, dirs, files in os.walk(dir_path):
|
||||
for dir in dirs:
|
||||
out_dirs.append(os.path.join(root, dir))
|
||||
return out_dirs
|
||||
|
||||
def _check_env_flag(name, default=''):
|
||||
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
|
||||
|
||||
|
||||
class install(setuptools.command.install.install):
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
class Build(BuildExtension):
|
||||
kernel_args = ()
|
||||
kernel_kwargs = {}
|
||||
|
||||
@classmethod
|
||||
def build_kernel(cls, *args, **kwargs):
|
||||
cls.kernel_args = args
|
||||
cls.kernel_kwargs = kwargs
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def pre_run(cls, func, *args, **kwargs):
|
||||
cls.build_kernel = func
|
||||
cls.kernel_args = args
|
||||
cls.kernel_kwargs = kwargs
|
||||
|
||||
def run(self):
|
||||
Build.build_kernel(*Build.kernel_args, **Build.kernel_kwargs)
|
||||
super().run()
|
||||
|
||||
class Clean(distutils.command.clean.clean):
|
||||
def run(self):
|
||||
super().run()
|
||||
for root, dirs, files in os.walk('.'):
|
||||
for dir in dirs:
|
||||
if dir == '__pycache__' or dir == 'build':
|
||||
shutil.rmtree(os.path.join(root, dir))
|
||||
for file in files:
|
||||
if file.endswith('.pyc') or file.endswith('.pyo') or file.endswith('.so'):
|
||||
os.remove(os.path.join(root, file))
|
||||
|
||||
|
||||
def main():
|
||||
BUILD_KERNELS_WITH_CMAKE = _check_env_flag('TMO_BUILD_KERNELS_WITH_CMAKE')
|
||||
DEBUG_MODE = _check_env_flag('TMO_DEBUG_MODE')
|
||||
CXX_FLAGS = []
|
||||
CNCC_FLAGS = []
|
||||
|
||||
# base_dir: bangtransformer
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
csrc_dir = os.path.join(base_dir, "csrc")
|
||||
include_dir_list = [csrc_dir]
|
||||
|
||||
# get tmo ops version
|
||||
tmo_version_file = os.path.join(base_dir, 'build.property')
|
||||
torch_tmo_version = get_tmo_version(tmo_version_file)
|
||||
tv = Version(torch.__version__)
|
||||
torch_tmo_version = torch_tmo_version + "+pt" + f"{tv.major}{tv.minor}"
|
||||
|
||||
# create _version.py
|
||||
env = Environment(loader=FileSystemLoader(base_dir))
|
||||
tpl = env.get_template('version.tpl')
|
||||
ver_cxt = {'VERSION': torch_tmo_version}
|
||||
with open('torch_mlu_ops/_version.py', 'w', encoding='utf-8') as f:
|
||||
f.write(tpl.render(ver_cxt))
|
||||
|
||||
# source files in source_list can not be absolute path
|
||||
source_list = get_source_files("csrc", not BUILD_KERNELS_WITH_CMAKE)
|
||||
include_dir_list.extend(get_include_dirs("csrc"))
|
||||
library_dir_list = []
|
||||
|
||||
neuware_home = os.getenv("NEUWARE_HOME", "/usr/local/neuware")
|
||||
neuware_home_include = os.path.join(neuware_home, "include")
|
||||
neuware_home_lib = os.path.join(neuware_home, "lib64")
|
||||
include_dir_list.append(neuware_home_include)
|
||||
library_dir_list.append(neuware_home_lib)
|
||||
|
||||
CXX_FLAGS += [
|
||||
'-fPIC',
|
||||
'-Wall',
|
||||
'-Werror',
|
||||
'-Wno-error=deprecated-declarations',
|
||||
]
|
||||
if DEBUG_MODE:
|
||||
CXX_FLAGS += ['-Og', '-g', '-DDEBUG']
|
||||
|
||||
TMO_MEM_CHECK = _check_env_flag('TMO_MEM_CHECK')
|
||||
if TMO_MEM_CHECK:
|
||||
SANITIZER_FLAGS=["-O1",
|
||||
"-g",
|
||||
"-DDEBUG",
|
||||
"-fsanitize=address",
|
||||
"-fno-omit-frame-pointer",]
|
||||
CXX_FLAGS.extend(SANITIZER_FLAGS)
|
||||
CNCC_FLAGS.extend(SANITIZER_FLAGS)
|
||||
|
||||
def build_kernel_with_cmake(*args, **kwargs):
|
||||
if DEBUG_MODE:
|
||||
os.environ['BUILD_MODE'] = 'debug'
|
||||
cmd = os.path.join(csrc_dir, "kernels/build.sh")
|
||||
res = subprocess.run(cmd, shell=True)
|
||||
assert res.returncode == 0, 'failed to build tmo kernels by cmake'
|
||||
|
||||
if BUILD_KERNELS_WITH_CMAKE:
|
||||
Build.pre_run(build_kernel_with_cmake)
|
||||
else:
|
||||
CNCC_FLAGS += [
|
||||
'-fPIC',
|
||||
'--bang-arch=compute_30',
|
||||
'--bang-mlu-arch=mtp_592',
|
||||
'--bang-mlu-arch=mtp_613',
|
||||
'-Xbang-cnas',
|
||||
'-fno-soft-pipeline',
|
||||
'-Wall',
|
||||
'-Werror',
|
||||
'-Wno-error=deprecated-declarations',
|
||||
'--no-neuware-version-check',
|
||||
]
|
||||
if platform.machine() == "x86_64":
|
||||
CNCC_FLAGS += ['-mcmodel=large']
|
||||
if DEBUG_MODE:
|
||||
CNCC_FLAGS += ['-Og', '-g', '-DDEBUG']
|
||||
|
||||
# MLUExtension does not add include_dirs automatically for mlu files,
|
||||
# so we have to add them in CNCC_FLAGS.
|
||||
for dir in include_dir_list:
|
||||
CNCC_FLAGS.append("-I" + dir)
|
||||
|
||||
packages=find_packages(
|
||||
include=(
|
||||
"torch_mlu_ops",
|
||||
)
|
||||
)
|
||||
|
||||
# Read in README.md for long_description
|
||||
with open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f:
|
||||
long_description = f.read()
|
||||
|
||||
extra_link_args = ['-Wl,-rpath,$ORIGIN/']
|
||||
if not DEBUG_MODE:
|
||||
extra_link_args += ['-Wl,--strip-all']
|
||||
if BUILD_KERNELS_WITH_CMAKE:
|
||||
extra_link_args+=['-Wl,-whole-archive',
|
||||
os.path.join(csrc_dir, "kernels/build/lib/libtmo_kernels.a"),
|
||||
'-Wl,-no-whole-archive',]
|
||||
|
||||
C = MLUExtension('torch_mlu_ops._C',
|
||||
sources=[*source_list],
|
||||
include_dirs=[*include_dir_list],
|
||||
library_dirs=[*library_dir_list],
|
||||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'cncc': CNCC_FLAGS},
|
||||
extra_link_args=extra_link_args,
|
||||
)
|
||||
ext_modules = [C]
|
||||
|
||||
cmdclass={
|
||||
'build_ext': Build,
|
||||
'clean': Clean,
|
||||
'install': install,
|
||||
}
|
||||
|
||||
install_requires=[
|
||||
"torch >= 2.1.0",
|
||||
"torch_mlu >= 1.20.0",
|
||||
"ninja",
|
||||
"jinja2",
|
||||
"packaging",
|
||||
]
|
||||
|
||||
setup(name="torch_mlu_ops",
|
||||
version=torch_tmo_version,
|
||||
packages=packages,
|
||||
description='BangTransformer Torch API',
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url = 'http://www.cambricon.com',
|
||||
ext_modules=ext_modules,
|
||||
cmake_find_package=True,
|
||||
cmdclass=cmdclass,
|
||||
install_requires=install_requires,
|
||||
python_requires=">=3.10")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user