diff --git a/sgl-kernel/benchmark/bench_cutlass_mla.py b/sgl-kernel/benchmark/bench_cutlass_mla.py new file mode 100644 index 000000000..9ac97e20d --- /dev/null +++ b/sgl-kernel/benchmark/bench_cutlass_mla.py @@ -0,0 +1,133 @@ +import argparse +import copy +import itertools + +import torch +import triton +from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size + +bs_range = [1, 8, 32, 64, 128, 256] +qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + +configs = list(itertools.product(bs_range, qlen_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=configs, + x_log=False, + line_arg="provider", + line_vals=[ + "128 heads", + "64 heads", + "32 heads", + "16 heads", + ], + line_names=[ + "128 heads", + "64 heads", + "32 heads", + "16 heads", + ], + styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], + ylabel="GB/s", + plot_name="cutlass mla", + args={}, + ) +) +def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits): + d = 576 + dv = 512 + + h_q_map = { + "128": 128, + "64": 64, + "32": 32, + "16": 16, + } + parsed_h_q = next( + (value for key, value in h_q_map.items() if key in provider), None + ) + + if parsed_h_q is None: + raise ValueError(f"Unknown head configuration in provider: {provider}") + h_q = parsed_h_q + + seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + max_seq_len = seq_lens.max().item() + block_num = (max_seq_len + block_size - 1) // block_size + + # Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles. + # One 128-wide tile can hold (128 // block_size) small blocks. + pack_factor = 128 // block_size + block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor + + q = torch.randn(batch_size, h_q, d, dtype=torch.bfloat16, device="cuda") * 100.0 + block_table = torch.randint( + 0, + batch_size * block_num, + (batch_size, block_num), + dtype=torch.int32, + device="cuda", + ) + + kv_cache = torch.randn( + block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda" + ) + + workspace_size = cutlass_mla_get_workspace_size( + block_num * block_size, batch_size, num_kv_splits=num_kv_splits + ) + workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: cutlass_mla_decode( + q, kv_cache, seq_lens, block_table, workspace, num_kv_splits + ), + quantiles=quantiles, + ) + + gbps = ( + lambda ms: ( + q.numel() * q.element_size() + + q.numel() * q.element_size() * dv / d + + kv_cache.numel() * kv_cache.element_size() + ) + * 1e-9 + / (ms * 1e-3) + ) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--block-sizes", + nargs="+", + type=int, + default=[1, 32, 64, 128], + help="List of batch sizes", + ) + parser.add_argument( + "--num-kv-splits", + nargs="+", + type=int, + default=[-1], + help="List of batch sizes", + ) + args = parser.parse_args() + + for block_size in args.block_sizes: + for kv_split in args.num_kv_splits: + print(f"block_size={block_size}, num_kv_splits={kv_split}: ") + benchmark.run( + print_data=True, + show_plots=True, + save_path="bench_blackwell_mla_res", + block_size=block_size, + num_kv_splits=kv_split, + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu index 6fb9e83f3..d72da50c5 100644 --- a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu +++ b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu @@ -22,8 +22,9 @@ limitations under the License. #include #include -#include -#include + +#include "cutlass_sm100_mla/device/sm100_mla.hpp" +#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp" // clang-format off #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 @@ -55,7 +56,7 @@ struct IsPersistent { static const bool value = v; }; -template > +template > struct MlaSm100 { using Element = T; using ElementAcc = float; @@ -83,7 +84,7 @@ struct MlaSm100 { ElementOut, ElementAcc, TileScheduler, - /*kIsCpAsync=*/true>; + /*kIsCpAsync=*/!IsPaged128>; using Fmha = cutlass::fmha::device::MLA; }; @@ -93,7 +94,8 @@ typename T::Fmha::Arguments args_from_options( at::Tensor const& q_nope_and_q_pe, at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, - at::Tensor const& page_table) { + at::Tensor const& page_table, + int64_t num_kv_splits) { cutlass::KernelHardwareInfo hw_info; hw_info.device_id = q_nope_and_q_pe.device().index(); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); @@ -154,8 +156,8 @@ typename T::Fmha::Arguments args_from_options( // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will // perform worse with larger context length and smaller batch sizes. - 1, // split_kv - nullptr, // is_var_split_kv + num_kv_splits, // split_kv + nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute // split_kv automatically based on batch size and sequence length to balance @@ -165,7 +167,7 @@ typename T::Fmha::Arguments args_from_options( return arguments; } -template +template void runMla( at::Tensor const& out, at::Tensor const& q_nope_and_q_pe, @@ -173,10 +175,11 @@ void runMla( at::Tensor const& seq_lens, at::Tensor const& page_table, at::Tensor const& workspace, + int64_t num_kv_splits, cudaStream_t stream) { - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table); + auto arguments = args_from_options(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -185,31 +188,57 @@ void runMla( CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); } +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + void cutlass_mla_decode( torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, - torch::Tensor const& workspace) { + torch::Tensor const& workspace, + int64_t num_kv_splits) { auto in_dtype = q_nope_and_q_pe.dtype(); at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()}; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device()); - if (in_dtype == at::ScalarType::Half) { - runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { - runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream); - } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream); - } else { - TORCH_CHECK(false, "Unsupported input data type of MLA"); - } + const int page_size = kv_c_and_k_pe_cache.sizes()[1]; + + // NOTE(alcanderian): IsPersistent has bug with manual split_kv. + // Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8) + // Maybe per batch split kv will fix this. + DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { + DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { + if (in_dtype == at::ScalarType::Half) { + runMla>( + out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + runMla>( + out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); + } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { + runMla>( + out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); + } else { + TORCH_CHECK(false, "Unsupported input data type of MLA"); + } + return true; + }); + return true; + }); } -int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count) { +int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc) // which are float, so Element type here doesn't matter. - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; // Get split kv. Requires problem shape and sm_count only. typename MlaSm100Type::Fmha::Arguments arguments; @@ -220,6 +249,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, // Assumes device 0 when getting sm_count. arguments.hw_info.sm_count = sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; + arguments.split_kv = num_kv_splits; MlaSm100Type::Fmha::set_split_kv(arguments); return MlaSm100Type::Fmha::get_workspace_size(arguments); diff --git a/sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp b/sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp new file mode 100644 index 000000000..dd4ed231b --- /dev/null +++ b/sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp @@ -0,0 +1,358 @@ +/*************************************************************************************************** + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +// clang-format off +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp" +#include "../kernel/sm100_fmha_mla_reduction.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +using namespace cute; +using namespace cutlass::fmha::kernel; + + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class Kernel_ +> +class MLA { +public: + + using Kernel = Kernel_; + + using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< + typename Kernel::ElementOut, + typename Kernel::ElementAcc, + typename Kernel::ElementAcc, + Kernel::TileShapeH::value, + Kernel::TileShapeL::value, + 256 /*Max split*/ + >; + + /// Argument structure: User API + using KernelArguments = typename Kernel::Arguments; + using ReductionArguments = typename ReductionKernel::Arguments; + + using Arguments = KernelArguments; + + /// Argument structure: Kernel API + using KernelParams = typename Kernel::Params; + using ReductionParams = typename ReductionKernel::Params; + struct Params { + KernelParams fmha_params; + ReductionParams reduction_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + + static ReductionArguments to_reduction_args(Arguments const& args) { + auto [H, K, D, B] = args.problem_shape; + return ReductionArguments{ + nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse, + args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq, + args.ptr_split_kv, Kernel::TileShapeS::value + }; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + static void set_split_kv (KernelArguments& args) { + if (args.split_kv >= 1) return; + auto [H, K, D, B] = args.problem_shape; + int sm_count = args.hw_info.sm_count; + int max_splits = ceil_div(K, 128); + int sms_per_batch = max(1, sm_count / B); + int split_heur = min(max_splits, sms_per_batch); + int waves = ceil_div(B * split_heur, sm_count); + int k_waves = ceil_div(max_splits, split_heur); + int split_wave_aware = ceil_div(max_splits, k_waves); + args.split_kv = split_wave_aware; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (! Kernel::can_implement(args)) { + return Status::kInvalid; + } + if (! ReductionKernel::can_implement(to_reduction_args(args))) { + return Status::kInvalid; + } + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); + return workspace_bytes; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); + if (status != Status::kSuccess) { + return status; + } + KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {kernel_params, reduction_params}; + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + // no dynamic smem is needed for reduction kernel + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + auto fmha_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {fmha_params, reduction_params}; + + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = Kernel::get_grid_shape(params.fmha_params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms.fmha_params}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params.fmha_params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess != result or Status::kSuccess != launch_result) { + //return Status::kSuccess; + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + if (params.reduction_params.split_kv > 1) { + // launch reduction kernel + dim3 const block = ReductionKernel::get_block_shape(); + dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params); + device_kernel<<>>(params.reduction_params); + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + else { + return Status::kSuccess; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp b/sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp new file mode 100644 index 000000000..b75870d0c --- /dev/null +++ b/sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// clang-format off +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +template< + class ElementOut, + class ElementAcc, + class ElementScale, + size_t kNumHeads, + size_t kHeadDimLatent, + int kMaxSplits +> +struct Sm100FmhaMlaReductionKernel { + + static const int SharedStorageSize = 0; + static const int MaxThreadsPerBlock = 128; + static const int MinBlocksPerMultiprocessor = 1; + + using ArchTag = cutlass::arch::Sm100; + + static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0); + struct Arguments { + ElementAcc* ptr_oaccum = nullptr; + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_lseaccum = nullptr; + ElementAcc* ptr_lse = nullptr; + ElementScale scale = 1.f; + int num_batches = 0; + int split_kv = -1; + int dim_k = -1; + int* ptr_seq = nullptr; + int* ptr_split_kv = nullptr; + int tile_shape_s = 128; + }; + using Params = Arguments; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse, + args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq, + args.ptr_split_kv, args.tile_shape_s}; + } + + static size_t get_workspace_size(Arguments const& /*args*/) { + return 0; + } + + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return dim3(kNumHeads, 1, params.num_batches); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + static bool can_implement(Arguments const& args) { + if (args.num_batches <= 0) return false; + if (args.split_kv <= 0) return false; + return true; + } + + CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) { + if (params.split_kv <= 1) return; + auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z); + + __shared__ ElementAcc sLseScale[kMaxSplits]; + const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord); + const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord); + + Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum), + make_shape(params.split_kv), Stride>{}); + + Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse), + Shape<_1>{}, Stride<_1>{}); + + auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)]; + auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)]; + auto k_tile_total = ceil_div(dim_k, params.tile_shape_s); + auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv); + local_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + ElementAcc local_lse[kNLsePerThread]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits::infinity(); + } + + ElementAcc lse_max = -std::numeric_limits::infinity(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + lse_max = max(lse_max, local_lse[i]); + } + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset)); + } + lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = __shfl_sync(0xffffffff, lse_max, 0); + + ElementAcc sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + sum_lse = sum_lse + expf(local_lse[i] - lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset); + } + + sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); + + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + lse_max; + if (threadIdx.x == 0 and params.ptr_lse != nullptr) { + gLSE(0) = global_lse; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + if (split < local_split_kv) { + sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + } + __syncthreads(); + + constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock; + const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord)); + Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum), + Shape>{}, Stride<_1>{}); + ElementAcc local_val[Elements] = {0}; + for (int split = 0; split < local_split_kv; ++split) { + ElementAcc lse_scale = sLseScale[split]; + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i); + } + gOaccum.data() = gOaccum.data() + kHeadDimLatent; + } + auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent; + Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape>{}, Stride<_1>{}); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast(local_val[i]); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp new file mode 100644 index 000000000..9809db84e --- /dev/null +++ b/sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -0,0 +1,2018 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// clang-format off +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "gather_tensor.hpp" // from examples/common +#include "common/pow_2.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class TileShape, + class Element_, + class ElementAcc_, + class ElementOut_, + class ElementLSE_, + class TileScheduler, +#ifdef CPASYNC + bool kIsCpAsync = true +#else + bool kIsCpAsync = false +#endif +> +struct Sm100FmhaMlaKernelTmaWarpspecialized { + + using Element = Element_; + using ElementAcc = ElementAcc_; + using ElementOut = ElementOut_; + using ElementLSE = ElementLSE_; + + // only 2Sm mode is supported + static const bool kIs2Sm = true; + static const int MaxThreadsPerBlock = 256; + static const int MinBlocksPerMultiprocessor = 1; + static const int TotalSNum = 2; + static const int TotalPNum = 2; + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = cute::conditional_t, Shape<_1, _1, _1>>; + + using TileShapeH = tuple_element_t<0, TileShape>; + using TileShapeS = tuple_element_t<1, TileShape>; + using TileShapeD = tuple_element_t<2, TileShape>; + + using TileShapeL = tuple_element_t<0, TileShapeD>; + using TileShapeR = tuple_element_t<1, TileShapeD>; + static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim"); + + using ProblemShape = Shape; + using TensorStride = Stride; + using TmemAllocator = cute::conditional_t; + + static_assert(TileShapeH{} == 128); + static const int kWarpsInN = kIs2Sm ? 2 : 1; + + static const int kNumComputeWarps = 4; + static const int kNumLoadWarps = kIsCpAsync ? 2 : 1; + + enum class WarpRole { + kMma = 0x1, kLoad = 0x2, kCompute = 0x3, kLoadPageTable = 0x4, kEmpty=0x0 + }; + + static const long long unsigned int kWarpAssignment = kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; + + static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + static const int Alignment = 128 / sizeof_bits_v; + static const int AlignmentOut = 128 / sizeof_bits_v; + + using TileShapeQK = Shape; + static const int StagesQK = 24 / sizeof(Element); // free parameter + static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQK = IterationsQKLatent + IterationsQKRope; + + using Schedule = cute::conditional_t; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TensorStride, Alignment, + ElementAcc, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK; + + // chosen for unified smem staging between K and V + using TileShapePV = Shape; + using TransposeTensorStride = decltype(select<1,0,2>(TensorStride{})); + static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes + static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; + static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TransposeTensorStride, Alignment, + ElementAcc, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK; + static_assert(std::is_same_v); + + using TiledMmaPV = typename CollectiveMmaPV::TiledMma; + + using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK; + static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == typename CollectiveMmaPV::AtomThrShapeMNK{}, "schedule must match"); + + static const int StagesPageTable = kIsCpAsync ? StagesPV : 1; + + // pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd + // use expect_tx for Q load + using PipelineLoadQK = cute::conditional_t, PipelineTmaUmmaAsync>; + using PipelineLoadPV = PipelineLoadQK; + // pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages + using PipelineS = PipelineUmmaAsync; + // pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages + using PipelineP = PipelineUmmaConsumerAsync; + // pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage + using PipelineO = PipelineUmmaAsync<1, AtomThrShapeMNK>; + + using PipelinePT = PipelineAsync; + + struct PipelineStorage { + alignas(16) typename PipelineLoadQK::SharedStorage load_qk; + alignas(16) typename PipelineS::SharedStorage mma_s; + alignas(16) typename PipelineP::SharedStorage p_mma; + alignas(16) typename PipelineO::SharedStorage mma_o; + alignas(16) typename PipelinePT::SharedStorage load_page_table; + }; + + template + static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB; + using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int{}, _2{}))); + + static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v); + static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v); + // pre-condition for overlapped smem staging + static_assert(kBytesLoadKC == kBytesLoadVC); + static_assert(StagesQK == StagesPV); + + static const int kTransactionsBytesLoadQK = kBytesLoadKC; + static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ; + static const int kTransactionsBytesLoadPV = kBytesLoadVC; + + static const int kNamedBarrierExchange = (int) cutlass::arch::ReservedNamedBarriers::TransformBarrier; + // This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent + // tile scheduler for FP8 MLA. + static const int kNamedBarrierEpilogue = (int) cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; + // + static const int kNamedBarrierTmemDealloc = (int) cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; + + enum class TmemAllocation : uint32_t { + kSizeS = TileShapeS::value / kWarpsInN, + // Overall + kSizeO = TileShapeL::value / kWarpsInN, + // Between accumulators we loop over + kSizeAccO = decltype(get<1>(TileShapePV{}))::value / kWarpsInN, + kNumS = TotalSNum, + kNumP = TotalPNum, + kNumO = 1, + kS0 = 0, + kS1 = kS0 + kSizeS, + kO0 = kS1 + kSizeS, + kTotal = kO0 + kSizeO + }; + + static_assert(static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem"); + + struct TensorStorage { + // to communicate max and row_sum + cute::array smem_exchange; + cute::array smem_page_table; + alignas(2048) cute::array> smem_q; + union { + alignas(2048) cute::array> smem_kc; + alignas(2048) cute::array> smem_vc; + }; + alignas(2048) cute::array> smem_p; + }; + + struct SharedStorage { + PipelineStorage pipelines; + TensorStorage tensors; + uint32_t tmem_base_ptr; + }; + + static const int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + struct MainloopArguments { + ElementAcc softmax_scale; + + // all tensors strides are (num_heads or seqlen, head_dim, batch) + // head_dim stride is always 1 + Element* ptr_q_latent; + TensorStride stride_q_latent; + Element* ptr_q_rope; + TensorStride stride_q_rope; + + Element* ptr_c_latent; + TensorStride stride_c_latent; + Element* ptr_k_rope; + TensorStride stride_k_rope; + + // for paged attention, we interpret what was previously [batch, seqlen] + // as [page_count, page_size], and index according to page_table + int* ptr_seq = nullptr; + int* ptr_page_table = nullptr; + // page table is [batch, seqlen or similar] + Stride<_1, int> stride_page_table = {}; + int page_count = 0; + int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS + }; + + struct EpilogueArguments { + ElementOut* ptr_o = nullptr; + TensorStride stride_o; + ElementLSE* ptr_lse = nullptr; + Stride<_1, int> stride_lse; + ElementAcc output_scale = 1.0f; + }; + + struct Arguments { + // (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count) + // for paged attention, seqlen is max seqlen + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B; + + struct MainloopParams { + TmaLoadQLatent tma_load_q_latent; + TmaLoadQRope tma_load_q_rope; + TmaLoadCLatent tma_load_c_latent; + TmaLoadKRope tma_load_k_rope; + TmaLoadCLatentTranspose tma_load_c_latent_transpose; + }; + + struct EpilogueParams { + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_o_acc = nullptr; + TensorStride stride_o; + TensorStride stride_o_acc; + ElementLSE* ptr_lse = nullptr; + ElementLSE* ptr_lse_acc = nullptr; + Stride<_1, int> stride_lse; + Stride<_1, int> stride_lse_acc; + ElementAcc output_scale = 1.0f; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueParams epilogue; + MainloopParams mainloop_params; + typename TileScheduler::Params tile_scheduler; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + //workspace = nullptr; // let's get an error if one of these needs workspace + + auto [H, K, D, B] = args.problem_shape; + auto [L, R] = D; + + int paged_B = B; + int paged_K = K; + if (args.mainloop.ptr_page_table != nullptr) { + paged_B = args.mainloop.page_count; + paged_K = args.mainloop.page_size; + } + + auto params_qk_latent = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, L, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_latent_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, L, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_rope = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, R, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + auto params_qk_rope_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, R, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + + auto stride_c_latent_transpose = select<1,0,2>(args.mainloop.stride_c_latent); + auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments( + make_shape(H, L, paged_K, paged_B), + typename CollectiveMmaPV::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, // dummy, never used + args.mainloop.ptr_c_latent, stride_c_latent_transpose, + }, nullptr); + + MainloopParams mainloop_params { + params_qk_latent.tma_load_a, + params_qk_rope.tma_load_a, + params_qk_latent_paged.tma_load_b, + params_qk_rope_paged.tma_load_b, + params_pv_latent.tma_load_b + }; + + EpilogueParams epilogue_params; + + epilogue_params.ptr_o = args.epilogue.ptr_o; + epilogue_params.stride_o = args.epilogue.stride_o; + epilogue_params.ptr_lse = args.epilogue.ptr_lse; + epilogue_params.stride_lse = args.epilogue.stride_lse; + epilogue_params.output_scale = args.epilogue.output_scale; + + if (args.split_kv > 1) { + ElementAcc* ptr_o_acc = reinterpret_cast(workspace); + ElementLSE* ptr_lse_acc = reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); + epilogue_params.ptr_o_acc = ptr_o_acc; + epilogue_params.ptr_lse_acc = ptr_lse_acc; + + epilogue_params.stride_o_acc = make_tuple(static_cast(0 + L) * args.split_kv, _1{}, static_cast(0 + H * L) * args.split_kv); + epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv); + } + + return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params, + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv}; + } + + static size_t get_workspace_size(Arguments const& args) { + ProblemShape problem_shape = args.problem_shape; + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto split_kv = args.split_kv; + return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B; + } + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static bool can_implement(Arguments const& args) { + if (kIsCpAsync) { + if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { + return false; + } + if (args.mainloop.page_size > TileShapeS{}) { + return false; + } + } + else { + if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + return false; + } + } + if (get<0>(args.problem_shape) != 128) { + return false; + } + if (get<1>(args.problem_shape) <= 0) { + return false; + } + if (args.split_kv <= 0) { + return false; + } + return true; + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { + + TileScheduler tile_scheduler(params.tile_scheduler); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{}); + bool is_mma_leader_cta = cta_coord_v == 0; + + if (role == WarpRole::kLoad && lane_predicate && ! kIsCpAsync) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor()); + } + SharedStorage& shared_storage = *reinterpret_cast(smem_raw); + + typename PipelineLoadQK::Params pipeline_load_qk_params; + if (role == WarpRole::kLoad) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Producer; + } + if (role == WarpRole::kMma) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Consumer; + } + if constexpr (kIsCpAsync) { + // we can make our life easier by unconditionally loading blocks + // since we know it'll always be legal + pipeline_load_qk_params.producer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + } + else { + pipeline_load_qk_params.is_leader = lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; + pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK; + } + pipeline_load_qk_params.initializing_warp = 0; + PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineS::Params pipeline_mma_s_params; + if (role == WarpRole::kMma) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_s_params.initializing_warp = 1; + PipelineS pipeline_mma_s( + shared_storage.pipelines.mma_s, + pipeline_mma_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineP::Params pipeline_p_mma_params; + if (role == WarpRole::kMma) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Consumer; + } + if (role == WarpRole::kCompute) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer; + } + pipeline_p_mma_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_p_mma_params.consumer_arv_count = 1; + pipeline_p_mma_params.initializing_warp = 2; + PipelineP pipeline_p_mma( + shared_storage.pipelines.p_mma, + pipeline_p_mma_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineO::Params pipeline_mma_o_params; + if (role == WarpRole::kMma) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_o_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_o_params.initializing_warp = 3; + PipelineO pipeline_mma_o( + shared_storage.pipelines.mma_o, + pipeline_mma_o_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelinePT::Params pipeline_pt_params; + if (role == WarpRole::kLoad) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Consumer; + } + if (role == WarpRole::kLoadPageTable) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer; + } + pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp; + pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp; + pipeline_pt_params.initializing_warp = 4; + PipelinePT pipeline_page_table( + shared_storage.pipelines.load_page_table, + pipeline_pt_params); + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm? + pipeline_mma_s.init_masks(ClusterShape{}); + pipeline_p_mma.init_masks(ClusterShape{}); + pipeline_mma_o.init_masks(ClusterShape{}); + + typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state; + typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = cutlass::make_producer_start_state(); + + typename PipelineS::PipelineState pipeline_mma_s_consumer_state; + typename PipelineS::PipelineState pipeline_mma_s_producer_state = cutlass::make_producer_start_state(); + + typename PipelineP::PipelineState pipeline_p_mma_consumer_state; + typename PipelineP::PipelineState pipeline_p_mma_producer_state = cutlass::make_producer_start_state(); + + typename PipelineO::PipelineState pipeline_mma_o_consumer_state; + typename PipelineO::PipelineState pipeline_mma_o_producer_state = cutlass::make_producer_start_state(); + + typename PipelinePT::PipelineState pipeline_pt_consumer_state; + typename PipelinePT::PipelineState pipeline_pt_producer_state = cutlass::make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + if (role == WarpRole::kLoadPageTable) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_page_table( + blk_coord, + problem_shape, + params.mainloop, + shared_storage.tensors, + pipeline_page_table, pipeline_pt_producer_state, + local_split_kv + ); + } + } + else if (role == WarpRole::kLoad) { + if constexpr (kIsCpAsync) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_cpasync( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv, + /* must be shared pipe */ + pipeline_page_table, pipeline_pt_consumer_state + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + if (params.mainloop.ptr_page_table != nullptr) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + } + } + else if (role == WarpRole::kMma) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + if (is_mma_leader_cta) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + mma(blk_coord, + problem_shape, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_mma_s, pipeline_mma_s_producer_state, + pipeline_p_mma, pipeline_p_mma_consumer_state, + pipeline_mma_o, pipeline_mma_o_producer_state, + local_split_kv + ); + } + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait(); + + //uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + //tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + else if (role == WarpRole::kCompute) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + compute( + blk_coord, + problem_shape, + params.mainloop, // for softmax_scale + params.epilogue, + shared_storage.tensors, // for smem_comm + pipeline_mma_s, pipeline_mma_s_consumer_state, + pipeline_p_mma, pipeline_p_mma_producer_state, + pipeline_mma_o, pipeline_mma_o_consumer_state, + local_split_kv + ); + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + } + + cute::cluster_sync(); + cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + if (role == WarpRole::kMma) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + template + CUTLASS_DEVICE void load_page_table( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + int batch_coord = get<2>(blk_coord); + + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(mainloop_args.page_count, B), + mainloop_args.stride_page_table); + auto mPT = mPT_l(_, batch_coord); + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + auto page_size = Pow2{mainloop_args.page_size}; + auto pages_per_tile = Pow2{TileShapeS{} / page_size}; + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp; + +#if 1 + for (; k_tile_count > 0; ++k_index, --k_tile_count) { + pipeline_page_table.producer_acquire(pipeline_pt_producer_state); + + // assume a single warp + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) { + int idx = i + thread_idx; + bool guard = idx < pages_per_tile; + int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx; + int pt_idx = pages_per_tile * k_index + idx; + + cutlass::arch::cp_async_zfill( + &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard + ); + } + + pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_pt_producer_state; + } +#endif + } + + + struct Gather { + int& page_table_stage; + Pow2 pages_per_tile; + const int * __restrict__ smem_page_table; + + CUTLASS_DEVICE int operator()(int idx) const { + return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile]; + } + + CUTLASS_DEVICE friend void print(Gather const&) { + printf(""); + } + + }; + + + template + CUTLASS_DEVICE void load_cpasync( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load, + typename PipelineLoadQK::PipelineState& pipeline_load_producer_state, + int const& split_kv, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_consumer_state) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using X = Underscore; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // partition all tensors + auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), mainloop_args.stride_q_latent); + auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), mainloop_args.stride_q_rope); + + int paged_B = mainloop_args.page_count; + auto paged_K = Pow2{mainloop_args.page_size}; + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + int batch_coord = get<2>(blk_coord); + auto mPT = mPT_l(_, batch_coord); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto make_copy_for = [](auto sT) { + auto rT_a = sT.layout()(_, _, _, _0{}); + auto rT = make_ordered_layout(shape(rT_a), stride(rT_a)); + auto threads = Int{}; + auto values = Int{}; + return make_cotiled_copy( + Copy_Atom, Element>{}, + make_ordered_layout( + make_shape(threads, values), + make_stride(_1{}, _0{})), + rT); + }; + + // like cute::copy, but makes sure we do all page table lookups first + auto copy_split = [](auto atom, auto src, auto dst) { + auto src_v = group_modes<1, rank_v>(src); + auto dst_v = group_modes<1, rank_v>(dst); + + auto src_v_ptrs = make_tensor(size<1>(src_v)); + for (int i = 0; i < size<1>(src_v); i++) { + src_v_ptrs(i) = &src_v(_0{}, i); + } + + + for (int i = 0; i < size<1>(src_v); i++) { + auto src_v_i = make_tensor( + make_gmem_ptr(src_v_ptrs(i)), + make_shape(shape<0>(src_v)), + make_stride(make_stride(_1{}, _0{})) + ); + atom.call(src_v_i, dst_v(_, i)); + } + }; + + auto tiled_copy_q = make_copy_for(sQ); + auto tiled_copy_kc = make_copy_for(sKC); + auto tiled_copy_vc = make_copy_for(sVC); + + auto thr_copy_q = tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_kc = tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_vc = tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQL = thr_copy_q.partition_S(tSgQL); + auto tQgQR = thr_copy_q.partition_S(tSgQR); + + auto tKCsKC = thr_copy_kc.partition_D(sKC); + auto tVCsVC = thr_copy_vc.partition_D(sVC); + + auto pipeline_pt_release_state = pipeline_pt_consumer_state; + + int page_table_stage = -1; + Pow2 pages_per_tile{TileShapeS{} / paged_K}; + const int * __restrict__ smem_page_table = shared_tensors.smem_page_table.begin(); + Gather gather{page_table_stage, pages_per_tile, smem_page_table}; + + auto mCL = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))), get<1>(mainloop_args.stride_c_latent))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mKR = make_tensor( + make_gmem_ptr(mainloop_args.ptr_k_rope), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), example::CustomStride(gather, get<2>(mainloop_args.stride_k_rope))), get<1>(mainloop_args.stride_k_rope))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mCLT = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(_1{}, make_shape(paged_K, paged_B)), + make_stride(get<1>(mainloop_args.stride_c_latent), make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(D_latent, paged_K * paged_B))}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + auto tKCgCL = thr_copy_kc.partition_S(tSgCL); + auto tKCgKR = thr_copy_kc.partition_S(tSgKR); + auto tVCgCLT = thr_copy_vc.partition_S(tOgCLT); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + auto& pipeline_acquire_state = pipeline_load_producer_state; + auto pipeline_commit_state = pipeline_acquire_state; + int pipeline_offset = 0; + + for (int i = 0; i < StagesPV; i++) { + cutlass::arch::cp_async_fence(); + } + + auto load_stage = [&](auto fn) { + pipeline_load.producer_acquire(pipeline_acquire_state); + fn(pipeline_acquire_state.index()); + cutlass::arch::cp_async_fence(); + + ++pipeline_acquire_state; + ++pipeline_offset; + + if (pipeline_offset == StagesPV - 1) { + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + }; + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i)); + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, IterationsQKLatent + i)); + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + k_index += 1; + k_tile_count -= 1; + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + while (pipeline_offset > 0) { + cutlass::arch::cp_async_fence(); + + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + + cutlass::arch::cp_async_wait<0>(); + + } + + + template + CUTLASS_DEVICE void load_tma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_producer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + using X = Underscore; + + // partition all tensors + auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B)); + auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B)); + + int paged_B = B; + int paged_K = K; + if constexpr (kIsPaged) { + paged_B = mainloop_args.page_count; + paged_K = mainloop_args.page_size; + } + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B)); + auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B)); + + auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor(make_shape(D_latent, paged_K, paged_B)); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto [tQLgQL_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQL)); + + auto [tQRgQR_mkl, tQsQ_ignore] = tma_partition( + mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQR)); + + auto [tCLgCL_nkl, tKCsKC] = tma_partition( + mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgCL)); + + auto [tKRgKR_nkl, tKCsKC_ignore] = tma_partition( + mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgKR)); + + auto [tCLTgCLT_nkl, tVCsVC] = tma_partition( + mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}), + group_modes<0,3>(sVC), group_modes<0,3>(tOgCLT)); + + uint16_t mcast_mask = 0; + + int batch_coord = get<2>(blk_coord); + Tensor tQLgQL = tQLgQL_mkl(_, _, _, batch_coord); + Tensor tQRgQR = tQRgQR_mkl(_, _, _, batch_coord); + + auto mPT = mPT_l(_, batch_coord); + + Tensor tCLgCL = tCLgCL_nkl(_, _, _, _); + Tensor tKRgKR = tKRgKR_nkl(_, _, _, _); + + // careful: stage and k are swapped here! + Tensor tCLTgCLT = tCLTgCLT_nkl(_, _, _, _); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), tQLgQL(_, _0{}, i), tQsQ(_, i)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + // perform K load + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + // prefetch next K load to keep busy while we transpose-load from cache + const int kPrefetchDistance = 1; + for (int i = 0; i < IterationsQKLatent; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + for (int i = 0; i < IterationsQKRope; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + // perform V load (k_idx - 1) + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices! + // note we are off-by-one on k_index + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + + k_index += 1; + k_tile_count -= 1; + } + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices + // note we are off-by-one on k_index + + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + } + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_consumer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_consumer_state, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_producer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_consumer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // mma init + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}); + + Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ); + Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC); + Tensor tOrP = TiledMmaPV::make_fragment_A(sP); + Tensor tOrVC = TiledMmaPV::make_fragment_B(sVC); + + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero; + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + + // Mma S0 S1 O0 S2 O1 ... Sn On-1 On + // S0 ownership -- ----- -- -- + // S1 ownership -- ----- ---- + // O ownership -- -- ---- -- + + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + + --k_tile_count; + } + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + } + + + template + CUTLASS_DEVICE void softmax( + IsLastTile const& is_last_tile, + ElementAcc& row_max, + ElementAcc& row_sum, + ElementAcc& correction_factor, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + int k_index, + uint32_t tmem_s, + int smem_p_index) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaQK tiled_mma_qk; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tStS.data() = tmem_s; + + CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{}); + Tensor tAcc = tStS(make_coord(_,_),_0{},_0{}); + + Tensor cS = make_identity_tensor(take<0,2>(CtaShapeQK{})); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_cS = thread_t2r.partition_D(cS); + Tensor tTR_rAcc = make_tensor(shape(tTR_cS)); + + Tensor tTR_rS_frag = make_tensor(shape(tTR_rAcc)); + const int AlignmentS = 4; + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + Tensor tTR_rAcc_vec = recast>(tTR_rAcc); + Tensor tTR_rS_vec = recast>(tTR_rS_frag); + + // load s + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + if (is_last_tile) { + for (int i = 0; i < size(tTR_rAcc); i++) { + if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) { + tTR_rAcc(i) = -std::numeric_limits::infinity(); + } + } + } + + // max + ElementAcc row_max_new = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 1) { + row_max_new = ::fmax(row_max_new, tTR_rAcc(i)); + } + + // for 2x2 dp, reduce here + if constexpr (kWarpsInN > 1) { + shared_tensors.smem_exchange[threadIdx.x] = row_max_new; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); + } + +#ifndef B2B + // find correction factor + ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast(M_LOG2E); + correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new)); + row_max = row_max_new; + + // softmax + ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); + } +#endif + + // quantize + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_vec); i++) { + tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i)); + } + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})(_, _, _, make_coord(_, smem_p_index)); + + Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS); + + // have a mapping for each thread to coord + // find identical mapping to coords for the MMA + auto l = make_ordered_layout(make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); + auto sP_ = as_position_independent_swizzle_tensor(sP); + copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _)); + + // sum + row_sum *= correction_factor; + + static_assert(cute::is_same_v); + auto tTR_rAcc_float2 = recast(tTR_rAcc); + auto sums = make_tensor(_4{}); + static_assert(size(tTR_rAcc_float2) % size(sums) == 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(sums); i++) { + sums(i) = tTR_rAcc_float2(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = size(sums); i < size(tTR_rAcc_float2); i += size(sums)) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j++) { + cute::add(sums(j), sums(j), tTR_rAcc_float2(i + j)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < size(sums); i *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j += 2*i) { + cute::add(sums(j), sums(j), sums(j+i)); + } + } + row_sum += sums(0).x + sums(0).y; + } + + + CUTLASS_DEVICE void rescale( + ElementAcc correction_factor, + uint32_t tmem_o) { + + // for b2b gemm, do nothing +#ifndef B2B + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + auto store_op = TMEM::tmem_load_to_store(load_op); + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*) nullptr), cta_tiler_pv, make_stride(0, 0)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto tiled_r2t = make_tmem_copy(store_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + auto thread_r2t = tiled_r2t.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + // load o + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + // multiply by correction factor + float2 correction_factor_vec = make_float2(correction_factor, correction_factor); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 2) { + float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1)); + float2 out; + cute::mul(out, in, correction_factor_vec); + tTR_rAcc(i + 0) = out.x; + tTR_rAcc(i + 1) = out.y; + } + + // store o + copy(tiled_r2t, tTR_rAcc, tTR_tAcc); +#endif + } + + + template + CUTLASS_DEVICE void epilogue( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + uint32_t tmem_o, + int const& split_kv) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + if (epilogue_args.ptr_o_acc != nullptr) { + using ElementOutAcc = ElementAcc; + constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + #endif + } + else { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + } +#endif + } + } + + + template + CUTLASS_DEVICE void compute( + CtaCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_consumer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_producer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_consumer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + + // if we return early, we have to make sure we release the load warp + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue + ).arrive(); + + return; + } + int k_index_final = k_tile_total - 1; + + ElementAcc row_max = -std::numeric_limits::infinity(); + ElementAcc row_sum = 0; + ElementAcc correction_factor = 1; + + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + // softmax s0 -> p0 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + k_index += 1; + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + // softmax s1 -> p1 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + // rescale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + rescale(correction_factor, uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO)); + } + + cutlass::arch::fence_view_async_tmem_store(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + + --k_tile_count; + k_index += 1; + } + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + +#ifdef B2B + row_sum = 1; +#else + if constexpr (kWarpsInN > 1) { + // reduce row_sum if needed (for 2x2 dp) + shared_tensors.smem_exchange[threadIdx.x] = row_sum; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_sum += shared_tensors.smem_exchange[peer_index]; + } +#endif + + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive(); + + // epilogue + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + epilogue( + row_max, row_sum, + replace<1>(cta_coord, j), problem_shape, + mainloop_args, epilogue_args, shared_tensors, + uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv + ); + } + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp b/sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp new file mode 100644 index 000000000..30389e79f --- /dev/null +++ b/sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp @@ -0,0 +1,160 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// clang-format off +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaIndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z); + } + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_split_kv; + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = size<0>(cluster_shape); + int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */; + num_blocks *= split_kv; /* Maximum Split KV*/ + + return Params { + num_blocks, + { num_m_blocks}, { get<3>(problem_shape) }, {split_kv}, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, n_split_kv; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_split_kv(block_decode, n_split_kv, block_decode); + return make_coord(m_block, _0{}, bidb, n_split_kv); + } + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 1bc227197..f0f78f8be 100755 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2); m.def( "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor " - "page_table, Tensor workspace) -> ()"); + "page_table, Tensor! workspace, int num_kv_splits) -> ()"); m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index dbd969ea4..7b939b67f 100755 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -109,8 +109,10 @@ void cutlass_mla_decode( torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, - torch::Tensor const& workspace); -int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0); + torch::Tensor const& workspace, + int64_t num_kv_splits = -1); +int64_t cutlass_mla_get_workspace_size( + int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, int64_t num_kv_splits = -1); /* * From csrc/elementwise */ diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index 749462ccd..2ece6abdd 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -57,6 +57,7 @@ def cutlass_mla_decode( seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, + num_kv_splits: int = -1, ) -> torch.Tensor: assert ( q_nope_and_q_pe.ndim == 3 @@ -73,7 +74,12 @@ def cutlass_mla_decode( f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, " f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}" ) - assert H == 128, f"H must be 128, but got {H}" + MAX_HEADS = 128 + assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" + if H < MAX_HEADS: + q_nope_and_q_pe_padded = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_q)) + q_nope_and_q_pe_padded[:, :H] = q_nope_and_q_pe + q_nope_and_q_pe = q_nope_and_q_pe_padded assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape @@ -97,21 +103,25 @@ def cutlass_mla_decode( page_table.dtype == torch.int32 ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." - out = torch.empty( - (B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype - ) + out = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_latent)) torch.ops.sgl_kernel.cutlass_mla_decode.default( - out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace + out, + q_nope_and_q_pe, + kv_c_and_k_pe_cache, + seq_lens, + page_table, + workspace, + num_kv_splits, ) - return out + return out[:, :H].contiguous() def cutlass_mla_get_workspace_size( - max_seq_len: int, num_batches: int, sm_count: int = 0 + max_seq_len: int, num_batches: int, sm_count: int = 0, num_kv_splits: int = -1 ) -> int: assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}" assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}" return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default( - max_seq_len, num_batches, sm_count + max_seq_len, num_batches, sm_count, num_kv_splits ) diff --git a/sgl-kernel/tests/test_cutlass_mla.py b/sgl-kernel/tests/test_cutlass_mla.py index d6e506490..22f850af7 100644 --- a/sgl-kernel/tests/test_cutlass_mla.py +++ b/sgl-kernel/tests/test_cutlass_mla.py @@ -40,15 +40,23 @@ def ref_mla( @pytest.mark.parametrize("bs", [1, 2, 4]) @pytest.mark.parametrize("varlen", [False, True]) @pytest.mark.parametrize("block_size", [1, 16, 64, 128]) +@pytest.mark.parametrize("num_heads", [16, 32, 64, 128]) +@pytest.mark.parametrize("num_kv_splits", [-1, 1]) def test_cutlass_mla_decode( - dtype: torch.dtype, mean_seq_len: int, bs: int, varlen: bool, block_size: int + dtype: torch.dtype, + mean_seq_len: int, + bs: int, + varlen: bool, + block_size: int, + num_heads: int, + num_kv_splits: int, ): torch.set_default_dtype(dtype) torch.set_default_device("cuda") torch.manual_seed(42) d = 576 - h_q = 128 + h_q = num_heads dv = 512 q_nope_dim = 128 @@ -67,17 +75,22 @@ def test_cutlass_mla_decode( pack_factor = 128 // block_size block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor + # Lager q values to detect split kv error q = torch.randn(bs, h_q, d) * 100.0 block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32) kv_cache = torch.randn(block_table.numel(), block_size, d) - workspace_size = cutlass_mla_get_workspace_size(block_num * block_size, bs) + workspace_size = cutlass_mla_get_workspace_size( + block_num * block_size, bs, num_kv_splits=num_kv_splits + ) workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) out_ref = q.new_zeros(bs, h_q, dv) ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) - out = cutlass_mla_decode(q, kv_cache, seq_lens, block_table, workspace) + out = cutlass_mla_decode( + q, kv_cache, seq_lens, block_table, workspace, num_kv_splits + ) torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)