feat: Add FlashMLA submodule (#4449)
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
@@ -18,6 +18,21 @@ import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Setup flash_mla at the top level for tests to find
|
||||
# This makes the module importable without installation
|
||||
root_dir = Path(__file__).parent.resolve()
|
||||
module_src = root_dir / "3rdparty" / "flashmla" / "flash_mla"
|
||||
module_dest = root_dir / "flash_mla"
|
||||
|
||||
if module_src.exists() and not module_dest.exists():
|
||||
try:
|
||||
os.symlink(module_src, module_dest, target_is_directory=True)
|
||||
print(f"Created symbolic link from {module_src} to {module_dest}")
|
||||
except (OSError, NotImplementedError):
|
||||
if module_src.exists():
|
||||
shutil.copytree(module_src, module_dest)
|
||||
print(f"Copied directory from {module_src} to {module_dest}")
|
||||
|
||||
import torch
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools.command.build_py import build_py
|
||||
@@ -55,6 +70,7 @@ cutlass_default = root / "3rdparty" / "cutlass"
|
||||
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
|
||||
flashinfer = root / "3rdparty" / "flashinfer"
|
||||
deepgemm = root / "3rdparty" / "deepgemm"
|
||||
flashmla = root / "3rdparty" / "flashmla"
|
||||
include_dirs = [
|
||||
root / "include",
|
||||
root / "csrc",
|
||||
@@ -63,6 +79,7 @@ include_dirs = [
|
||||
flashinfer.resolve() / "include",
|
||||
flashinfer.resolve() / "include" / "gemm",
|
||||
flashinfer.resolve() / "csrc",
|
||||
flashmla.resolve() / "csrc",
|
||||
"cublas",
|
||||
]
|
||||
|
||||
@@ -70,6 +87,7 @@ include_dirs = [
|
||||
class CustomBuildPy(build_py):
|
||||
def run(self):
|
||||
self.copy_deepgemm_to_build_lib()
|
||||
self.copy_flashmla_to_build_lib()
|
||||
self.make_jit_include_symlinks()
|
||||
build_py.run(self)
|
||||
|
||||
@@ -93,6 +111,17 @@ class CustomBuildPy(build_py):
|
||||
os.unlink(dst_dir)
|
||||
os.symlink(src_dir, dst_dir, target_is_directory=True)
|
||||
|
||||
# Create symbolic links for FlashMLA
|
||||
flash_mla_include_dir = os.path.join(self.build_lib, "flash_mla/include")
|
||||
os.makedirs(flash_mla_include_dir, exist_ok=True)
|
||||
|
||||
# Create empty directories for FlashMLA's include paths
|
||||
# This is safer than creating symlinks as the targets might not exist in CI
|
||||
for dirname in ["cute", "cutlass"]:
|
||||
dst_dir = f"{flash_mla_include_dir}/{dirname}"
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
|
||||
def copy_deepgemm_to_build_lib(self):
|
||||
"""
|
||||
This function copies DeepGemm to python's site-packages
|
||||
@@ -110,6 +139,26 @@ class CustomBuildPy(build_py):
|
||||
# Copy the directory
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
def copy_flashmla_to_build_lib(self):
|
||||
"""
|
||||
This function copies FlashMLA to python's site-packages
|
||||
"""
|
||||
dst_dir = os.path.join(self.build_lib, "flash_mla")
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
|
||||
src_dir = os.path.join(str(flashmla.resolve()), "flash_mla")
|
||||
|
||||
if not os.path.exists(src_dir):
|
||||
print(
|
||||
f"Warning: Source directory {src_dir} does not exist, possibly the submodule is not properly initialized"
|
||||
)
|
||||
return
|
||||
|
||||
if os.path.exists(dst_dir):
|
||||
shutil.rmtree(dst_dir)
|
||||
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
|
||||
nvcc_flags = [
|
||||
"-DNDEBUG",
|
||||
|
||||
Reference in New Issue
Block a user