From f65b8d5c896c3f51f8f1e1ff03cd556e1c5dc1fd Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 11 Apr 2025 22:16:51 -0700 Subject: [PATCH] Blackwell Cutlass MLA kernel (#5142) --- sgl-kernel/CMakeLists.txt | 5 +- .../csrc/attention/cutlass_mla_kernel.cu | 207 ++++++++++++++++++ sgl-kernel/csrc/common_extension.cc | 5 + sgl-kernel/include/sgl_kernel_ops.h | 9 +- sgl-kernel/python/sgl_kernel/__init__.py | 6 +- sgl-kernel/python/sgl_kernel/attention.py | 61 ++++++ sgl-kernel/tests/test_cutlass_mla.py | 81 +++++++ 7 files changed, 371 insertions(+), 3 deletions(-) create mode 100644 sgl-kernel/csrc/attention/cutlass_mla_kernel.cu create mode 100644 sgl-kernel/tests/test_cutlass_mla.py diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 220896975..c3c769d5e 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -33,7 +33,7 @@ include(FetchContent) FetchContent_Declare( repo-cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass - GIT_TAG 6f4921858b3bb0a82d7cbeb4e499690e9ae60d16 + GIT_TAG df8a550d3917b0e97f416b2ed8c2d786f7f686a3 GIT_SHALLOW OFF ) FetchContent_Populate(repo-cutlass) @@ -76,6 +76,8 @@ include_directories( ${PROJECT_SOURCE_DIR}/csrc ${repo-cutlass_SOURCE_DIR}/include ${repo-cutlass_SOURCE_DIR}/tools/util/include + ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha + ${repo-cutlass_SOURCE_DIR}/examples/common ${repo-flashinfer_SOURCE_DIR}/include ${repo-flashinfer_SOURCE_DIR}/csrc ${repo-flash-attention_SOURCE_DIR}/hopper @@ -158,6 +160,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE set(SOURCES "csrc/allreduce/custom_all_reduce.cu" + "csrc/attention/cutlass_mla_kernel.cu" "csrc/attention/lightning_attention_decode_kernel.cu" "csrc/elementwise/activation.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu" diff --git a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu new file mode 100644 index 000000000..da6ea2a08 --- /dev/null +++ b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu @@ -0,0 +1,207 @@ +/* +Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +using namespace cute; +using namespace cutlass::fmha::kernel; + +template +struct IsPersistent { + static const bool value = v; +}; + +template > +struct MlaSm100 { + using Element = T; + using ElementAcc = float; + using ElementOut = T; + + using TileShape = Shape<_128, _128, Shape<_512, _64>>; + using TileShapeH = cute::tuple_element_t<0, TileShape>; + using TileShapeD = cute::tuple_element_t<2, TileShape>; + + // H K (D_latent D_rope) B + using ProblemShape = cute::tuple; + + using StrideQ = cute::tuple; // H D B + using StrideK = cute::tuple; // K D B + using StrideO = StrideK; // H D B + using StrideLSE = cute::tuple<_1, int>; // H B + + using TileScheduler = + std::conditional_t; + + using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< + TileShape, + Element, + ElementAcc, + ElementOut, + ElementAcc, + TileScheduler, + /*kIsCpAsync=*/true>; + using Fmha = cutlass::fmha::device::MLA; +}; + +template +typename T::Fmha::Arguments args_from_options( + at::Tensor const& out, + at::Tensor const& q_nope_and_q_pe, + at::Tensor const& kv_c_and_k_pe_cache, + at::Tensor const& seq_lens, + at::Tensor const& page_table) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = q_nope_and_q_pe.device().index(); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + int batches = q_nope_and_q_pe.sizes()[0]; + int page_count_per_seq = page_table.sizes()[1]; + int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; + int page_size = kv_c_and_k_pe_cache.sizes()[1]; + int max_seq_len = page_size * page_count_per_seq; + using TileShapeH = typename T::TileShapeH; + using TileShapeD = typename T::TileShapeD; + auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + // the scale is based on the non-absorbed sizes, change as appropriate + // we can't determine this parameter from the info we have, it's an input + int D_non_latent = 128; + float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope)); + + using StrideQ = typename T::StrideQ; + using StrideK = typename T::StrideK; + using StrideO = typename T::StrideO; + using StrideLSE = typename T::StrideLSE; + + StrideQ stride_Q = cute::make_tuple( + static_cast(0 + D_latent + D_rope), _1{}, static_cast(H * (0 + D_latent + D_rope))); + StrideK stride_C = cute::make_tuple( + static_cast(0 + D_latent + D_rope), _1{}, static_cast(page_size * (D_latent + D_rope))); + StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); + StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H); + StrideO stride_O = cute::make_tuple(static_cast(0 + D_latent), _1{}, static_cast(0 + H * D_latent)); + + using Element = typename T::Element; + using ElementOut = typename T::ElementOut; + using ElementAcc = typename T::ElementAcc; + auto Q_ptr = static_cast(q_nope_and_q_pe.data_ptr()); + auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); + typename T::Fmha::Arguments arguments{ + problem_shape, + {scale, + Q_ptr, + stride_Q, + Q_ptr + D_latent, + stride_Q, + C_ptr, + stride_C, + C_ptr + D_latent, + stride_C, + static_cast(seq_lens.data_ptr()), + static_cast(page_table.data_ptr()), + stride_PT, + page_count_total, + page_size}, + {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, + hw_info, + -1, // split_kv + nullptr, // is_var_split_kv + }; + // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute + // split_kv automatically based on batch size and sequence length to balance + // workload across available SMs. Consider using var_split_kv for manual + // control if needed. + T::Fmha::set_split_kv(arguments); + return arguments; +} + +template +void runMla( + at::Tensor const& out, + at::Tensor const& q_nope_and_q_pe, + at::Tensor const& kv_c_and_k_pe_cache, + at::Tensor const& seq_lens, + at::Tensor const& page_table, + at::Tensor const& workspace, + cudaStream_t stream) { + using MlaSm100Type = MlaSm100; + typename MlaSm100Type::Fmha fmha; + auto arguments = args_from_options(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table); + + CUTLASS_CHECK(fmha.can_implement(arguments)); + + CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); +} + +void cutlass_mla_decode( + torch::Tensor const& out, + torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, + torch::Tensor const& workspace) { + auto in_dtype = q_nope_and_q_pe.dtype(); + at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device()); + if (in_dtype == at::ScalarType::Half) { + runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream); + } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { + runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream); + } else { + TORCH_CHECK(false, "Unsupported input data type of MLA"); + } +} + +int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count) { + // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc) + // which are float, so Element type here doesn't matter. + using MlaSm100Type = MlaSm100; + + // Get split kv. Requires problem shape and sm_count only. + typename MlaSm100Type::Fmha::Arguments arguments; + using TileShapeH = typename MlaSm100Type::TileShapeH; + using TileShapeD = typename MlaSm100Type::TileShapeD; + arguments.problem_shape = + cute::make_tuple(TileShapeH{}, static_cast(max_seq_len), TileShapeD{}, static_cast(num_batches)); + // Assumes device 0 when getting sm_count. + arguments.hw_info.sm_count = + sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; + MlaSm100Type::Fmha::set_split_kv(arguments); + + return MlaSm100Type::Fmha::get_workspace_size(arguments); +} diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index c9b2c8516..346b2e133 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -45,6 +45,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " "new_kv) -> ()"); m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + m.def( + "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor " + "page_table, Tensor workspace) -> ()"); + m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); + m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size); /* * From csrc/elementwise diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 4aa8535f5..d1222b1dd 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -87,7 +87,14 @@ void lightning_attention_decode( const torch::Tensor& slope, torch::Tensor output, torch::Tensor new_kv); - +void cutlass_mla_decode( + torch::Tensor const& out, + torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, + torch::Tensor const& workspace); +int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0); /* * From csrc/elementwise */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 0fc68cfcc..2d6bc0d56 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -11,7 +11,11 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"): from sgl_kernel import common_ops from sgl_kernel.allreduce import * -from sgl_kernel.attention import lightning_attention_decode +from sgl_kernel.attention import ( + cutlass_mla_decode, + cutlass_mla_get_workspace_size, + lightning_attention_decode, +) from sgl_kernel.elementwise import ( apply_rope_with_cos_sin_cache_inplace, fused_add_rmsnorm, diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index 6ad1d347e..b90834194 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -5,3 +5,64 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): torch.ops.sgl_kernel.lightning_attention_decode.default( q, k, v, past_kv, slope, output, new_kv ) + + +def cutlass_mla_decode( + q_nope_and_q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + workspace: torch.Tensor, +) -> torch.Tensor: + assert ( + q_nope_and_q_pe.ndim == 3 + ), f"q_nope_and_q_pe must be a 3D tensor, but got {q_nope_and_q_pe.ndim}" + assert ( + kv_c_and_k_pe_cache.ndim == 3 + ), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}" + B_q, H, D_q = q_nope_and_q_pe.shape + _, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape + + D_latent = 512 + D_rope = 64 + assert D_q == D_ckv and D_q == D_latent + D_rope, ( + f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, " + f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}" + ) + assert H == 128, f"H must be 128, but got {H}" + # TODO: There is currently an illegal memory access issue with page size != + # 128. Change this when it is fixed. + assert PAGE_SIZE == 128, f"PAGE_SIZE must be 128, but got {PAGE_SIZE}" + + # TODO(kaixih@nvidia): support fp8 + assert q_nope_and_q_pe.dtype in ( + torch.float16, + torch.bfloat16, + ), f"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got {q_nope_and_q_pe.dtype}." + assert kv_c_and_k_pe_cache.dtype == q_nope_and_q_pe.dtype, ( + f"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, " + f"but got {kv_c_and_k_pe_cache.dtype}." + ) + assert ( + seq_lens.dtype == torch.int32 + ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." + assert ( + page_table.dtype == torch.int32 + ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." + + out = torch.empty( + (B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype + ) + + torch.ops.sgl_kernel.cutlass_mla_decode( + out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace + ) + return out + + +def cutlass_mla_get_workspace_size( + max_seq_len: int, num_batches: int, sm_count: int = 0 +) -> int: + return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size( + max_seq_len, num_batches, sm_count + ) diff --git a/sgl-kernel/tests/test_cutlass_mla.py b/sgl-kernel/tests/test_cutlass_mla.py new file mode 100644 index 000000000..26a59ad7c --- /dev/null +++ b/sgl-kernel/tests/test_cutlass_mla.py @@ -0,0 +1,81 @@ +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size +from torch import Tensor + +if torch.cuda.get_device_capability() < (10, 0): + pytest.skip( + reason="Cutlass MLA Requires compute capability of 10 or above.", + allow_module_level=True, + ) + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) +@pytest.mark.parametrize("bs", [1, 2, 4]) +@pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize("block_size", [128]) +def test_cutlass_mla_decode( + dtype: torch.dtype, mean_seq_len: int, bs: int, varlen: bool, block_size: int +): + torch.set_default_dtype(dtype) + torch.set_default_device("cuda") + torch.manual_seed(42) + + d = 576 + h_q = 128 + dv = 512 + + q_nope_dim = 128 + q_pe_dim = 64 + scale = (q_nope_dim + q_pe_dim) ** (-0.5) + if varlen: + seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) + seq_lens = seq_lens.clip(2).to(torch.int32) + else: + seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32) + max_seq_len = seq_lens.max().item() + block_num = (max_seq_len + block_size - 1) // block_size + + q = torch.randn(bs, h_q, d) + block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32) + + kv_cache = torch.randn(block_table.numel(), block_size, d) + + workspace_size = cutlass_mla_get_workspace_size(block_num * block_size, bs) + workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) + + out_ref = q.new_zeros(bs, h_q, dv) + ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) + out = cutlass_mla_decode(q, kv_cache, seq_lens, block_table, workspace) + + torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__])