From f792e3c561b427b9b00648cc4f23f54f457f46cc Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 13 Oct 2025 20:51:45 -0700 Subject: [PATCH] Revert "[NVIDIA] BUMP FA3 (#11444)" (#11582) --- sgl-kernel/CMakeLists.txt | 4 +- sgl-kernel/csrc/flash_extension.cc | 55 ++++++++-------- sgl-kernel/include/sgl_flash_kernel_ops.h | 73 +++++++++++----------- sgl-kernel/python/sgl_kernel/flash_attn.py | 9 +-- 4 files changed, 66 insertions(+), 75 deletions(-) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index f2ab3151e..7133ad652 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -90,7 +90,7 @@ FetchContent_Populate(repo-flashinfer) FetchContent_Declare( repo-flash-attention GIT_REPOSITORY https://github.com/sgl-project/sgl-attn - GIT_TAG 36f9456cd48ec57c8d75d8d6b90933d4bedffb6b + GIT_TAG f9af0c2a1d82ab1812e6987e9338363cc2bf0f8d GIT_SHALLOW OFF ) FetchContent_Populate(repo-flash-attention) @@ -99,7 +99,7 @@ FetchContent_Populate(repo-flash-attention) FetchContent_Declare( repo-flash-attention-origin GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git - GIT_TAG 5a5a65b48dc99fc7483d2a7d5cfb1d8befa89389 + GIT_TAG 203b9b3dba39d5d08dffb49c09aa622984dff07d GIT_SHALLOW OFF ) FetchContent_Populate(repo-flash-attention-origin) diff --git a/sgl-kernel/csrc/flash_extension.cc b/sgl-kernel/csrc/flash_extension.cc index df6024dfa..f80db673f 100644 --- a/sgl-kernel/csrc/flash_extension.cc +++ b/sgl-kernel/csrc/flash_extension.cc @@ -23,43 +23,40 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { * From flash-attention */ m.def( - "fwd(Tensor q," // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - " Tensor k," // (b_k, s_k, h_k, d) or (total_k, h_k, d) or paged - " Tensor v," // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) or paged - " Tensor? k_new," // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) - " Tensor? v_new," // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) - " Tensor? q_v," // (b, s_q, h, dv) or (total_q_new, h, dv) - " Tensor? out," // (b, s_q, h, dv) or (total_q, h, dv) - " Tensor? cu_seqlens_q," // b+1 - " Tensor? cu_seqlens_k," // b+1 - " Tensor? cu_seqlens_k_new," // b+1 - " Tensor? seqused_q," // b - " Tensor? seqused_k," // b + "fwd(Tensor! q," + " Tensor k," + " Tensor v," + " Tensor? k_new," + " Tensor? v_new," + " Tensor? q_v," + " Tensor!? out," + " Tensor? cu_seqlens_q," + " Tensor? cu_seqlens_k," + " Tensor? cu_seqlens_k_new," + " Tensor? seqused_q," + " Tensor? seqused_k," " int? max_seqlen_q," - " int? max_seqlen_k," // TODO: check if needed - " Tensor? page_table," // (b_k, max_num_pages_per_seq) - " Tensor? kv_batch_idx," // b - " Tensor? leftpad_k," // b - " Tensor? rotary_cos," // seqlen_ro x (rotary_dim / 2) - " Tensor? rotary_sin," // seqlen_ro x (rotary_dim / 2) - " Tensor? seqlens_rotary," // b - " Tensor? q_descale," // (b, h_k) - " Tensor? k_descale," // (b, h_k) - " Tensor? v_descale," // (b, h_k) - " float? softmax_scale," // now optional + " int? max_seqlen_k," + " Tensor? page_table," + " Tensor? kv_batch_idx," + " Tensor? leftpad_k," + " Tensor? rotary_cos," + " Tensor? rotary_sin," + " Tensor? seqlens_rotary," + " Tensor? q_descale," + " Tensor? k_descale," + " Tensor? v_descale," + " float softmax_scale," " bool is_causal," " int window_size_left," " int window_size_right," - " int attention_chunk," // NEW - " float softcap," // promoted to double in C++; schema float is fine + " float softcap," " bool is_rotary_interleaved," - " Tensor? scheduler_metadata," // (b + 1) + " Tensor? scheduler_metadata," " int num_splits," " bool? pack_gqa," " int sm_margin," - " Tensor? sinks" - ") -> (Tensor, Tensor, Tensor, Tensor)"); // NEW return type: tuple of 4 tensors - + " Tensor? sinks) -> Tensor[]"); m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); } diff --git a/sgl-kernel/include/sgl_flash_kernel_ops.h b/sgl-kernel/include/sgl_flash_kernel_ops.h index b36af6b69..383e207c3 100644 --- a/sgl-kernel/include/sgl_flash_kernel_ops.h +++ b/sgl-kernel/include/sgl_flash_kernel_ops.h @@ -42,44 +42,45 @@ limitations under the License. /* * From flash-attention */ -std::tuple mha_fwd( - at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, - // h_k, d) if there is page_table. - at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, - // page_size, h_k, dv) if there is page_table. - std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional cu_seqlens_q_, // b+1 - std::optional cu_seqlens_k_, // b+1 - std::optional cu_seqlens_k_new_, // b+1 - std::optional +std::vector mha_fwd( + at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional& + k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional& + v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional& cu_seqlens_q_, // b+1 + std::optional& cu_seqlens_k_, // b+1 + std::optional& cu_seqlens_k_new_, // b+1 + std::optional& seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional + std::optional& seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, + std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional page_table_, // (b_k, max_num_pages_per_seq) - std::optional kv_batch_idx_, // b. indices to index into the KV cache - std::optional leftpad_k_, // b - std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional seqlens_rotary_, // b - std::optional q_descale_, // (b, h_k), not (b, h) - std::optional k_descale_, // (b, h_k) - std::optional v_descale_, // (b, h_k) - std::optional softmax_scale_, + std::optional max_seqlen_k_, + std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) + float const softmax_scale, bool is_causal, - int64_t window_size_left, - int64_t window_size_right, - int64_t attention_chunk, - double softcap, - bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional scheduler_metadata_, // (b + 1) - int64_t num_splits, + int window_size_left, + int window_size_right, + float const softcap, + bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional& scheduler_metadata_, // (b + 1) + int num_splits, std::optional pack_gqa_, - int64_t sm_margin, - std::optional& sinks_); // (h) + int const sm_margin, + std::optional& sinks_); diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index e0584c06b..ea70abb18 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -43,7 +43,7 @@ def flash_attn_with_kvcache( qv=None, rotary_cos=None, rotary_sin=None, - cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_batch_idx: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, @@ -57,7 +57,6 @@ def flash_attn_with_kvcache( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window - attention_chunk: Optional[int] = None, softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, scheduler_metadata=None, @@ -136,7 +135,6 @@ def flash_attn_with_kvcache( Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. - attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory. softcap: float. Anything > 0 activates softcapping attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, @@ -216,7 +214,6 @@ def flash_attn_with_kvcache( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] rotary_seqlens = maybe_contiguous(rotary_seqlens) - attention_chunk = 0 if attention_chunk is None else int(attention_chunk) out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, @@ -246,7 +243,6 @@ def flash_attn_with_kvcache( causal, window_size[0], window_size[1], - attention_chunk, softcap, rotary_interleaved, scheduler_metadata, @@ -276,7 +272,6 @@ def flash_attn_varlen_func( k_descale=None, v_descale=None, window_size=(-1, -1), - attention_chunk: Optional[int] = None, softcap=0.0, num_splits=1, pack_gqa=None, @@ -326,7 +321,6 @@ def flash_attn_varlen_func( softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( -0.5 ) - attention_chunk = 0 if attention_chunk is None else int(attention_chunk) out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, @@ -356,7 +350,6 @@ def flash_attn_varlen_func( causal, window_size[0], window_size[1], - attention_chunk, softcap, is_rotary_interleaved=False, scheduler_metadata=None,