[NVIDIA] FA3/FA4 Fix (#11606)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
Johnny
2025-10-20 02:10:10 +02:00
committed by GitHub
parent cbb5fc2edc
commit 252dc4e112
10 changed files with 382 additions and 219 deletions

View File

@@ -23,40 +23,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
* From flash-attention
*/
m.def(
"fwd(Tensor! q,"
" Tensor k,"
" Tensor v,"
" Tensor? k_new,"
" Tensor? v_new,"
" Tensor? q_v,"
" Tensor!? out,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
"fwd(Tensor q," // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
" Tensor k," // (b_k, s_k, h_k, d) or (total_k, h_k, d) or paged
" Tensor v," // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) or paged
" Tensor? k_new," // (b, s_k_new, h_k, d) or (total_k_new, h_k, d)
" Tensor? v_new," // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv)
" Tensor? q_v," // (b, s_q, h, dv) or (total_q_new, h, dv)
" Tensor? out," // (b, s_q, h, dv) or (total_q, h, dv)
" Tensor? cu_seqlens_q," // b+1
" Tensor? cu_seqlens_k," // b+1
" Tensor? cu_seqlens_k_new," // b+1
" Tensor? seqused_q," // b
" Tensor? seqused_k," // b
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" int? max_seqlen_k," // TODO: check if needed
" Tensor? page_table," // (b_k, max_num_pages_per_seq)
" Tensor? kv_batch_idx," // b
" Tensor? leftpad_k," // b
" Tensor? rotary_cos," // seqlen_ro x (rotary_dim / 2)
" Tensor? rotary_sin," // seqlen_ro x (rotary_dim / 2)
" Tensor? seqlens_rotary," // b
" Tensor? q_descale," // (b, h_k)
" Tensor? k_descale," // (b, h_k)
" Tensor? v_descale," // (b, h_k)
" float? softmax_scale," // now optional
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" int attention_chunk," // NEW
" float softcap," // promoted to double in C++; schema float is fine
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" Tensor? scheduler_metadata," // (b + 1)
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin,"
" Tensor? sinks) -> Tensor[]");
" Tensor? sinks"
") -> (Tensor, Tensor, Tensor, Tensor)"); // NEW return type: tuple of 4 tensors
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
}