forked from EngineX-Cambricon/enginex-mlu370-vllm
226 lines
7.2 KiB
Python
226 lines
7.2 KiB
Python
|
|
import os
|
||
|
|
import sys
|
||
|
|
import shutil
|
||
|
|
import torch
|
||
|
|
import subprocess
|
||
|
|
import platform
|
||
|
|
from pathlib import Path
|
||
|
|
from packaging.version import Version
|
||
|
|
from jinja2 import Environment, FileSystemLoader
|
||
|
|
import distutils.command.clean # pylint: disable=C0411
|
||
|
|
import setuptools.command.install # pylint: disable=C0411
|
||
|
|
from setuptools import setup, find_packages, distutils # pylint: disable=C0411
|
||
|
|
from torch_mlu.utils.cpp_extension import BuildExtension, MLUExtension
|
||
|
|
|
||
|
|
|
||
|
|
def get_tmo_version(tmo_version_file):
|
||
|
|
if (not os.path.isfile(tmo_version_file)):
|
||
|
|
print("Failed to find version file: {0}".format(tmo_version_file))
|
||
|
|
sys.exit(1)
|
||
|
|
with open(tmo_version_file, 'r') as f:
|
||
|
|
lines = f.readlines()
|
||
|
|
for line in lines:
|
||
|
|
if "TMO_VERSION" in line:
|
||
|
|
return Version(line.split("=")[1].strip()).base_version
|
||
|
|
raise RuntimeError(f"Can not find version from {tmo_version_file}.")
|
||
|
|
|
||
|
|
def get_source_files(dir_path, include_kernel: bool = True):
|
||
|
|
source_list = []
|
||
|
|
for file in Path(dir_path).rglob('*.cpp'):
|
||
|
|
source_list.append(file.as_posix())
|
||
|
|
if include_kernel:
|
||
|
|
for file in Path(dir_path).rglob('*.mlu'):
|
||
|
|
source_list.append(file.as_posix())
|
||
|
|
return source_list
|
||
|
|
|
||
|
|
def get_include_dirs(dir_path):
|
||
|
|
out_dirs = []
|
||
|
|
for root, dirs, files in os.walk(dir_path):
|
||
|
|
for dir in dirs:
|
||
|
|
out_dirs.append(os.path.join(root, dir))
|
||
|
|
return out_dirs
|
||
|
|
|
||
|
|
def _check_env_flag(name, default=''):
|
||
|
|
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
|
||
|
|
|
||
|
|
|
||
|
|
class install(setuptools.command.install.install):
|
||
|
|
def run(self):
|
||
|
|
super().run()
|
||
|
|
|
||
|
|
class Build(BuildExtension):
|
||
|
|
kernel_args = ()
|
||
|
|
kernel_kwargs = {}
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def build_kernel(cls, *args, **kwargs):
|
||
|
|
cls.kernel_args = args
|
||
|
|
cls.kernel_kwargs = kwargs
|
||
|
|
pass
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def pre_run(cls, func, *args, **kwargs):
|
||
|
|
cls.build_kernel = func
|
||
|
|
cls.kernel_args = args
|
||
|
|
cls.kernel_kwargs = kwargs
|
||
|
|
|
||
|
|
def run(self):
|
||
|
|
Build.build_kernel(*Build.kernel_args, **Build.kernel_kwargs)
|
||
|
|
super().run()
|
||
|
|
|
||
|
|
class Clean(distutils.command.clean.clean):
|
||
|
|
def run(self):
|
||
|
|
super().run()
|
||
|
|
for root, dirs, files in os.walk('.'):
|
||
|
|
for dir in dirs:
|
||
|
|
if dir == '__pycache__' or dir == 'build':
|
||
|
|
shutil.rmtree(os.path.join(root, dir))
|
||
|
|
for file in files:
|
||
|
|
if file.endswith('.pyc') or file.endswith('.pyo') or file.endswith('.so'):
|
||
|
|
os.remove(os.path.join(root, file))
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
BUILD_KERNELS_WITH_CMAKE = _check_env_flag('TMO_BUILD_KERNELS_WITH_CMAKE')
|
||
|
|
DEBUG_MODE = _check_env_flag('TMO_DEBUG_MODE')
|
||
|
|
CXX_FLAGS = []
|
||
|
|
CNCC_FLAGS = []
|
||
|
|
|
||
|
|
# base_dir: bangtransformer
|
||
|
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
|
|
csrc_dir = os.path.join(base_dir, "csrc")
|
||
|
|
include_dir_list = [csrc_dir]
|
||
|
|
|
||
|
|
# get tmo ops version
|
||
|
|
tmo_version_file = os.path.join(base_dir, 'build.property')
|
||
|
|
torch_tmo_version = get_tmo_version(tmo_version_file)
|
||
|
|
tv = Version(torch.__version__)
|
||
|
|
torch_tmo_version = torch_tmo_version + "+pt" + f"{tv.major}{tv.minor}"
|
||
|
|
|
||
|
|
# create _version.py
|
||
|
|
env = Environment(loader=FileSystemLoader(base_dir))
|
||
|
|
tpl = env.get_template('version.tpl')
|
||
|
|
ver_cxt = {'VERSION': torch_tmo_version}
|
||
|
|
with open('torch_mlu_ops/_version.py', 'w', encoding='utf-8') as f:
|
||
|
|
f.write(tpl.render(ver_cxt))
|
||
|
|
|
||
|
|
# source files in source_list can not be absolute path
|
||
|
|
source_list = get_source_files("csrc", not BUILD_KERNELS_WITH_CMAKE)
|
||
|
|
include_dir_list.extend(get_include_dirs("csrc"))
|
||
|
|
library_dir_list = []
|
||
|
|
|
||
|
|
neuware_home = os.getenv("NEUWARE_HOME", "/usr/local/neuware")
|
||
|
|
neuware_home_include = os.path.join(neuware_home, "include")
|
||
|
|
neuware_home_lib = os.path.join(neuware_home, "lib64")
|
||
|
|
include_dir_list.append(neuware_home_include)
|
||
|
|
library_dir_list.append(neuware_home_lib)
|
||
|
|
|
||
|
|
CXX_FLAGS += [
|
||
|
|
'-fPIC',
|
||
|
|
'-Wall',
|
||
|
|
'-Werror',
|
||
|
|
'-Wno-error=deprecated-declarations',
|
||
|
|
]
|
||
|
|
if DEBUG_MODE:
|
||
|
|
CXX_FLAGS += ['-Og', '-g', '-DDEBUG']
|
||
|
|
|
||
|
|
TMO_MEM_CHECK = _check_env_flag('TMO_MEM_CHECK')
|
||
|
|
if TMO_MEM_CHECK:
|
||
|
|
SANITIZER_FLAGS=["-O1",
|
||
|
|
"-g",
|
||
|
|
"-DDEBUG",
|
||
|
|
"-fsanitize=address",
|
||
|
|
"-fno-omit-frame-pointer",]
|
||
|
|
CXX_FLAGS.extend(SANITIZER_FLAGS)
|
||
|
|
CNCC_FLAGS.extend(SANITIZER_FLAGS)
|
||
|
|
|
||
|
|
def build_kernel_with_cmake(*args, **kwargs):
|
||
|
|
if DEBUG_MODE:
|
||
|
|
os.environ['BUILD_MODE'] = 'debug'
|
||
|
|
cmd = os.path.join(csrc_dir, "kernels/build.sh")
|
||
|
|
res = subprocess.run(cmd, shell=True)
|
||
|
|
assert res.returncode == 0, 'failed to build tmo kernels by cmake'
|
||
|
|
|
||
|
|
if BUILD_KERNELS_WITH_CMAKE:
|
||
|
|
Build.pre_run(build_kernel_with_cmake)
|
||
|
|
else:
|
||
|
|
CNCC_FLAGS += [
|
||
|
|
'-fPIC',
|
||
|
|
'--bang-arch=compute_30',
|
||
|
|
'--bang-mlu-arch=mtp_592',
|
||
|
|
'--bang-mlu-arch=mtp_613',
|
||
|
|
'-Xbang-cnas',
|
||
|
|
'-fno-soft-pipeline',
|
||
|
|
'-Wall',
|
||
|
|
'-Werror',
|
||
|
|
'-Wno-error=deprecated-declarations',
|
||
|
|
'--no-neuware-version-check',
|
||
|
|
]
|
||
|
|
if platform.machine() == "x86_64":
|
||
|
|
CNCC_FLAGS += ['-mcmodel=large']
|
||
|
|
if DEBUG_MODE:
|
||
|
|
CNCC_FLAGS += ['-Og', '-g', '-DDEBUG']
|
||
|
|
|
||
|
|
# MLUExtension does not add include_dirs automatically for mlu files,
|
||
|
|
# so we have to add them in CNCC_FLAGS.
|
||
|
|
for dir in include_dir_list:
|
||
|
|
CNCC_FLAGS.append("-I" + dir)
|
||
|
|
|
||
|
|
packages=find_packages(
|
||
|
|
include=(
|
||
|
|
"torch_mlu_ops",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Read in README.md for long_description
|
||
|
|
with open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f:
|
||
|
|
long_description = f.read()
|
||
|
|
|
||
|
|
extra_link_args = ['-Wl,-rpath,$ORIGIN/']
|
||
|
|
if not DEBUG_MODE:
|
||
|
|
extra_link_args += ['-Wl,--strip-all']
|
||
|
|
if BUILD_KERNELS_WITH_CMAKE:
|
||
|
|
extra_link_args+=['-Wl,-whole-archive',
|
||
|
|
os.path.join(csrc_dir, "kernels/build/lib/libtmo_kernels.a"),
|
||
|
|
'-Wl,-no-whole-archive',]
|
||
|
|
|
||
|
|
C = MLUExtension('torch_mlu_ops._C',
|
||
|
|
sources=[*source_list],
|
||
|
|
include_dirs=[*include_dir_list],
|
||
|
|
library_dirs=[*library_dir_list],
|
||
|
|
extra_compile_args={'cxx': CXX_FLAGS,
|
||
|
|
'cncc': CNCC_FLAGS},
|
||
|
|
extra_link_args=extra_link_args,
|
||
|
|
)
|
||
|
|
ext_modules = [C]
|
||
|
|
|
||
|
|
cmdclass={
|
||
|
|
'build_ext': Build,
|
||
|
|
'clean': Clean,
|
||
|
|
'install': install,
|
||
|
|
}
|
||
|
|
|
||
|
|
install_requires=[
|
||
|
|
"torch >= 2.1.0",
|
||
|
|
"torch_mlu >= 1.20.0",
|
||
|
|
"ninja",
|
||
|
|
"jinja2",
|
||
|
|
"packaging",
|
||
|
|
]
|
||
|
|
|
||
|
|
setup(name="torch_mlu_ops",
|
||
|
|
version=torch_tmo_version,
|
||
|
|
packages=packages,
|
||
|
|
description='BangTransformer Torch API',
|
||
|
|
long_description=long_description,
|
||
|
|
long_description_content_type="text/markdown",
|
||
|
|
url = 'http://www.cambricon.com',
|
||
|
|
ext_modules=ext_modules,
|
||
|
|
cmake_find_package=True,
|
||
|
|
cmdclass=cmdclass,
|
||
|
|
install_requires=install_requires,
|
||
|
|
python_requires=">=3.10")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|