[Feat] Add sparse attn to sgl-kernel (#5327)

This commit is contained in:
PGFLMG
2025-04-13 02:36:36 +08:00
committed by GitHub
parent bc92107b03
commit 4879e50c6d
5 changed files with 625 additions and 14 deletions

View File

@@ -256,18 +256,21 @@ void min_p_sampling_from_probs(
double min_p_val,
bool deterministic,
int64_t cuda_stream);
void top_k_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
int64_t top_k_val,
int64_t cuda_stream);
void top_p_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
int64_t cuda_stream);
void top_k_top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
@@ -279,6 +282,7 @@ void top_k_top_p_sampling_from_probs(
double top_p_val,
bool deterministic,
int64_t cuda_stream);
void top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
@@ -288,3 +292,49 @@ void top_p_sampling_from_probs(
double top_p_val,
bool deterministic,
int64_t cuda_stream);
namespace flash {
/*
* From fa2 sparse
*/
std::vector<at::Tensor> mha_fwd_sparse(
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& block_count,
const at::Tensor& block_offset,
const at::Tensor& column_count,
const at::Tensor& column_index,
const std::optional<at::Tensor>& out_, // batch_size x seqlen_q x num_heads x head_size
const std::optional<at::Tensor>& alibi_slopes_, // num_heads or batch_size x num_heads
const double p_dropout,
const double softmax_scale,
bool is_causal,
const double softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);
std::vector<at::Tensor> mha_varlen_fwd_sparse(
at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
const at::Tensor& block_count,
const at::Tensor& block_offset,
const at::Tensor& column_count,
const at::Tensor& column_index,
const c10::optional<at::Tensor>& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor& cu_seqlens_q, // b+1
const at::Tensor& cu_seqlens_k, // b+1
const c10::optional<at::Tensor>&
seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const c10::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
int64_t max_seqlen_q,
const int64_t max_seqlen_k,
const double p_dropout,
const double softmax_scale,
const bool zero_tensors,
bool is_causal,
const double softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
} // namespace flash