[Feat] Add sparse attn to sgl-kernel (#5327)
This commit is contained in:
@@ -206,6 +206,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
/*
|
||||
* From Sparse Flash Attention
|
||||
*/
|
||||
m.def(
|
||||
"fwd_sparse(Tensor! q, Tensor k, Tensor v, "
|
||||
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
|
||||
"Tensor!? out, Tensor? alibi_slopes, "
|
||||
"float p_dropout, float softmax_scale, bool is_causal, "
|
||||
"float softcap, bool return_softmax, Generator? gen)"
|
||||
"-> Tensor[]");
|
||||
m.impl("fwd_sparse", torch::kCUDA, &flash::mha_fwd_sparse);
|
||||
|
||||
m.def(
|
||||
"varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, "
|
||||
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
|
||||
"Tensor!? out, Tensor cu_seqlens_q, "
|
||||
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, "
|
||||
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
|
||||
"bool is_causal, float softcap, bool return_softmax, "
|
||||
"Generator? gen) -> Tensor[]");
|
||||
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
Reference in New Issue
Block a user