From 37c66ec8563d4fb9b6a658f0ea282bcf6171b2f0 Mon Sep 17 00:00:00 2001 From: yinfan98 <1106310035@qq.com> Date: Mon, 31 Mar 2025 03:57:10 +0800 Subject: [PATCH] [feat] add fa3 in sgl-kernel (#4902) Co-authored-by: Sleepcoo --- sgl-kernel/CMakeLists.txt | 54 ++ sgl-kernel/README.md | 30 + sgl-kernel/csrc/torch_extension.cc | 5 + sgl-kernel/include/sgl_kernel_ops.h | 47 ++ sgl-kernel/include/sgl_kernel_torch_shim.h | 122 +++ sgl-kernel/python/sgl_kernel/flash_attn.py | 201 +++++ sgl-kernel/tests/test_flash_attention.py | 841 +++++++++++++++++++++ 7 files changed, 1300 insertions(+) create mode 100644 sgl-kernel/include/sgl_kernel_torch_shim.h create mode 100644 sgl-kernel/python/sgl_kernel/flash_attn.py create mode 100644 sgl-kernel/tests/test_flash_attention.py diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 18ca42ba5..e4d03f133 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -25,6 +25,7 @@ find_package(Torch REQUIRED) include(FetchContent) +# cutlass FetchContent_Declare( repo-cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass @@ -32,6 +33,7 @@ FetchContent_Declare( GIT_SHALLOW ON ) FetchContent_Populate(repo-cutlass) +# DeepGEMM FetchContent_Declare( repo-deepgemm GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM @@ -39,6 +41,7 @@ FetchContent_Declare( GIT_SHALLOW ON ) FetchContent_Populate(repo-deepgemm) +# flashinfer FetchContent_Declare( repo-flashinfer GIT_REPOSITORY https://github.com/sgl-project/flashinfer @@ -46,6 +49,15 @@ FetchContent_Declare( GIT_SHALLOW OFF ) FetchContent_Populate(repo-flashinfer) +# flash-attention +FetchContent_Declare( + repo-flash-attention + GIT_REPOSITORY https://github.com/sgl-project/sgl-attn + GIT_TAG sgl-kernel + GIT_SHALLOW OFF +) +FetchContent_Populate(repo-flash-attention) + include_directories( ${PROJECT_SOURCE_DIR}/include @@ -54,6 +66,7 @@ include_directories( ${repo-cutlass_SOURCE_DIR}/tools/util/include ${repo-flashinfer_SOURCE_DIR}/include ${repo-flashinfer_SOURCE_DIR}/csrc + ${repo-flash-attention_SOURCE_DIR}/hopper ) set(CMAKE_CXX_STANDARD 17) @@ -78,6 +91,7 @@ set(SGL_KERNEL_CUDA_FLAGS "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1" "-DCUTLASS_DEBUG_TRACE_LEVEL=0" "--expt-relaxed-constexpr" + "--use_fast_math" "-Xcompiler=-Wconversion" "-Xcompiler=-fno-strict-aliasing" ) @@ -130,6 +144,30 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") +# set flash-attention sources file +# BF16 source files +file(GLOB FA3_BF16_GEN_SRCS + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") +file(GLOB FA3_BF16_GEN_SRCS_ + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu") +list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) + +# FP16 source files +file(GLOB FA3_FP16_GEN_SRCS + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu") +file(GLOB FA3_FP16_GEN_SRCS_ + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu") +list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) + +# FP8 source files +file(GLOB FA3_FP8_GEN_SRCS + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu") +file(GLOB FA3_FP8_GEN_SRCS_ + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu") +list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_}) + +set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS}) + set(SOURCES "csrc/allreduce/trt_reduce_internal.cu" "csrc/allreduce/trt_reduce_kernel.cu" @@ -160,6 +198,10 @@ set(SOURCES "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" + "${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu" + "${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp" + "${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu" + "${FA3_GEN_SRCS}" ) # Support abi3 for build @@ -173,6 +215,18 @@ target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cubl install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel") +# Add some flash-attention custom flag for inference +target_compile_definitions(common_ops PRIVATE + FLASHATTENTION_DISABLE_SM8x + FLASHATTENTION_DISABLE_BACKWARD + FLASHATTENTION_DISABLE_DROPOUT + # FLASHATTENTION_DISABLE_ALIBI + # FLASHATTENTION_DISABLE_SOFTCAP + FLASHATTENTION_DISABLE_UNEVEN_K + # FLASHATTENTION_DISABLE_LOCAL + FLASHATTENTION_VARLEN_ONLY +) + # JIT Logic # DeepGEMM diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 6de25bb01..90c5b5179 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -92,6 +92,36 @@ Steps to add a new kernel: ) ``` +### Integrating Third-Party Libraries with Data Type Conversion + +When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`. + +To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically. + +When you need to support new data type conversions, you can easily add conversion functions like this: + +```cpp +// Map `int` -> `int64_t` +template <> +struct pytorch_library_compatible_type { + using type = int64_t; + static int convert_from_type(int64_t arg) { + TORCH_CHECK(arg <= std::numeric_limits::max(), "int64_t value is too large to be converted to int"); + TORCH_CHECK(arg >= std::numeric_limits::min(), "int64_t value is too small to be converted to int"); + return arg; + } +}; +``` + +To use this with your library functions, simply wrap them with make_pytorch_shim: + +```cpp +/* + * From flash-attention + */ + m.def("fwd", make_pytorch_shim(mha_fwd)); +``` + ### Build & Install Development build: diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index 263a9d15c..c26b9024c 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -91,6 +91,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { m.def("top_p_renorm_probs", top_p_renorm_probs); m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs); m.def("top_p_sampling_from_probs", top_p_sampling_from_probs); + + /* + * From flash-attention + */ + m.def("fwd", make_pytorch_shim(mha_fwd)); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 847b24ebe..d89fccbb2 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -23,6 +23,8 @@ limitations under the License. #include +#include "sgl_kernel_torch_shim.h" + #define _CONCAT(A, B) A##B #define CONCAT(A, B) _CONCAT(A, B) @@ -291,3 +293,48 @@ void top_p_sampling_from_probs( double top_p_val, bool deterministic, int64_t cuda_stream); + +/* + * From flash-attention + */ +std::vector mha_fwd( + at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional& + k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional& + v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional& cu_seqlens_q_, // b+1 + std::optional& cu_seqlens_k_, // b+1 + std::optional& cu_seqlens_k_new_, // b+1 + std::optional& + seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional& + seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + // TODO: check if we need max_seqlen_k + std::optional max_seqlen_k_, + std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) + float const softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + float const softcap, + bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional& scheduler_metadata_, // (b + 1) + int num_splits, + std::optional pack_gqa_, + int const sm_margin); diff --git a/sgl-kernel/include/sgl_kernel_torch_shim.h b/sgl-kernel/include/sgl_kernel_torch_shim.h new file mode 100644 index 000000000..4101726be --- /dev/null +++ b/sgl-kernel/include/sgl_kernel_torch_shim.h @@ -0,0 +1,122 @@ +/*Adapt from: +https://github.com/neuralmagic/vllm-flash-attention/blob/90eacc1af2a7c3de62ea249e929ed5faccf38954/csrc/common/pytorch_shim.h + Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +/** + * Unforunately, the type signatures of the flash_attn ops are not compatible + * with the PyTorch library bindings. To get around that we use + * `make_pytorch_shim` which creates a lambda that exponses the API using + * PyTorch compatible types to the types, then converts them to the types + * expected by the flash_attn ops. This shims allows us to make minimal changes + * to `flash_api.cpp` making it easier to synchronize with upstream changes. + * + * The `pytorch_library_compatible_type` struct is used to map from the + * flash_attn ops types to a PyTorch library compatible one. The main issues is + * that the following types are not support by PyTorch libary bindings: + * - `int` + * - `float` + * - `std::optional &` + * - `std::optional &` + * So we convert them to (respectively): + * - `int64_t` + * - `double` + * - `const std::optional&` + * - `const std::optional&` + */ + +template +struct pytorch_library_compatible_type { + using type = T; + static T convert_from_type(T arg) { + return arg; + } +}; + +template +using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type::type; + +template +T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t arg) { + return pytorch_library_compatible_type::convert_from_type(arg); +} + +// Map `c10::optional &` -> `const c10::optional&` +// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate +// the optional container) +template +struct pytorch_library_compatible_type&> { + using type = const c10::optional&; + static c10::optional& convert_from_type(const c10::optional& arg) { + return const_cast&>(arg); + } +}; + +// Map `c10::optional` -> +// `c10::optional>` +// (NOTE: tested for `c10::optional` -> `c10::optional`) +template +struct pytorch_library_compatible_type> { + using type = c10::optional>; + static c10::optional> convert_from_type(c10::optional arg) { + return arg; + } +}; + +// Map `c10::optional&` -> `const c10::optional&` +template <> +struct pytorch_library_compatible_type&> { + using type = const c10::optional&; + static c10::optional& convert_from_type(const c10::optional& arg) { + return const_cast&>(reinterpret_cast&>(arg)); + } +}; + +// Map `int` -> `int64_t` +template <> +struct pytorch_library_compatible_type { + using type = int64_t; + static int convert_from_type(int64_t arg) { + TORCH_CHECK(arg <= std::numeric_limits::max(), "int64_t value is too large to be converted to int"); + TORCH_CHECK(arg >= std::numeric_limits::min(), "int64_t value is too small to be converted to int"); + return arg; + } +}; + +// Map `float` -> `double` +template <> +struct pytorch_library_compatible_type { + using type = double; + static float convert_from_type(double arg) { + TORCH_CHECK( + std::abs(arg) <= std::numeric_limits::max(), "double value is too large to be converted to float"); + return arg; + } +}; + +// +// Shim Utils +// + +template +auto make_pytorch_shim(Ret (*fun)(Args... args)) { + return [fun](pytorch_library_compatible_type_t... args) { + return fun(convert_from_pytorch_compatible_type(args)...); + }; +} diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py new file mode 100644 index 000000000..b23a64623 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -0,0 +1,201 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + + +def is_fa3_supported(device=None) -> bool: + # FA3 can fail without a enough shared memory for a some shapes, currently + # only 8.0 and 8.7 have enough shared memory for all shapes + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x + return FA3_AVAILABLE and ( + torch.cuda.get_device_capability(device)[0] >= 9 + or torch.cuda.get_device_capability(device) == (8, 0) + or torch.cuda.get_device_capability(device) == (8, 7) + ) + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + if softmax_scale is None: + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( + -0.5 + ) + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) + cache_seqlens = maybe_contiguous(cache_seqlens) + + q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)] + v_cache = ( + v_cache.contiguous() + if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 + else v_cache + ) + cu_seqlens_q, cu_seqlens_k_new = [ + maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new) + ] + page_table, cache_batch_idx, cache_leftpad = [ + maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad) + ] + rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] + rotary_seqlens = maybe_contiguous(rotary_seqlens) + + out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( + q, + k_cache, + v_cache, + k, + v, + qv, + None, # out + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_seqlens, + max_seqlen_q, + None, # max_seqlen_k + page_table, + cache_batch_idx, + cache_leftpad, + rotary_cos, + rotary_sin, + rotary_seqlens, + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + rotary_interleaved, + scheduler_metadata, + num_splits, + pack_gqa, + sm_margin, + ) + # return (out, softmax_lse) if return_softmax_lse else out + return (out, softmax_lse, *rest) if return_softmax_lse else out diff --git a/sgl-kernel/tests/test_flash_attention.py b/sgl-kernel/tests/test_flash_attention.py new file mode 100644 index 000000000..37e50cbd7 --- /dev/null +++ b/sgl-kernel/tests/test_flash_attention.py @@ -0,0 +1,841 @@ +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py +import itertools +import math +import os + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +apply_rotary_emb = None + +from sgl_kernel.flash_attn import flash_attn_with_kvcache + +DISABLE_BACKWARD = True +# For CI test, we close them to True. +# DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +# DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +# DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +# DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" +# DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" +# DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" +# DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +# DISABLE_FP8 = ( +# os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" +# or torch.cuda.get_device_capability("cuda")[0] < 9 +# ) + +DISABLE_SPLIT = True +DISABLE_PAGEDKV = True +DISABLE_APPENDKV = True +DISABLE_LOCAL = True +DISABLE_SOFTCAP = True +DISABLE_PACKGQA = True +DISABLE_FP16 = True +DISABLE_FP8 = True + + +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/padding.py +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = ( + (attention_mask + unused_mask) if unused_mask is not None else attention_mask + ) + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. + return ( + rearrange(hidden_states, "b s ... -> (b s) ...")[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def generate_random_padding_mask( + max_seqlen, batch_size, device, mode="random", zero_lengths=False +): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full( + (batch_size, 1), max_seqlen, device=device, dtype=torch.int32 + ) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) + elif mode == "third": + lengths = torch.randint( + max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device + ) + + if zero_lengths: + # Generate zero-lengths every 5 batches and the last batch. + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) + < lengths + ) + return padding_mask + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros( + (batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype + ) + output[indices] = hidden_states + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange( + torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" + ) + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], + col_idx >= sink_token_length, + ), + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim_v) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + ) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + ) + # Without this we might get NaN in dv + if key_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0 + ) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill( + torch.all(local_mask, dim=-1, keepdim=True), 0.0 + ) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) +) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +# @pytest.mark.parametrize("new_kv", [True]) +# @pytest.mark.parametrize( +# "causal,local", +# [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []), +# ) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +@pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize( + "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True] +) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) +@pytest.mark.parametrize( + "rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False] +) +# @pytest.mark.parametrize("rotary_interleaved", [True]) +@pytest.mark.parametrize( + "rotary_fraction", + ( + [0.0, 0.5, 1.0] + if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) + else [0.0] + ), +) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize( + "page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else []) +) +# @pytest.mark.parametrize("page_size", [None]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + (128, 128), + (256, 512), # To test appending KV with more than 1 block + (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + for dv in dv_vals: + has_qv = d == 64 and dv >= 256 + q = ( + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + if has_qv: + qv = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random" + ) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( + q, query_padding_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + ) + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + seqlen_new = ( + seqlen_q + if seqlen_new_eq_seqlen_q + else torch.randint(1, seqlen_q + 1, (1,)).item() + ) + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = ( + torch.randn( + batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + v = ( + torch.randn( + batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask( + seqlen_new, batch_size, device, mode="random" + ) + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( + k, key_new_padding_mask + ) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + d, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + v_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + dv, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, + page_size, + batch_size_cache, + nheads_k, + d, + dv, + device, + dtype, + dtype_ref, + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 + ) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat( + [ + ( + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + ) + for i in range(batch_size) + ] + ) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = ( + key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + ) + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, + arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx] + ).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, + arange < cache_seqlens_expanded + k_new_seqlens, + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat( + k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + v_cache_rep = repeat( + v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product( + num_splits_vals, precompute_metadata_vals + ): + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True, + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) + else: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * ( + out_pt - out_ref + ).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * ( + out_pt - out_ref + ).abs().mean().item() + + +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref +): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + v_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +if __name__ == "__main__": + pytest.main([__file__])