[NVIDIA] FA3/FA4 Fix (#11606)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
Johnny
2025-10-20 02:10:10 +02:00
committed by GitHub
parent cbb5fc2edc
commit 252dc4e112
10 changed files with 382 additions and 219 deletions

View File

@@ -42,45 +42,44 @@ limitations under the License.
/*
* From flash-attention
*/
std::vector<at::Tensor> mha_fwd(
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std::optional<const at::Tensor>&
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<const at::Tensor>&
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_fwd(
at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std::optional<at::Tensor> k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<at::Tensor> v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<at::Tensor> q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor> out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor> cu_seqlens_q_, // b+1
std::optional<at::Tensor> cu_seqlens_k_, // b+1
std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
std::optional<at::Tensor>
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>&
std::optional<at::Tensor>
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
std::optional<int64_t> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& seqlens_rotary_, // b
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
float const softmax_scale,
std::optional<int64_t> max_seqlen_k_,
std::optional<at::Tensor> page_table_, // (b_k, max_num_pages_per_seq)
std::optional<at::Tensor> kv_batch_idx_, // b. indices to index into the KV cache
std::optional<at::Tensor> leftpad_k_, // b
std::optional<at::Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<at::Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<at::Tensor> seqlens_rotary_, // b
std::optional<at::Tensor> q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor> k_descale_, // (b, h_k)
std::optional<at::Tensor> v_descale_, // (b, h_k)
std::optional<double> softmax_scale_,
bool is_causal,
int window_size_left,
int window_size_right,
float const softcap,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
int num_splits,
int64_t window_size_left,
int64_t window_size_right,
int64_t attention_chunk,
double softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor> scheduler_metadata_, // (b + 1)
int64_t num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin,
std::optional<const at::Tensor>& sinks_);
int64_t sm_margin,
std::optional<const at::Tensor>& sinks_); // (h)