[Feat] Add sparse attn to sgl-kernel (#5327)

This commit is contained in:
PGFLMG
2025-04-13 02:36:36 +08:00
committed by GitHub
parent bc92107b03
commit 4879e50c6d
5 changed files with 625 additions and 14 deletions

View File

@@ -206,6 +206,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
/*
* From Sparse Flash Attention
*/
m.def(
"fwd_sparse(Tensor! q, Tensor k, Tensor v, "
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
"Tensor!? out, Tensor? alibi_slopes, "
"float p_dropout, float softmax_scale, bool is_causal, "
"float softcap, bool return_softmax, Generator? gen)"
"-> Tensor[]");
m.impl("fwd_sparse", torch::kCUDA, &flash::mha_fwd_sparse);
m.def(
"varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, "
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
"Tensor!? out, Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
"bool is_causal, float softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
}
REGISTER_EXTENSION(common_ops)