DeepGemm integrate to sgl-kernel (#4165)

Co-authored-by: sleepcoo <sleepcoo@gmail.com>
Co-authored-by: HandH1998 <1335248067@qq.com>
Co-authored-by: shuaills <shishuaiuoe@gmail.com>
Co-authored-by: yinfan98 <1106310035@qq.com>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
laixin
2025-03-10 15:35:07 +08:00
committed by GitHub
parent 7c0541b385
commit c553e1604c
6 changed files with 324 additions and 5 deletions

View File

@@ -14,11 +14,13 @@
# ==============================================================================
import os
import shutil
import sys
from pathlib import Path
import torch
from setuptools import find_packages, setup
from setuptools.command.build_py import build_py
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
root = Path(__file__).parent.resolve()
@@ -52,6 +54,7 @@ operator_namespace = "sgl_kernel"
cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer"
deepgemm = root / "3rdparty" / "deepgemm"
include_dirs = [
root / "include",
root / "csrc",
@@ -63,6 +66,51 @@ include_dirs = [
"cublas",
]
class CustomBuildPy(build_py):
def run(self):
self.copy_deepgemm_to_build_lib()
self.make_jit_include_symlinks()
build_py.run(self)
def make_jit_include_symlinks(self):
# Make symbolic links of third-party include directories
build_include_dir = os.path.join(self.build_lib, "deep_gemm/include")
os.makedirs(build_include_dir, exist_ok=True)
third_party_include_dirs = [
cutlass.resolve() / "include" / "cute",
cutlass.resolve() / "include" / "cutlass",
]
for d in third_party_include_dirs:
dirname = str(d).split("/")[-1]
src_dir = d
dst_dir = f"{build_include_dir}/{dirname}"
assert os.path.exists(src_dir)
if os.path.exists(dst_dir):
assert os.path.islink(dst_dir)
os.unlink(dst_dir)
os.symlink(src_dir, dst_dir, target_is_directory=True)
def copy_deepgemm_to_build_lib(self):
"""
This function copies DeepGemm to python's site-packages
"""
dst_dir = os.path.join(self.build_lib, "deep_gemm")
os.makedirs(dst_dir, exist_ok=True)
# Copy deepgemm/deep_gemm to the build directory
src_dir = os.path.join(str(deepgemm.resolve()), "deep_gemm")
# Remove existing directory if it exists
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
# Copy the directory
shutil.copytree(src_dir, dst_dir)
nvcc_flags = [
"-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}",
@@ -175,6 +223,9 @@ setup(
packages=find_packages(where="python"),
package_dir={"": "python"},
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
cmdclass={
"build_ext": BuildExtension.with_options(use_ninja=True),
"build_py": CustomBuildPy,
},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)