diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 7133ad652..f2ab3151e 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -90,7 +90,7 @@ FetchContent_Populate(repo-flashinfer) FetchContent_Declare( repo-flash-attention GIT_REPOSITORY https://github.com/sgl-project/sgl-attn - GIT_TAG f9af0c2a1d82ab1812e6987e9338363cc2bf0f8d + GIT_TAG 36f9456cd48ec57c8d75d8d6b90933d4bedffb6b GIT_SHALLOW OFF ) FetchContent_Populate(repo-flash-attention) @@ -99,7 +99,7 @@ FetchContent_Populate(repo-flash-attention) FetchContent_Declare( repo-flash-attention-origin GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git - GIT_TAG 203b9b3dba39d5d08dffb49c09aa622984dff07d + GIT_TAG 5a5a65b48dc99fc7483d2a7d5cfb1d8befa89389 GIT_SHALLOW OFF ) FetchContent_Populate(repo-flash-attention-origin) diff --git a/sgl-kernel/csrc/flash_extension.cc b/sgl-kernel/csrc/flash_extension.cc index f80db673f..df6024dfa 100644 --- a/sgl-kernel/csrc/flash_extension.cc +++ b/sgl-kernel/csrc/flash_extension.cc @@ -23,40 +23,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { * From flash-attention */ m.def( - "fwd(Tensor! q," - " Tensor k," - " Tensor v," - " Tensor? k_new," - " Tensor? v_new," - " Tensor? q_v," - " Tensor!? out," - " Tensor? cu_seqlens_q," - " Tensor? cu_seqlens_k," - " Tensor? cu_seqlens_k_new," - " Tensor? seqused_q," - " Tensor? seqused_k," + "fwd(Tensor q," // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + " Tensor k," // (b_k, s_k, h_k, d) or (total_k, h_k, d) or paged + " Tensor v," // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) or paged + " Tensor? k_new," // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) + " Tensor? v_new," // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) + " Tensor? q_v," // (b, s_q, h, dv) or (total_q_new, h, dv) + " Tensor? out," // (b, s_q, h, dv) or (total_q, h, dv) + " Tensor? cu_seqlens_q," // b+1 + " Tensor? cu_seqlens_k," // b+1 + " Tensor? cu_seqlens_k_new," // b+1 + " Tensor? seqused_q," // b + " Tensor? seqused_k," // b " int? max_seqlen_q," - " int? max_seqlen_k," - " Tensor? page_table," - " Tensor? kv_batch_idx," - " Tensor? leftpad_k," - " Tensor? rotary_cos," - " Tensor? rotary_sin," - " Tensor? seqlens_rotary," - " Tensor? q_descale," - " Tensor? k_descale," - " Tensor? v_descale," - " float softmax_scale," + " int? max_seqlen_k," // TODO: check if needed + " Tensor? page_table," // (b_k, max_num_pages_per_seq) + " Tensor? kv_batch_idx," // b + " Tensor? leftpad_k," // b + " Tensor? rotary_cos," // seqlen_ro x (rotary_dim / 2) + " Tensor? rotary_sin," // seqlen_ro x (rotary_dim / 2) + " Tensor? seqlens_rotary," // b + " Tensor? q_descale," // (b, h_k) + " Tensor? k_descale," // (b, h_k) + " Tensor? v_descale," // (b, h_k) + " float? softmax_scale," // now optional " bool is_causal," " int window_size_left," " int window_size_right," - " float softcap," + " int attention_chunk," // NEW + " float softcap," // promoted to double in C++; schema float is fine " bool is_rotary_interleaved," - " Tensor? scheduler_metadata," + " Tensor? scheduler_metadata," // (b + 1) " int num_splits," " bool? pack_gqa," " int sm_margin," - " Tensor? sinks) -> Tensor[]"); + " Tensor? sinks" + ") -> (Tensor, Tensor, Tensor, Tensor)"); // NEW return type: tuple of 4 tensors + m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); } diff --git a/sgl-kernel/include/sgl_flash_kernel_ops.h b/sgl-kernel/include/sgl_flash_kernel_ops.h index 383e207c3..b36af6b69 100644 --- a/sgl-kernel/include/sgl_flash_kernel_ops.h +++ b/sgl-kernel/include/sgl_flash_kernel_ops.h @@ -42,45 +42,44 @@ limitations under the License. /* * From flash-attention */ -std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, - // h_k, d) if there is page_table. - const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, - // page_size, h_k, dv) if there is page_table. - std::optional& - k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional& - v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional& cu_seqlens_q_, // b+1 - std::optional& cu_seqlens_k_, // b+1 - std::optional& cu_seqlens_k_new_, // b+1 - std::optional& +std::tuple mha_fwd( + at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional& + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, + std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional& page_table_, // (b_k, max_num_pages_per_seq) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) - float const softmax_scale, + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional& scheduler_metadata_, // (b + 1) - int num_splits, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, std::optional pack_gqa_, - int const sm_margin, - std::optional& sinks_); + int64_t sm_margin, + std::optional& sinks_); // (h) diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index ea70abb18..e0584c06b 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -43,7 +43,7 @@ def flash_attn_with_kvcache( qv=None, rotary_cos=None, rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, cache_batch_idx: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, @@ -57,6 +57,7 @@ def flash_attn_with_kvcache( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + attention_chunk: Optional[int] = None, softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, scheduler_metadata=None, @@ -135,6 +136,7 @@ def flash_attn_with_kvcache( Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory. softcap: float. Anything > 0 activates softcapping attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, @@ -214,6 +216,7 @@ def flash_attn_with_kvcache( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] rotary_seqlens = maybe_contiguous(rotary_seqlens) + attention_chunk = 0 if attention_chunk is None else int(attention_chunk) out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, @@ -243,6 +246,7 @@ def flash_attn_with_kvcache( causal, window_size[0], window_size[1], + attention_chunk, softcap, rotary_interleaved, scheduler_metadata, @@ -272,6 +276,7 @@ def flash_attn_varlen_func( k_descale=None, v_descale=None, window_size=(-1, -1), + attention_chunk: Optional[int] = None, softcap=0.0, num_splits=1, pack_gqa=None, @@ -321,6 +326,7 @@ def flash_attn_varlen_func( softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( -0.5 ) + attention_chunk = 0 if attention_chunk is None else int(attention_chunk) out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, @@ -350,6 +356,7 @@ def flash_attn_varlen_func( causal, window_size[0], window_size[1], + attention_chunk, softcap, is_rotary_interleaved=False, scheduler_metadata=None,