diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 85e247452..3a5eb6ec3 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -1,10 +1,13 @@ cmake_minimum_required(VERSION 3.26 FATAL_ERROR) project(sgl-kernel LANGUAGES CXX CUDA) +# utils +include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) +include(FetchContent) + # CMake cmake_policy(SET CMP0169 OLD) cmake_policy(SET CMP0177 NEW) -include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) set(CMAKE_COLOR_DIAGNOSTICS ON) set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON") set(CMAKE_POSITION_INDEPENDENT_CODE ON) @@ -37,11 +40,9 @@ endif() # Torch find_package(Torch REQUIRED) -# clean Torch Flag clear_cuda_arches(CMAKE_FLAG) -include(FetchContent) - +# Third Party # cutlass FetchContent_Declare( repo-cutlass @@ -69,7 +70,7 @@ FetchContent_Declare( ) FetchContent_Populate(repo-fmt) -# Triton +# Triton kernel FetchContent_Declare( repo-triton GIT_REPOSITORY "https://github.com/triton-lang/triton" @@ -143,12 +144,6 @@ endif() include_directories( ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/csrc - ${repo-cutlass_SOURCE_DIR}/include - ${repo-cutlass_SOURCE_DIR}/tools/util/include - ${repo-flashinfer_SOURCE_DIR}/include - ${repo-flashinfer_SOURCE_DIR}/csrc - ${repo-mscclpp_SOURCE_DIR}/include - ${repo-fast-hadamard-transform}/csrc ) set(SGL_KERNEL_CUDA_FLAGS @@ -350,6 +345,7 @@ set(SOURCES "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp" ) +# =========================== Common SM90 Build ============================= # # Build SM90 library with fast math optimization (same namespace, different directory) Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) @@ -360,7 +356,11 @@ target_compile_options(common_ops_sm90_build PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math> ) target_include_directories(common_ops_sm90_build PRIVATE - ${PROJECT_SOURCE_DIR}/csrc + ${repo-cutlass_SOURCE_DIR}/include + ${repo-cutlass_SOURCE_DIR}/tools/util/include + ${repo-flashinfer_SOURCE_DIR}/include + ${repo-flashinfer_SOURCE_DIR}/csrc + ${repo-mscclpp_SOURCE_DIR}/include ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha ${repo-cutlass_SOURCE_DIR}/examples/common ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src @@ -371,6 +371,7 @@ set_target_properties(common_ops_sm90_build PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm90" ) +# =========================== Common SM100+ Build ============================= # # Build SM100+ library with precise math (same namespace, different directory) Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) @@ -381,7 +382,11 @@ target_compile_options(common_ops_sm100_build PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}> ) target_include_directories(common_ops_sm100_build PRIVATE - ${PROJECT_SOURCE_DIR}/csrc + ${repo-cutlass_SOURCE_DIR}/include + ${repo-cutlass_SOURCE_DIR}/tools/util/include + ${repo-flashinfer_SOURCE_DIR}/include + ${repo-flashinfer_SOURCE_DIR}/csrc + ${repo-mscclpp_SOURCE_DIR}/include ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha ${repo-cutlass_SOURCE_DIR}/examples/common ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src @@ -408,7 +413,7 @@ else() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") endif() -# mscclpp +# mscclpp option set(MSCCLPP_USE_CUDA ON) set(MSCCLPP_BYPASS_GPU_CHECK ON) set(MSCCLPP_BUILD_TESTS OFF) @@ -419,7 +424,7 @@ add_subdirectory( target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) -# flash attention +# sparse flash attention target_compile_definitions(common_ops_sm90_build PRIVATE FLASHATTENTION_DISABLE_BACKWARD FLASHATTENTION_DISABLE_DROPOUT @@ -506,6 +511,8 @@ if (SGL_KERNEL_ENABLE_FA3) target_compile_options(flash_ops PRIVATE $<$:${SGL_FLASH_KERNEL_CUDA_FLAGS}>) target_include_directories(flash_ops PRIVATE + ${repo-cutlass_SOURCE_DIR}/include + ${repo-cutlass_SOURCE_DIR}/tools/util/include ${repo-flash-attention_SOURCE_DIR}/hopper ) target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) @@ -535,6 +542,8 @@ target_compile_options(spatial_ops PRIVATE $<$:${SGL_KERN target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel) +# ============================ Extra Install ============================= # +include(${CMAKE_CURRENT_LIST_DIR}/cmake/flashmla.cmake) # ============================ DeepGEMM (JIT) ============================= # # Create a separate library for DeepGEMM's Python API. diff --git a/sgl-kernel/cmake/flashmla.cmake b/sgl-kernel/cmake/flashmla.cmake new file mode 100644 index 000000000..50f1c68dd --- /dev/null +++ b/sgl-kernel/cmake/flashmla.cmake @@ -0,0 +1,60 @@ +include(FetchContent) + +# flash_mla +FetchContent_Declare( + repo-flashmla + GIT_REPOSITORY https://github.com/sgl-project/FlashMLA + GIT_TAG bc8576abc3e507425cf6498f3d3393df7733ce37 + GIT_SHALLOW OFF +) +FetchContent_Populate(repo-flashmla) + +set(FLASHMLA_CUDA_FLAGS + "--expt-relaxed-constexpr" + "--expt-extended-lambda" + "--use_fast_math" +) + +# The FlashMLA kernels only work on hopper and require CUDA 12.4 or later. +# Only build FlashMLA kernels if we are building for something compatible with +# sm90a +if(${CUDA_VERSION} VERSION_GREATER 12.4) + list(APPEND FLASHMLA_CUDA_FLAGS + "-gencode=arch=compute_90a,code=sm_90a" + ) +endif() +if(${CUDA_VERSION} VERSION_GREATER 12.8) + list(APPEND FLASHMLA_CUDA_FLAGS + "-gencode=arch=compute_100a,code=sm_100a" + ) +endif() + + +set(FlashMLA_SOURCES + "csrc/flashmla_extension.cc" + ${repo-flashmla_SOURCE_DIR}/csrc/python_api.cpp + ${repo-flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu + ${repo-flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu +) + +Python_add_library(flashmla_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FlashMLA_SOURCES}) +target_compile_options(flashmla_ops PRIVATE $<$:${FLASHMLA_CUDA_FLAGS}>) +target_include_directories(flashmla_ops PRIVATE + ${repo-flashmla_SOURCE_DIR}/csrc + ${repo-flashmla_SOURCE_DIR}/csrc/sm90 + ${repo-flashmla_SOURCE_DIR}/csrc/cutlass/include + ${repo-flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include +) + +target_link_libraries(flashmla_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) + +install(TARGETS flashmla_ops LIBRARY DESTINATION "sgl_kernel") + +target_compile_definitions(flashmla_ops PRIVATE) diff --git a/sgl-kernel/csrc/flashmla_extension.cc b/sgl-kernel/csrc/flashmla_extension.cc new file mode 100644 index 000000000..e72cfa11c --- /dev/null +++ b/sgl-kernel/csrc/flashmla_extension.cc @@ -0,0 +1,46 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "sgl_kernel_ops.h" + +TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { + /* + * From FlashMLA + */ + m.def( + "get_mla_decoding_metadata(Tensor seqlens_k, int num_q_tokens_per_head_k, int h_k, int? h_q, bool " + "is_fp8_kvcache, int? topk) -> Tensor[]"); + m.impl("get_mla_decoding_metadata", torch::kCUDA, &get_mla_decoding_metadata); + + m.def( + "fwd_kvcache_mla(Tensor q, Tensor kv_cache, int head_size_v, Tensor seqlens_k, Tensor block_table, float " + "softmax_scale, bool is_causal, Tensor tile_scheduler_metadata, Tensor num_splits, bool is_fp8, Tensor? indices) " + "-> Tensor[]"); + m.impl("fwd_kvcache_mla", torch::kCUDA, &fwd_kvcache_mla); + + m.def( + "dense_prefill_fwd(Tensor workspace_buffer, Tensor q, Tensor k, Tensor v, Tensor cumulative_seqlen_q, Tensor " + "cumulative_seqlen_kv, Tensor o, Tensor lse, int mask_mode_code, float softmax_scale, int max_seqlen_q, int " + "max_seqlen_kv, bool is_varlen) -> ()"); + m.impl("dense_prefill_fwd", torch::kCUDA, &FMHACutlassSM100FwdRun); + + m.def("sparse_prefill_fwd(Tensor q, Tensor kv, Tensor indices, float sm_scale, int d_v) -> Tensor[]"); + m.impl("sparse_prefill_fwd", torch::kCUDA, &sparse_prefill_fwd); +} + +REGISTER_EXTENSION(flashmla_ops) diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 6b095069c..6be8af703 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -842,6 +842,7 @@ void es_fp8_blockwise_scaled_grouped_mm( const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, const torch::Tensor& workspace); + /* * From fast-hadamard-transform */ @@ -850,3 +851,47 @@ torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale); torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale); torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale); torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale); + +/* + * From csrc/fastertransformer + */ +std::vector get_mla_decoding_metadata( + at::Tensor& seqlens_k, + const int64_t num_q_tokens_per_head_k, + const int64_t h_k, + const std::optional h_q, + const bool is_fp8_kvcache, + const std::optional topk); + +std::vector fwd_kvcache_mla( + at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or + // num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) + const int64_t head_size_v, + const at::Tensor& seqlens_k, // batch_size + const at::Tensor& block_table, // batch_size x max_num_blocks_per_seq + const double softmax_scale, + bool is_causal, + const at::Tensor& tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor& num_splits, // batch_size + 1 + const bool& is_fp8, + const std::optional& indices // None, or batch_size x seqlen_q x topk +); + +void FMHACutlassSM100FwdRun( + at::Tensor workspace_buffer, + at::Tensor q, + at::Tensor k, + at::Tensor v, + at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, + at::Tensor o, + at::Tensor lse, + int64_t mask_mode_code, + double softmax_scale, + int64_t max_seqlen_q, + int64_t max_seqlen_kv, + bool is_varlen); + +std::vector +sparse_prefill_fwd(const at::Tensor& q, const at::Tensor& kv, const at::Tensor& indices, double sm_scale, int64_t d_v); diff --git a/sgl-kernel/python/sgl_kernel/flash_mla.py b/sgl-kernel/python/sgl_kernel/flash_mla.py new file mode 100644 index 000000000..614c54e2e --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/flash_mla.py @@ -0,0 +1,126 @@ +from typing import Optional, Tuple + +import torch + +try: + from . import flashmla_ops # triggers TORCH extension registration +except Exception as _e: + _flashmla_import_error = _e +else: + _flashmla_import_error = None + +_IMPORT_ERROR = ImportError( + "Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4" +) + + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_q_tokens_per_head_k: int, + num_heads_k: int, + num_heads_q: Optional[int] = None, + is_fp8_kvcache: bool = False, + topk: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. + num_heads_k: The number of k heads. + num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled + is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. + topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. + + Returns: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops.sgl_kernel.get_mla_decoding_metadata.default( + cache_seqlens, + num_q_tokens_per_head_k, + num_heads_k, + num_heads_q, + is_fp8_kvcache, + topk, + ) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head dimension of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. + softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md + indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. + + Returns: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if indices is not None: + assert causal == False, "causal must be `false` if sparse attention is enabled." + out, softmax_lse = torch.ops.sgl_kernel.fwd_kvcache_mla.default( + q, + k_cache, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + is_fp8_kvcache, + indices, + ) + return out, softmax_lse + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + + Returns: + (output, max_logits, lse) + About the definition of output, max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp + """ + results = torch.ops.sgl_kernel.sparse_prefill_fwd.default( + q, kv, indices, sm_scale, d_v + ) + return results diff --git a/sgl-kernel/tests/test_flashmla.py b/sgl-kernel/tests/test_flashmla.py new file mode 100644 index 000000000..b6c049999 --- /dev/null +++ b/sgl-kernel/tests/test_flashmla.py @@ -0,0 +1,518 @@ +import math +import random +from typing import Optional, Tuple + +import pytest +import torch +import triton +from sgl_kernel.flash_mla import ( + flash_mla_sparse_fwd, + flash_mla_with_kvcache, + get_mla_metadata, +) + +skip_condition = torch.cuda.get_device_capability() < (10, 0) + +# ================ prefill usage ================ # +S_Q_PREFILL = [1, 62] +KV_TOPK_PREFILL = [ + # Regular shapes + (128, 128), + (256, 256), + (512, 512), + # Irregular shapes + (592, 128), + (1840, 256), + (1592, 384), + (1521, 512), + # Irregular shapes with OOB TopK + (95, 128), + (153, 256), + (114, 384), +] + +# ================= decode usage ================= # +B_DECODE = [1, 2, 6, 64] +S_Q_DECODE = [1, 2, 4] +S_K_DECODE = [20, 140, 4096] +IS_VARLEN = [False, True] +CAUSAL_TOPK = [(True, None), (False, None), (False, 128), (False, 2048)] +DTYPE = [torch.float16, torch.bfloat16] + + +def quantize_k_cache( + input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d) + dv: int, + tile_size: int = 128, +) -> torch.Tensor: + """ + Quantize the k-cache + Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size() + For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md + """ + assert dv % tile_size == 0 + num_tiles = dv // tile_size + num_blocks, block_size, h_k, d = input_k_cache.shape + assert h_k == 1 + input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] + input_elem_size = input_k_cache.element_size() + + result = torch.empty( + (num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), + dtype=torch.float8_e4m3fn, + device=input_k_cache.device, + ) + result_k_nope_part = result[..., :dv] + result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32) + result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype) + result_k_rope_part[:] = input_k_cache[..., dv:] + + for tile_idx in range(0, num_tiles): + cur_scale_factors_inv = ( + torch.abs( + input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] + ) + .max(dim=-1) + .values + / 448.0 + ) # [num_blocks, block_size] + result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv + + cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] + cur_quantized_nope = ( + input_k_cache[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].float() + / cur_scale_factors_inv.float() + ).to(torch.float8_e4m3fn) + result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( + cur_quantized_nope + ) + + result = result.view(num_blocks, block_size, 1, -1) + return result + + +def dequantize_k_cache( + quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token) + dv: int = 512, + tile_size: int = 128, + d: int = 576, +) -> torch.Tensor: + """ + De-quantize the k-cache + """ + assert dv % tile_size == 0 + num_tiles = dv // tile_size + num_blocks, block_size, h_k, _ = quant_k_cache.shape + assert h_k == 1 + result = torch.empty( + (num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device + ) + + quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) + + input_nope = quant_k_cache[..., :dv] + input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32) + input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16) + result[..., dv:] = input_rope + + for tile_idx in range(0, num_tiles): + cur_nope = input_nope[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].to(torch.float32) + cur_scales = input_scale[..., tile_idx].unsqueeze(-1) + result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( + cur_nope * cur_scales + ) + + result = result.view(num_blocks, block_size, 1, d) + return result + + +def cdiv(x: int, y: int): + return (x + y - 1) // y + + +def get_window_size(causal, window): + if window > 0: + window_size = (window - 1, 0) if causal else (window - 1, window - 1) + else: + window_size = (-1, -1) + return window_size + + +def get_attn_bias(s_q, s_k, causal, window): + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32, device="cuda") + if causal: + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril( + diagonal=s_k - s_q + ) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + if window > 0: + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril( + diagonal=s_k - s_q - window + ) + attn_bias.masked_fill_(temp_mask, float("-inf")) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril( + diagonal=s_k - s_q + window - 1 + ) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + return attn_bias + + +def sdpa(query, key, value, attn_bias, softmax_scale=None): + query = query.float().transpose(-3, -2) + key = key.float().transpose(-3, -2) + value = value.float().transpose(-3, -2) + key = key.repeat_interleave(h // h_k, dim=-3) + value = value.repeat_interleave(h // h_k, dim=-3) + if softmax_scale is None: + softmax_scale = query.shape[-1] ** (-0.5) + attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight.to(query.dtype) @ value, lse + + +def sdpa_checkpoint(*args, **kwargs): + return checkpoint(sdpa, *args, use_reentrant=False, **kwargs) + + +def reference_torch_prefill( + s_q, s_kv, topk, indices, q, kv, sm_scale: float +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: + return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) + + indices = indices[0, :, 0, :] # [s_q, topk] + invalid_indices_mask = (indices < 0) | (indices >= s_kv) + qs = q[0, :, :, :].float() # [s_q, h_q, d_qk] + kvs = kv[0, :, 0, :].float() # [s_kv, d_qk] + + kvs = torch.index_select( + kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten() + ).view( + s_q, topk, 576 + ) # [s_q, topk, d_qk] + attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk] + attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf")) + attn_score *= sm_scale * math.log2(math.e) + max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q] + lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] + attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk] + result = attn_score @ kvs[:, :, :512] + return (max_logits, lse, result) + + +def reference_torch_decode( + cache_seqlens: torch.Tensor, # [batch_size] + block_table: torch.Tensor, # [batch_size, ?] + q: torch.Tensor, # [batch_size, s_q, h_q, d] + blocked_k: torch.Tensor, # [?, block_size, h_kv, d] + dv: int, + is_causal: bool, + indices: Optional[torch.Tensor] = None, # [batch_size, s_q, topk] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation in PyTorch + """ + + def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): + mask = torch.zeros(s_q, s_k, dtype=torch.bool, device="cuda") + for i in range(s_q): + cur_indices = indices[i] + valid_indices = cur_indices[cur_indices != -1] + mask[i, valid_indices] = True + return mask + + def scaled_dot_product_attention( + batch_idx: int, + query: torch.Tensor, # [h_q, s_q, d] + kv: torch.Tensor, # [h_kv, s_k, d] + dv: int, + is_causal, + indices: Optional[torch.Tensor], # [s_q, topk] + ) -> Tuple[torch.Tensor, torch.Tensor]: + h_q = query.size(0) + h_kv = kv.size(0) + s_q = query.shape[-2] + s_k = kv.shape[-2] + query = query.float() + kv = kv.float() + if h_kv != 1: + kv = kv.repeat_interleave(h_q // h_kv, dim=0) + kv[kv != kv] = 0.0 + attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] + if (is_causal and query.size(1) > 1) or indices is not None: + mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda") + if is_causal: + assert indices is None + mask = mask.tril(diagonal=s_k - s_q) + if indices is not None: + mask &= get_topk_attn_mask(s_q, s_k, indices) + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float, device="cuda") + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + attn_weight += attn_bias.to(q.dtype) + attn_weight /= math.sqrt(query.size(-1)) + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] + # Correct for q tokens which has no attendable k + lonely_q_mask = lse == float("-inf") + output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output, lse + + b, s_q, h_q, d = q.size() + block_size = blocked_k.size(1) + h_kv = blocked_k.size(2) + cache_seqlens_cpu = cache_seqlens.cpu() + out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device="cuda") + lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32, device="cuda") + for i in range(b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, block_size) + cur_block_indices = block_table[i][0:cur_num_blocks] + cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] + cur_out, cur_lse = scaled_dot_product_attention( + i, + q[i].transpose(0, 1), + cur_kv.transpose(0, 1), + dv, + is_causal, + indices[i] if indices is not None else None, + ) + out_ref[i] = cur_out.transpose(0, 1) + lse_ref[i] = cur_lse + out_ref = out_ref.to(torch.bfloat16) + return out_ref, lse_ref + + +@pytest.mark.parametrize("s_q", S_Q_PREFILL) +@pytest.mark.parametrize("kv_topk", KV_TOPK_PREFILL) +@torch.inference_mode() +def test_flashmla_prefill( + s_q: int, + kv_topk: Tuple[int, int], +): + + torch.cuda.empty_cache() + + q = torch.randn((1, s_q, 128, 576), dtype=torch.bfloat16, device="cuda") / 10 + kv = torch.randn((1, kv_topk[0], 1, 576), dtype=torch.bfloat16, device="cuda") / 10 + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full( + (1, s_q, 1, kv_topk[1]), kv_topk[0], dtype=torch.int32, device="cuda" + ) + for s in range(s_q): + # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention + near_mask = ( + torch.randint(0, 32, (min(kv_topk[1], kv_topk[0]),), device="cuda") < 31 + ) + cur_indices = torch.randperm(kv_topk[0], device="cuda")[: kv_topk[1]] + cur_indices[near_mask] = torch.randint( + max(0, kv_topk[0] - 20000), + kv_topk[0] - 1, + (near_mask.sum().item(),), + device="cuda", + ) + if len(cur_indices) < kv_topk[1]: + cur_indices = torch.cat( + [ + cur_indices, + torch.full( + (kv_topk[1] - len(cur_indices),), 2147480000, device="cuda" + ), + ] + ) + cur_indices = cur_indices[torch.randperm(kv_topk[1], device="cuda")] + indices[0, s, 0] = cur_indices + indices = indices.to(q.device) + + sm_scale = 1 / math.sqrt(576) + torch.cuda.synchronize() + + ans_out, ans_max_logits, ans_lse = flash_mla_sparse_fwd( + q.squeeze(0), kv.squeeze(0), indices.squeeze(0), sm_scale=sm_scale + ) + + ans_out, ans_max_logits, ans_lse = ( + ans_out.float(), + ans_max_logits.float(), + ans_lse.float(), + ) + + torch.cuda.synchronize() + ref_max_logits, ref_lse, ref_out = reference_torch_prefill( + s_q, kv_topk[0], kv_topk[1], indices, q, kv, sm_scale + ) + torch.cuda.synchronize() + + torch.testing.assert_close(ans_out, ref_out, atol=8e-4, rtol=2.01 / 128) + torch.testing.assert_close( + ans_max_logits, + ref_max_logits, + atol=1e-6, + rtol=2.01 / 65536, + ) + torch.testing.assert_close(ans_lse, ref_lse, atol=1e-6, rtol=2.01 / 65536) + + +@pytest.mark.parametrize("b", B_DECODE) +@pytest.mark.parametrize("s_q", S_Q_DECODE) +@pytest.mark.parametrize("s_k", S_K_DECODE) +@pytest.mark.parametrize("is_varlen", IS_VARLEN) +@pytest.mark.parametrize("causal_topk", CAUSAL_TOPK) +@pytest.mark.parametrize("dtype", DTYPE) +@torch.inference_mode() +def test_flash_mla_decode( + b: int, + s_q: int, + s_k: int, + is_varlen: bool, + causal_topk: Tuple[bool, Optional[int]], + dtype: torch.dtype, +): + d = 576 + dv = 512 + block_size = 64 + h_q = 128 + h_kv = 1 + is_causal = causal_topk[0] + topk = causal_topk[1] + + # Generating test data + torch.cuda.synchronize() + + cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device="cpu") + if is_varlen: + for i in range(b): + cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), s_q) + + max_seqlen = cache_seqlens_cpu.max().item() + max_seqlen_pad = cdiv(max_seqlen, 256) * 256 + cache_seqlens = cache_seqlens_cpu.cuda() + + q = torch.randn(b, s_q, 128, d, dtype=torch.bfloat16, device="cuda") + q.clamp_(min=-1.0, max=1.0) + + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32, device="cuda" + ).view(b, max_seqlen_pad // block_size) + block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1) + blocked_k = ( + torch.randn( + block_table.numel(), + block_size, + h_kv, + d, + dtype=torch.bfloat16, + device="cuda", + ) + / 10 + ) + blocked_k.clamp_(min=-1.0, max=1.0) + + if topk is None: + for i in range(b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, block_size) + blocked_k[block_table[i][cur_num_blocks:]] = float("nan") + if cur_len % block_size != 0: + blocked_k[block_table[i][cur_num_blocks - 1]][ + cur_len % block_size : + ] = float("nan") + block_table[i][cur_num_blocks:] = 2147480000 + abs_indices = None + indices_in_kvcache = None + else: + block_table_cpu = block_table.cpu() + abs_indices = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu") + indices_in_kvcache = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu") + for i in range(b): + # Generate indices + for j in range(s_q): + cur_abs_indices = torch.randperm( + int(cache_seqlens_cpu[i].item()), device="cpu" + )[:topk] + cur_blocked_indices = block_table_cpu[ + i, cur_abs_indices // block_size + ] * block_size + (cur_abs_indices % block_size) + if len(cur_abs_indices) < topk: + pad_len = topk - len(cur_abs_indices) + cur_abs_indices = torch.cat( + [cur_abs_indices, torch.full((pad_len,), -1, device="cpu")] + ) + cur_blocked_indices = torch.cat( + [cur_blocked_indices, torch.full((pad_len,), -1, device="cpu")] + ) + + # Mask KV + perm = torch.randperm(topk, device="cpu") + cur_abs_indices = cur_abs_indices[perm] + cur_blocked_indices = cur_blocked_indices[perm] + + abs_indices[i, j, :] = cur_abs_indices + indices_in_kvcache[i, j, :] = cur_blocked_indices + + # Mask nonused KV as NaN + all_indices = indices_in_kvcache.flatten().tolist() + all_indices = list(set(all_indices)) + if -1 in all_indices: + all_indices.remove(-1) + all_indices = torch.tensor(all_indices, dtype=torch.int32, device="cpu") + + blocked_k = blocked_k.view(-1, h_kv, d) + nonused_indices_mask = torch.ones( + blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device="cpu" + ) + nonused_indices_mask[all_indices] = False + blocked_k[nonused_indices_mask, :, :] = float("nan") + blocked_k = blocked_k.view(-1, block_size, h_kv, d) + + abs_indices = abs_indices.to(q.device) + indices_in_kvcache = indices_in_kvcache.to(q.device) + + is_fp8 = topk is not None + if is_fp8: + # The quantization error may be too large to be distinguished from wrong kernels + # So we quantize and de-quantize kv-cache here to mitigate quantization error + blocked_k_quantized = quantize_k_cache(blocked_k, dv, 128) + blocked_k_dequantized = dequantize_k_cache(blocked_k_quantized) + blocked_k = blocked_k_dequantized + + # Get schedule metadata + torch.cuda.synchronize() + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, is_fp8, topk + ) + torch.cuda.synchronize() + + out_ans, lse_ans = flash_mla_with_kvcache( + q, + blocked_k if not is_fp8 else blocked_k_quantized, # type: ignore + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=is_causal, + is_fp8_kvcache=is_fp8, + indices=indices_in_kvcache, + ) + + out_ref, lse_ref = reference_torch_decode( + cache_seqlens, block_table, q, blocked_k, dv, is_causal, abs_indices + ) + torch.testing.assert_close(out_ans, out_ref, atol=8e-4, rtol=2.01 / 128) + torch.testing.assert_close(lse_ans, lse_ref, atol=1e-6, rtol=8.01 / 65536) + + +if __name__ == "__main__": + pytest.main([__file__])