diff --git a/.gitmodules b/.gitmodules index 265d8d989..035314ec3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "sgl-kernel/3rdparty/deepgemm"] path = sgl-kernel/3rdparty/deepgemm url = https://github.com/deepseek-ai/DeepGEMM +[submodule "sgl-kernel/3rdparty/flashmla"] + path = sgl-kernel/3rdparty/flashmla + url = https://github.com/deepseek-ai/FlashMLA diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index ac2cc26bd..37555ab55 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -18,6 +18,21 @@ import shutil import sys from pathlib import Path +# Setup flash_mla at the top level for tests to find +# This makes the module importable without installation +root_dir = Path(__file__).parent.resolve() +module_src = root_dir / "3rdparty" / "flashmla" / "flash_mla" +module_dest = root_dir / "flash_mla" + +if module_src.exists() and not module_dest.exists(): + try: + os.symlink(module_src, module_dest, target_is_directory=True) + print(f"Created symbolic link from {module_src} to {module_dest}") + except (OSError, NotImplementedError): + if module_src.exists(): + shutil.copytree(module_src, module_dest) + print(f"Copied directory from {module_src} to {module_dest}") + import torch from setuptools import find_packages, setup from setuptools.command.build_py import build_py @@ -55,6 +70,7 @@ cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" deepgemm = root / "3rdparty" / "deepgemm" +flashmla = root / "3rdparty" / "flashmla" include_dirs = [ root / "include", root / "csrc", @@ -63,6 +79,7 @@ include_dirs = [ flashinfer.resolve() / "include", flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "csrc", + flashmla.resolve() / "csrc", "cublas", ] @@ -70,6 +87,7 @@ include_dirs = [ class CustomBuildPy(build_py): def run(self): self.copy_deepgemm_to_build_lib() + self.copy_flashmla_to_build_lib() self.make_jit_include_symlinks() build_py.run(self) @@ -93,6 +111,17 @@ class CustomBuildPy(build_py): os.unlink(dst_dir) os.symlink(src_dir, dst_dir, target_is_directory=True) + # Create symbolic links for FlashMLA + flash_mla_include_dir = os.path.join(self.build_lib, "flash_mla/include") + os.makedirs(flash_mla_include_dir, exist_ok=True) + + # Create empty directories for FlashMLA's include paths + # This is safer than creating symlinks as the targets might not exist in CI + for dirname in ["cute", "cutlass"]: + dst_dir = f"{flash_mla_include_dir}/{dirname}" + if not os.path.exists(dst_dir): + os.makedirs(dst_dir, exist_ok=True) + def copy_deepgemm_to_build_lib(self): """ This function copies DeepGemm to python's site-packages @@ -110,6 +139,26 @@ class CustomBuildPy(build_py): # Copy the directory shutil.copytree(src_dir, dst_dir) + def copy_flashmla_to_build_lib(self): + """ + This function copies FlashMLA to python's site-packages + """ + dst_dir = os.path.join(self.build_lib, "flash_mla") + os.makedirs(dst_dir, exist_ok=True) + + src_dir = os.path.join(str(flashmla.resolve()), "flash_mla") + + if not os.path.exists(src_dir): + print( + f"Warning: Source directory {src_dir} does not exist, possibly the submodule is not properly initialized" + ) + return + + if os.path.exists(dst_dir): + shutil.rmtree(dst_dir) + + shutil.copytree(src_dir, dst_dir) + nvcc_flags = [ "-DNDEBUG", diff --git a/sgl-kernel/tests/test_flash_mla.py b/sgl-kernel/tests/test_flash_mla.py new file mode 100644 index 000000000..b3b47ad3f --- /dev/null +++ b/sgl-kernel/tests/test_flash_mla.py @@ -0,0 +1,153 @@ +import argparse +import math +import random + +import torch +import triton +from flash_mla import flash_mla_with_kvcache, get_mla_metadata + +""" +fork FlashMLA/tests/test_flash_mla.py +""" + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + x, y = x.double(), y.double() + RMSE = ((x - y) * (x - y)).mean().sqrt().item() + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + amax_diff = (x - y).abs().max().item() + # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") + assert cos_diff < 1e-5 + + +@torch.inference_mode() +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + ) + + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + float("nan") + ) + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, s_q * h_q // h_kv, h_kv + ) + + def flash_mla(): + return flash_mla_with_kvcache( + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + ) + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q=h_q, + h_kv=h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_flash, lse_flash = flash_mla() + out_torch, lse_torch = ref_mla() + cal_diff(out_flash, out_torch, "out") + cal_diff(lse_flash, lse_torch, "lse") + + t = triton.testing.do_bench(flash_mla) + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( + torch.finfo(q.dtype).bits // 8 + ) + print( + f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" + ) + + +def main(torch_dtype): + device = torch.device("cuda:0") + torch.set_default_dtype(torch_dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + + h_kv = 1 + d, dv = 576, 512 + causal = True + + for b in [128]: + for s in [4096, 8192]: + for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 + for s_q in [1, 2]: # MTP = 1, 2 + for varlen in [False, True]: + test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + choices=["bf16", "fp16"], + default="bf16", + help="Data type to use for testing (bf16 or fp16)", + ) + args = parser.parse_args() + torch_dtype = torch.bfloat16 + if args.dtype == "fp16": + torch_dtype = torch.float16 + main(torch_dtype)