DeepGemm integrate to sgl-kernel (#4165)
Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: shuaills <shishuaiuoe@gmail.com> Co-authored-by: yinfan98 <1106310035@qq.com> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -14,11 +14,13 @@
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools.command.build_py import build_py
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
root = Path(__file__).parent.resolve()
|
||||
@@ -52,6 +54,7 @@ operator_namespace = "sgl_kernel"
|
||||
cutlass_default = root / "3rdparty" / "cutlass"
|
||||
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
|
||||
flashinfer = root / "3rdparty" / "flashinfer"
|
||||
deepgemm = root / "3rdparty" / "deepgemm"
|
||||
include_dirs = [
|
||||
root / "include",
|
||||
root / "csrc",
|
||||
@@ -63,6 +66,51 @@ include_dirs = [
|
||||
"cublas",
|
||||
]
|
||||
|
||||
|
||||
class CustomBuildPy(build_py):
|
||||
def run(self):
|
||||
self.copy_deepgemm_to_build_lib()
|
||||
self.make_jit_include_symlinks()
|
||||
build_py.run(self)
|
||||
|
||||
def make_jit_include_symlinks(self):
|
||||
# Make symbolic links of third-party include directories
|
||||
build_include_dir = os.path.join(self.build_lib, "deep_gemm/include")
|
||||
os.makedirs(build_include_dir, exist_ok=True)
|
||||
|
||||
third_party_include_dirs = [
|
||||
cutlass.resolve() / "include" / "cute",
|
||||
cutlass.resolve() / "include" / "cutlass",
|
||||
]
|
||||
|
||||
for d in third_party_include_dirs:
|
||||
dirname = str(d).split("/")[-1]
|
||||
src_dir = d
|
||||
dst_dir = f"{build_include_dir}/{dirname}"
|
||||
assert os.path.exists(src_dir)
|
||||
if os.path.exists(dst_dir):
|
||||
assert os.path.islink(dst_dir)
|
||||
os.unlink(dst_dir)
|
||||
os.symlink(src_dir, dst_dir, target_is_directory=True)
|
||||
|
||||
def copy_deepgemm_to_build_lib(self):
|
||||
"""
|
||||
This function copies DeepGemm to python's site-packages
|
||||
"""
|
||||
dst_dir = os.path.join(self.build_lib, "deep_gemm")
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
|
||||
# Copy deepgemm/deep_gemm to the build directory
|
||||
src_dir = os.path.join(str(deepgemm.resolve()), "deep_gemm")
|
||||
|
||||
# Remove existing directory if it exists
|
||||
if os.path.exists(dst_dir):
|
||||
shutil.rmtree(dst_dir)
|
||||
|
||||
# Copy the directory
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
|
||||
nvcc_flags = [
|
||||
"-DNDEBUG",
|
||||
f"-DOPERATOR_NAMESPACE={operator_namespace}",
|
||||
@@ -175,6 +223,9 @@ setup(
|
||||
packages=find_packages(where="python"),
|
||||
package_dir={"": "python"},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension.with_options(use_ninja=True),
|
||||
"build_py": CustomBuildPy,
|
||||
},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user