[Feat] QWen-1M context support[1/2]: Update block sparse attention backend utils kernel (#5847)
Co-authored-by: sighingnow <sighingnow@gmail.com>
This commit is contained in:
@@ -353,6 +353,36 @@ std::vector<at::Tensor> mha_varlen_fwd_sparse(
|
||||
c10::optional<at::Generator> gen_);
|
||||
} // namespace flash
|
||||
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int64_t context_size,
|
||||
int64_t block_size_M,
|
||||
int64_t block_size_N,
|
||||
bool causal);
|
||||
|
||||
void convert_vertical_slash_indexes_mergehead(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
torch::Tensor vertical_indices_count, // [N_HEADS, ]
|
||||
torch::Tensor slash_indices_count,
|
||||
int64_t context_size,
|
||||
int64_t block_size_M,
|
||||
int64_t block_size_N,
|
||||
bool causal);
|
||||
|
||||
/*
|
||||
* From XGrammar
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user