feat: Add FlashMLA submodule (#4449)

Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
Shi Shuai
2025-03-16 14:30:25 +08:00
committed by GitHub
parent 65b7c9b78f
commit 81f431eded
3 changed files with 205 additions and 0 deletions

View File

@@ -18,6 +18,21 @@ import shutil
import sys
from pathlib import Path
# Setup flash_mla at the top level for tests to find
# This makes the module importable without installation
root_dir = Path(__file__).parent.resolve()
module_src = root_dir / "3rdparty" / "flashmla" / "flash_mla"
module_dest = root_dir / "flash_mla"
if module_src.exists() and not module_dest.exists():
try:
os.symlink(module_src, module_dest, target_is_directory=True)
print(f"Created symbolic link from {module_src} to {module_dest}")
except (OSError, NotImplementedError):
if module_src.exists():
shutil.copytree(module_src, module_dest)
print(f"Copied directory from {module_src} to {module_dest}")
import torch
from setuptools import find_packages, setup
from setuptools.command.build_py import build_py
@@ -55,6 +70,7 @@ cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer"
deepgemm = root / "3rdparty" / "deepgemm"
flashmla = root / "3rdparty" / "flashmla"
include_dirs = [
root / "include",
root / "csrc",
@@ -63,6 +79,7 @@ include_dirs = [
flashinfer.resolve() / "include",
flashinfer.resolve() / "include" / "gemm",
flashinfer.resolve() / "csrc",
flashmla.resolve() / "csrc",
"cublas",
]
@@ -70,6 +87,7 @@ include_dirs = [
class CustomBuildPy(build_py):
def run(self):
self.copy_deepgemm_to_build_lib()
self.copy_flashmla_to_build_lib()
self.make_jit_include_symlinks()
build_py.run(self)
@@ -93,6 +111,17 @@ class CustomBuildPy(build_py):
os.unlink(dst_dir)
os.symlink(src_dir, dst_dir, target_is_directory=True)
# Create symbolic links for FlashMLA
flash_mla_include_dir = os.path.join(self.build_lib, "flash_mla/include")
os.makedirs(flash_mla_include_dir, exist_ok=True)
# Create empty directories for FlashMLA's include paths
# This is safer than creating symlinks as the targets might not exist in CI
for dirname in ["cute", "cutlass"]:
dst_dir = f"{flash_mla_include_dir}/{dirname}"
if not os.path.exists(dst_dir):
os.makedirs(dst_dir, exist_ok=True)
def copy_deepgemm_to_build_lib(self):
"""
This function copies DeepGemm to python's site-packages
@@ -110,6 +139,26 @@ class CustomBuildPy(build_py):
# Copy the directory
shutil.copytree(src_dir, dst_dir)
def copy_flashmla_to_build_lib(self):
"""
This function copies FlashMLA to python's site-packages
"""
dst_dir = os.path.join(self.build_lib, "flash_mla")
os.makedirs(dst_dir, exist_ok=True)
src_dir = os.path.join(str(flashmla.resolve()), "flash_mla")
if not os.path.exists(src_dir):
print(
f"Warning: Source directory {src_dir} does not exist, possibly the submodule is not properly initialized"
)
return
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
shutil.copytree(src_dir, dst_dir)
nvcc_flags = [
"-DNDEBUG",