import os import sys import shutil import torch import subprocess import platform from pathlib import Path from packaging.version import Version from jinja2 import Environment, FileSystemLoader import distutils.command.clean # pylint: disable=C0411 import setuptools.command.install # pylint: disable=C0411 from setuptools import setup, find_packages, distutils # pylint: disable=C0411 from torch_mlu.utils.cpp_extension import BuildExtension, MLUExtension def get_tmo_version(tmo_version_file): if (not os.path.isfile(tmo_version_file)): print("Failed to find version file: {0}".format(tmo_version_file)) sys.exit(1) with open(tmo_version_file, 'r') as f: lines = f.readlines() for line in lines: if "TMO_VERSION" in line: return Version(line.split("=")[1].strip()).base_version raise RuntimeError(f"Can not find version from {tmo_version_file}.") def get_source_files(dir_path, include_kernel: bool = True): source_list = [] for file in Path(dir_path).rglob('*.cpp'): source_list.append(file.as_posix()) if include_kernel: for file in Path(dir_path).rglob('*.mlu'): source_list.append(file.as_posix()) return source_list def get_include_dirs(dir_path): out_dirs = [] for root, dirs, files in os.walk(dir_path): for dir in dirs: out_dirs.append(os.path.join(root, dir)) return out_dirs def _check_env_flag(name, default=''): return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] class install(setuptools.command.install.install): def run(self): super().run() class Build(BuildExtension): kernel_args = () kernel_kwargs = {} @classmethod def build_kernel(cls, *args, **kwargs): cls.kernel_args = args cls.kernel_kwargs = kwargs pass @classmethod def pre_run(cls, func, *args, **kwargs): cls.build_kernel = func cls.kernel_args = args cls.kernel_kwargs = kwargs def run(self): Build.build_kernel(*Build.kernel_args, **Build.kernel_kwargs) super().run() class Clean(distutils.command.clean.clean): def run(self): super().run() for root, dirs, files in os.walk('.'): for dir in dirs: if dir == '__pycache__' or dir == 'build': shutil.rmtree(os.path.join(root, dir)) for file in files: if file.endswith('.pyc') or file.endswith('.pyo') or file.endswith('.so'): os.remove(os.path.join(root, file)) def main(): BUILD_KERNELS_WITH_CMAKE = _check_env_flag('TMO_BUILD_KERNELS_WITH_CMAKE') DEBUG_MODE = _check_env_flag('TMO_DEBUG_MODE') CXX_FLAGS = [] CNCC_FLAGS = [] # base_dir: bangtransformer base_dir = os.path.dirname(os.path.abspath(__file__)) csrc_dir = os.path.join(base_dir, "csrc") include_dir_list = [csrc_dir] # get tmo ops version tmo_version_file = os.path.join(base_dir, 'build.property') torch_tmo_version = get_tmo_version(tmo_version_file) tv = Version(torch.__version__) torch_tmo_version = torch_tmo_version + "+pt" + f"{tv.major}{tv.minor}" # create _version.py env = Environment(loader=FileSystemLoader(base_dir)) tpl = env.get_template('version.tpl') ver_cxt = {'VERSION': torch_tmo_version} with open('torch_mlu_ops/_version.py', 'w', encoding='utf-8') as f: f.write(tpl.render(ver_cxt)) # source files in source_list can not be absolute path source_list = get_source_files("csrc", not BUILD_KERNELS_WITH_CMAKE) include_dir_list.extend(get_include_dirs("csrc")) library_dir_list = [] neuware_home = os.getenv("NEUWARE_HOME", "/usr/local/neuware") neuware_home_include = os.path.join(neuware_home, "include") neuware_home_lib = os.path.join(neuware_home, "lib64") include_dir_list.append(neuware_home_include) library_dir_list.append(neuware_home_lib) CXX_FLAGS += [ '-fPIC', '-Wall', '-Werror', '-Wno-error=deprecated-declarations', ] if DEBUG_MODE: CXX_FLAGS += ['-Og', '-g', '-DDEBUG'] TMO_MEM_CHECK = _check_env_flag('TMO_MEM_CHECK') if TMO_MEM_CHECK: SANITIZER_FLAGS=["-O1", "-g", "-DDEBUG", "-fsanitize=address", "-fno-omit-frame-pointer",] CXX_FLAGS.extend(SANITIZER_FLAGS) CNCC_FLAGS.extend(SANITIZER_FLAGS) def build_kernel_with_cmake(*args, **kwargs): if DEBUG_MODE: os.environ['BUILD_MODE'] = 'debug' cmd = os.path.join(csrc_dir, "kernels/build.sh") res = subprocess.run(cmd, shell=True) assert res.returncode == 0, 'failed to build tmo kernels by cmake' if BUILD_KERNELS_WITH_CMAKE: Build.pre_run(build_kernel_with_cmake) else: CNCC_FLAGS += [ '-fPIC', '--bang-arch=compute_30', '--bang-mlu-arch=mtp_592', '--bang-mlu-arch=mtp_613', '-Xbang-cnas', '-fno-soft-pipeline', '-Wall', '-Werror', '-Wno-error=deprecated-declarations', '--no-neuware-version-check', ] if platform.machine() == "x86_64": CNCC_FLAGS += ['-mcmodel=large'] if DEBUG_MODE: CNCC_FLAGS += ['-Og', '-g', '-DDEBUG'] # MLUExtension does not add include_dirs automatically for mlu files, # so we have to add them in CNCC_FLAGS. for dir in include_dir_list: CNCC_FLAGS.append("-I" + dir) packages=find_packages( include=( "torch_mlu_ops", ) ) # Read in README.md for long_description with open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f: long_description = f.read() extra_link_args = ['-Wl,-rpath,$ORIGIN/'] if not DEBUG_MODE: extra_link_args += ['-Wl,--strip-all'] if BUILD_KERNELS_WITH_CMAKE: extra_link_args+=['-Wl,-whole-archive', os.path.join(csrc_dir, "kernels/build/lib/libtmo_kernels.a"), '-Wl,-no-whole-archive',] C = MLUExtension('torch_mlu_ops._C', sources=[*source_list], include_dirs=[*include_dir_list], library_dirs=[*library_dir_list], extra_compile_args={'cxx': CXX_FLAGS, 'cncc': CNCC_FLAGS}, extra_link_args=extra_link_args, ) ext_modules = [C] cmdclass={ 'build_ext': Build, 'clean': Clean, 'install': install, } install_requires=[ "torch >= 2.1.0", "torch_mlu >= 1.20.0", "ninja", "jinja2", "packaging", ] setup(name="torch_mlu_ops", version=torch_tmo_version, packages=packages, description='BangTransformer Torch API', long_description=long_description, long_description_content_type="text/markdown", url = 'http://www.cambricon.com', ext_modules=ext_modules, cmake_find_package=True, cmdclass=cmdclass, install_requires=install_requires, python_requires=">=3.10") if __name__ == "__main__": main()