Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/setup.py

226 lines
7.2 KiB
Python
Raw Normal View History

2026-02-04 17:39:32 +08:00
import os
import sys
import shutil
import torch
import subprocess
import platform
from pathlib import Path
from packaging.version import Version
from jinja2 import Environment, FileSystemLoader
import distutils.command.clean # pylint: disable=C0411
import setuptools.command.install # pylint: disable=C0411
from setuptools import setup, find_packages, distutils # pylint: disable=C0411
from torch_mlu.utils.cpp_extension import BuildExtension, MLUExtension
def get_tmo_version(tmo_version_file):
if (not os.path.isfile(tmo_version_file)):
print("Failed to find version file: {0}".format(tmo_version_file))
sys.exit(1)
with open(tmo_version_file, 'r') as f:
lines = f.readlines()
for line in lines:
if "TMO_VERSION" in line:
return Version(line.split("=")[1].strip()).base_version
raise RuntimeError(f"Can not find version from {tmo_version_file}.")
def get_source_files(dir_path, include_kernel: bool = True):
source_list = []
for file in Path(dir_path).rglob('*.cpp'):
source_list.append(file.as_posix())
if include_kernel:
for file in Path(dir_path).rglob('*.mlu'):
source_list.append(file.as_posix())
return source_list
def get_include_dirs(dir_path):
out_dirs = []
for root, dirs, files in os.walk(dir_path):
for dir in dirs:
out_dirs.append(os.path.join(root, dir))
return out_dirs
def _check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
class install(setuptools.command.install.install):
def run(self):
super().run()
class Build(BuildExtension):
kernel_args = ()
kernel_kwargs = {}
@classmethod
def build_kernel(cls, *args, **kwargs):
cls.kernel_args = args
cls.kernel_kwargs = kwargs
pass
@classmethod
def pre_run(cls, func, *args, **kwargs):
cls.build_kernel = func
cls.kernel_args = args
cls.kernel_kwargs = kwargs
def run(self):
Build.build_kernel(*Build.kernel_args, **Build.kernel_kwargs)
super().run()
class Clean(distutils.command.clean.clean):
def run(self):
super().run()
for root, dirs, files in os.walk('.'):
for dir in dirs:
if dir == '__pycache__' or dir == 'build':
shutil.rmtree(os.path.join(root, dir))
for file in files:
if file.endswith('.pyc') or file.endswith('.pyo') or file.endswith('.so'):
os.remove(os.path.join(root, file))
def main():
BUILD_KERNELS_WITH_CMAKE = _check_env_flag('TMO_BUILD_KERNELS_WITH_CMAKE')
DEBUG_MODE = _check_env_flag('TMO_DEBUG_MODE')
CXX_FLAGS = []
CNCC_FLAGS = []
# base_dir: bangtransformer
base_dir = os.path.dirname(os.path.abspath(__file__))
csrc_dir = os.path.join(base_dir, "csrc")
include_dir_list = [csrc_dir]
# get tmo ops version
tmo_version_file = os.path.join(base_dir, 'build.property')
torch_tmo_version = get_tmo_version(tmo_version_file)
tv = Version(torch.__version__)
torch_tmo_version = torch_tmo_version + "+pt" + f"{tv.major}{tv.minor}"
# create _version.py
env = Environment(loader=FileSystemLoader(base_dir))
tpl = env.get_template('version.tpl')
ver_cxt = {'VERSION': torch_tmo_version}
with open('torch_mlu_ops/_version.py', 'w', encoding='utf-8') as f:
f.write(tpl.render(ver_cxt))
# source files in source_list can not be absolute path
source_list = get_source_files("csrc", not BUILD_KERNELS_WITH_CMAKE)
include_dir_list.extend(get_include_dirs("csrc"))
library_dir_list = []
neuware_home = os.getenv("NEUWARE_HOME", "/usr/local/neuware")
neuware_home_include = os.path.join(neuware_home, "include")
neuware_home_lib = os.path.join(neuware_home, "lib64")
include_dir_list.append(neuware_home_include)
library_dir_list.append(neuware_home_lib)
CXX_FLAGS += [
'-fPIC',
'-Wall',
'-Werror',
'-Wno-error=deprecated-declarations',
]
if DEBUG_MODE:
CXX_FLAGS += ['-Og', '-g', '-DDEBUG']
TMO_MEM_CHECK = _check_env_flag('TMO_MEM_CHECK')
if TMO_MEM_CHECK:
SANITIZER_FLAGS=["-O1",
"-g",
"-DDEBUG",
"-fsanitize=address",
"-fno-omit-frame-pointer",]
CXX_FLAGS.extend(SANITIZER_FLAGS)
CNCC_FLAGS.extend(SANITIZER_FLAGS)
def build_kernel_with_cmake(*args, **kwargs):
if DEBUG_MODE:
os.environ['BUILD_MODE'] = 'debug'
cmd = os.path.join(csrc_dir, "kernels/build.sh")
res = subprocess.run(cmd, shell=True)
assert res.returncode == 0, 'failed to build tmo kernels by cmake'
if BUILD_KERNELS_WITH_CMAKE:
Build.pre_run(build_kernel_with_cmake)
else:
CNCC_FLAGS += [
'-fPIC',
'--bang-arch=compute_30',
'--bang-mlu-arch=mtp_592',
'--bang-mlu-arch=mtp_613',
'-Xbang-cnas',
'-fno-soft-pipeline',
'-Wall',
'-Werror',
'-Wno-error=deprecated-declarations',
'--no-neuware-version-check',
]
if platform.machine() == "x86_64":
CNCC_FLAGS += ['-mcmodel=large']
if DEBUG_MODE:
CNCC_FLAGS += ['-Og', '-g', '-DDEBUG']
# MLUExtension does not add include_dirs automatically for mlu files,
# so we have to add them in CNCC_FLAGS.
for dir in include_dir_list:
CNCC_FLAGS.append("-I" + dir)
packages=find_packages(
include=(
"torch_mlu_ops",
)
)
# Read in README.md for long_description
with open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f:
long_description = f.read()
extra_link_args = ['-Wl,-rpath,$ORIGIN/']
if not DEBUG_MODE:
extra_link_args += ['-Wl,--strip-all']
if BUILD_KERNELS_WITH_CMAKE:
extra_link_args+=['-Wl,-whole-archive',
os.path.join(csrc_dir, "kernels/build/lib/libtmo_kernels.a"),
'-Wl,-no-whole-archive',]
C = MLUExtension('torch_mlu_ops._C',
sources=[*source_list],
include_dirs=[*include_dir_list],
library_dirs=[*library_dir_list],
extra_compile_args={'cxx': CXX_FLAGS,
'cncc': CNCC_FLAGS},
extra_link_args=extra_link_args,
)
ext_modules = [C]
cmdclass={
'build_ext': Build,
'clean': Clean,
'install': install,
}
install_requires=[
"torch >= 2.1.0",
"torch_mlu >= 1.20.0",
"ninja",
"jinja2",
"packaging",
]
setup(name="torch_mlu_ops",
version=torch_tmo_version,
packages=packages,
description='BangTransformer Torch API',
long_description=long_description,
long_description_content_type="text/markdown",
url = 'http://www.cambricon.com',
ext_modules=ext_modules,
cmake_find_package=True,
cmdclass=cmdclass,
install_requires=install_requires,
python_requires=">=3.10")
if __name__ == "__main__":
main()