diff --git a/.gitmodules b/.gitmodules index 035314ec3..265d8d989 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,6 +10,3 @@ [submodule "sgl-kernel/3rdparty/deepgemm"] path = sgl-kernel/3rdparty/deepgemm url = https://github.com/deepseek-ai/DeepGEMM -[submodule "sgl-kernel/3rdparty/flashmla"] - path = sgl-kernel/3rdparty/flashmla - url = https://github.com/deepseek-ai/FlashMLA diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 26c6d0f70..4805e29f7 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -18,21 +18,6 @@ import shutil import sys from pathlib import Path -# Setup flash_mla at the top level for tests to find -# This makes the module importable without installation -root_dir = Path(__file__).parent.resolve() -module_src = root_dir / "3rdparty" / "flashmla" / "flash_mla" -module_dest = root_dir / "flash_mla" - -if module_src.exists() and not module_dest.exists(): - try: - os.symlink(module_src, module_dest, target_is_directory=True) - print(f"Created symbolic link from {module_src} to {module_dest}") - except (OSError, NotImplementedError): - if module_src.exists(): - shutil.copytree(module_src, module_dest) - print(f"Copied directory from {module_src} to {module_dest}") - import torch from setuptools import find_packages, setup from setuptools.command.build_py import build_py @@ -70,7 +55,6 @@ cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" deepgemm = root / "3rdparty" / "deepgemm" -flashmla = root / "3rdparty" / "flashmla" include_dirs = [ root / "include", root / "csrc", @@ -79,7 +63,6 @@ include_dirs = [ flashinfer.resolve() / "include", flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "csrc", - flashmla.resolve() / "csrc", "cublas", ] @@ -87,7 +70,6 @@ include_dirs = [ class CustomBuildPy(build_py): def run(self): self.copy_deepgemm_to_build_lib() - self.copy_flashmla_to_build_lib() self.make_jit_include_symlinks() build_py.run(self) @@ -111,17 +93,6 @@ class CustomBuildPy(build_py): os.unlink(dst_dir) os.symlink(src_dir, dst_dir, target_is_directory=True) - # Create symbolic links for FlashMLA - flash_mla_include_dir = os.path.join(self.build_lib, "flash_mla/include") - os.makedirs(flash_mla_include_dir, exist_ok=True) - - # Create empty directories for FlashMLA's include paths - # This is safer than creating symlinks as the targets might not exist in CI - for dirname in ["cute", "cutlass"]: - dst_dir = f"{flash_mla_include_dir}/{dirname}" - if not os.path.exists(dst_dir): - os.makedirs(dst_dir, exist_ok=True) - def copy_deepgemm_to_build_lib(self): """ This function copies DeepGemm to python's site-packages @@ -139,26 +110,6 @@ class CustomBuildPy(build_py): # Copy the directory shutil.copytree(src_dir, dst_dir) - def copy_flashmla_to_build_lib(self): - """ - This function copies FlashMLA to python's site-packages - """ - dst_dir = os.path.join(self.build_lib, "flash_mla") - os.makedirs(dst_dir, exist_ok=True) - - src_dir = os.path.join(str(flashmla.resolve()), "flash_mla") - - if not os.path.exists(src_dir): - print( - f"Warning: Source directory {src_dir} does not exist, possibly the submodule is not properly initialized" - ) - return - - if os.path.exists(dst_dir): - shutil.rmtree(dst_dir) - - shutil.copytree(src_dir, dst_dir) - nvcc_flags = [ "-DNDEBUG", diff --git a/sgl-kernel/tests/test_flash_mla.py b/sgl-kernel/tests/test_flash_mla.py deleted file mode 100644 index b3b47ad3f..000000000 --- a/sgl-kernel/tests/test_flash_mla.py +++ /dev/null @@ -1,153 +0,0 @@ -import argparse -import math -import random - -import torch -import triton -from flash_mla import flash_mla_with_kvcache, get_mla_metadata - -""" -fork FlashMLA/tests/test_flash_mla.py -""" - - -def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): - query = query.float() - key = key.float() - value = value.float() - key = key.repeat_interleave(h_q // h_kv, dim=0) - value = value.repeat_interleave(h_q // h_kv, dim=0) - attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) - if is_causal: - s_q = query.shape[-2] - s_k = key.shape[-2] - attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - attn_weight += attn_bias - lse = attn_weight.logsumexp(dim=-1) - attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) - return attn_weight @ value, lse - - -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: - x, y = x.double(), y.double() - RMSE = ((x - y) * (x - y)).mean().sqrt().item() - cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) - amax_diff = (x - y).abs().max().item() - # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5 - - -@torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): - print( - f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" - ) - - cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) - if varlen: - for i in range(b): - cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) - total_seqlens = cache_seqlens.sum().item() - mean_seqlens = cache_seqlens.float().mean().int().item() - max_seqlen = cache_seqlens.max().item() - max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 - # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") - - q = torch.randn(b, s_q, h_q, d) - block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32 - ).view(b, max_seqlen_pad // block_size) - blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( - float("nan") - ) - blocked_v = blocked_k[..., :dv] - - tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens, s_q * h_q // h_kv, h_kv - ) - - def flash_mla(): - return flash_mla_with_kvcache( - q, - blocked_k, - block_table, - cache_seqlens, - dv, - tile_scheduler_metadata, - num_splits, - causal=causal, - ) - - def ref_mla(): - out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) - lse = torch.empty(b, h_q, s_q, dtype=torch.float32) - for i in range(b): - begin = i * max_seqlen_pad - end = begin + cache_seqlens[i] - O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), - h_q=h_q, - h_kv=h_kv, - is_causal=causal, - ) - out[i] = O.transpose(0, 1) - lse[i] = LSE - return out, lse - - out_flash, lse_flash = flash_mla() - out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") - cal_diff(lse_flash, lse_torch, "lse") - - t = triton.testing.do_bench(flash_mla) - FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(q.dtype).bits // 8 - ) - print( - f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" - ) - - -def main(torch_dtype): - device = torch.device("cuda:0") - torch.set_default_dtype(torch_dtype) - torch.set_default_device(device) - torch.cuda.set_device(device) - torch.manual_seed(0) - random.seed(0) - - h_kv = 1 - d, dv = 576, 512 - causal = True - - for b in [128]: - for s in [4096, 8192]: - for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 - for s_q in [1, 2]: # MTP = 1, 2 - for varlen in [False, True]: - test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dtype", - type=str, - choices=["bf16", "fp16"], - default="bf16", - help="Data type to use for testing (bf16 or fp16)", - ) - args = parser.parse_args() - torch_dtype = torch.bfloat16 - if args.dtype == "fp16": - torch_dtype = torch.float16 - main(torch_dtype)