[Feat] QWen-1M context support[1/2]: Update block sparse attention backend utils kernel (#5847)

Co-authored-by: sighingnow <sighingnow@gmail.com>
This commit is contained in:
PGFLMG
2025-04-29 02:03:17 +08:00
committed by GitHub
parent d364b9b0f2
commit ee71ed8a41
6 changed files with 763 additions and 1 deletions

View File

@@ -234,6 +234,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Generator? gen) -> Tensor[]");
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
// Sparse Attention utils
m.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
m.impl("convert_vertical_slash_indexes", torch::kCUDA, &convert_vertical_slash_indexes);
m.def(
"convert_vertical_slash_indexes_mergehead("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" Tensor vertical_indices_count, Tensor slash_indices_count, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead);
/*
* From XGrammar
*/