Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/setup.py
2026-02-04 17:39:32 +08:00

226 lines
7.2 KiB
Python

import os
import sys
import shutil
import torch
import subprocess
import platform
from pathlib import Path
from packaging.version import Version
from jinja2 import Environment, FileSystemLoader
import distutils.command.clean # pylint: disable=C0411
import setuptools.command.install # pylint: disable=C0411
from setuptools import setup, find_packages, distutils # pylint: disable=C0411
from torch_mlu.utils.cpp_extension import BuildExtension, MLUExtension
def get_tmo_version(tmo_version_file):
if (not os.path.isfile(tmo_version_file)):
print("Failed to find version file: {0}".format(tmo_version_file))
sys.exit(1)
with open(tmo_version_file, 'r') as f:
lines = f.readlines()
for line in lines:
if "TMO_VERSION" in line:
return Version(line.split("=")[1].strip()).base_version
raise RuntimeError(f"Can not find version from {tmo_version_file}.")
def get_source_files(dir_path, include_kernel: bool = True):
source_list = []
for file in Path(dir_path).rglob('*.cpp'):
source_list.append(file.as_posix())
if include_kernel:
for file in Path(dir_path).rglob('*.mlu'):
source_list.append(file.as_posix())
return source_list
def get_include_dirs(dir_path):
out_dirs = []
for root, dirs, files in os.walk(dir_path):
for dir in dirs:
out_dirs.append(os.path.join(root, dir))
return out_dirs
def _check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
class install(setuptools.command.install.install):
def run(self):
super().run()
class Build(BuildExtension):
kernel_args = ()
kernel_kwargs = {}
@classmethod
def build_kernel(cls, *args, **kwargs):
cls.kernel_args = args
cls.kernel_kwargs = kwargs
pass
@classmethod
def pre_run(cls, func, *args, **kwargs):
cls.build_kernel = func
cls.kernel_args = args
cls.kernel_kwargs = kwargs
def run(self):
Build.build_kernel(*Build.kernel_args, **Build.kernel_kwargs)
super().run()
class Clean(distutils.command.clean.clean):
def run(self):
super().run()
for root, dirs, files in os.walk('.'):
for dir in dirs:
if dir == '__pycache__' or dir == 'build':
shutil.rmtree(os.path.join(root, dir))
for file in files:
if file.endswith('.pyc') or file.endswith('.pyo') or file.endswith('.so'):
os.remove(os.path.join(root, file))
def main():
BUILD_KERNELS_WITH_CMAKE = _check_env_flag('TMO_BUILD_KERNELS_WITH_CMAKE')
DEBUG_MODE = _check_env_flag('TMO_DEBUG_MODE')
CXX_FLAGS = []
CNCC_FLAGS = []
# base_dir: bangtransformer
base_dir = os.path.dirname(os.path.abspath(__file__))
csrc_dir = os.path.join(base_dir, "csrc")
include_dir_list = [csrc_dir]
# get tmo ops version
tmo_version_file = os.path.join(base_dir, 'build.property')
torch_tmo_version = get_tmo_version(tmo_version_file)
tv = Version(torch.__version__)
torch_tmo_version = torch_tmo_version + "+pt" + f"{tv.major}{tv.minor}"
# create _version.py
env = Environment(loader=FileSystemLoader(base_dir))
tpl = env.get_template('version.tpl')
ver_cxt = {'VERSION': torch_tmo_version}
with open('torch_mlu_ops/_version.py', 'w', encoding='utf-8') as f:
f.write(tpl.render(ver_cxt))
# source files in source_list can not be absolute path
source_list = get_source_files("csrc", not BUILD_KERNELS_WITH_CMAKE)
include_dir_list.extend(get_include_dirs("csrc"))
library_dir_list = []
neuware_home = os.getenv("NEUWARE_HOME", "/usr/local/neuware")
neuware_home_include = os.path.join(neuware_home, "include")
neuware_home_lib = os.path.join(neuware_home, "lib64")
include_dir_list.append(neuware_home_include)
library_dir_list.append(neuware_home_lib)
CXX_FLAGS += [
'-fPIC',
'-Wall',
'-Werror',
'-Wno-error=deprecated-declarations',
]
if DEBUG_MODE:
CXX_FLAGS += ['-Og', '-g', '-DDEBUG']
TMO_MEM_CHECK = _check_env_flag('TMO_MEM_CHECK')
if TMO_MEM_CHECK:
SANITIZER_FLAGS=["-O1",
"-g",
"-DDEBUG",
"-fsanitize=address",
"-fno-omit-frame-pointer",]
CXX_FLAGS.extend(SANITIZER_FLAGS)
CNCC_FLAGS.extend(SANITIZER_FLAGS)
def build_kernel_with_cmake(*args, **kwargs):
if DEBUG_MODE:
os.environ['BUILD_MODE'] = 'debug'
cmd = os.path.join(csrc_dir, "kernels/build.sh")
res = subprocess.run(cmd, shell=True)
assert res.returncode == 0, 'failed to build tmo kernels by cmake'
if BUILD_KERNELS_WITH_CMAKE:
Build.pre_run(build_kernel_with_cmake)
else:
CNCC_FLAGS += [
'-fPIC',
'--bang-arch=compute_30',
'--bang-mlu-arch=mtp_592',
'--bang-mlu-arch=mtp_613',
'-Xbang-cnas',
'-fno-soft-pipeline',
'-Wall',
'-Werror',
'-Wno-error=deprecated-declarations',
'--no-neuware-version-check',
]
if platform.machine() == "x86_64":
CNCC_FLAGS += ['-mcmodel=large']
if DEBUG_MODE:
CNCC_FLAGS += ['-Og', '-g', '-DDEBUG']
# MLUExtension does not add include_dirs automatically for mlu files,
# so we have to add them in CNCC_FLAGS.
for dir in include_dir_list:
CNCC_FLAGS.append("-I" + dir)
packages=find_packages(
include=(
"torch_mlu_ops",
)
)
# Read in README.md for long_description
with open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f:
long_description = f.read()
extra_link_args = ['-Wl,-rpath,$ORIGIN/']
if not DEBUG_MODE:
extra_link_args += ['-Wl,--strip-all']
if BUILD_KERNELS_WITH_CMAKE:
extra_link_args+=['-Wl,-whole-archive',
os.path.join(csrc_dir, "kernels/build/lib/libtmo_kernels.a"),
'-Wl,-no-whole-archive',]
C = MLUExtension('torch_mlu_ops._C',
sources=[*source_list],
include_dirs=[*include_dir_list],
library_dirs=[*library_dir_list],
extra_compile_args={'cxx': CXX_FLAGS,
'cncc': CNCC_FLAGS},
extra_link_args=extra_link_args,
)
ext_modules = [C]
cmdclass={
'build_ext': Build,
'clean': Clean,
'install': install,
}
install_requires=[
"torch >= 2.1.0",
"torch_mlu >= 1.20.0",
"ninja",
"jinja2",
"packaging",
]
setup(name="torch_mlu_ops",
version=torch_tmo_version,
packages=packages,
description='BangTransformer Torch API',
long_description=long_description,
long_description_content_type="text/markdown",
url = 'http://www.cambricon.com',
ext_modules=ext_modules,
cmake_find_package=True,
cmdclass=cmdclass,
install_requires=install_requires,
python_requires=">=3.10")
if __name__ == "__main__":
main()