[Feat] Add sparse attn to sgl-kernel (#5327)
This commit is contained in:
@@ -256,18 +256,21 @@ void min_p_sampling_from_probs(
|
||||
double min_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
void top_k_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
int64_t top_k_val,
|
||||
int64_t cuda_stream);
|
||||
|
||||
void top_p_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val,
|
||||
int64_t cuda_stream);
|
||||
|
||||
void top_k_top_p_sampling_from_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor uniform_samples,
|
||||
@@ -279,6 +282,7 @@ void top_k_top_p_sampling_from_probs(
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
void top_p_sampling_from_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor uniform_samples,
|
||||
@@ -288,3 +292,49 @@ void top_p_sampling_from_probs(
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
namespace flash {
|
||||
/*
|
||||
* From fa2 sparse
|
||||
*/
|
||||
std::vector<at::Tensor> mha_fwd_sparse(
|
||||
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor& block_count,
|
||||
const at::Tensor& block_offset,
|
||||
const at::Tensor& column_count,
|
||||
const at::Tensor& column_index,
|
||||
const std::optional<at::Tensor>& out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const std::optional<at::Tensor>& alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const double p_dropout,
|
||||
const double softmax_scale,
|
||||
bool is_causal,
|
||||
const double softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_);
|
||||
|
||||
std::vector<at::Tensor> mha_varlen_fwd_sparse(
|
||||
at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
|
||||
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
|
||||
const at::Tensor& block_count,
|
||||
const at::Tensor& block_offset,
|
||||
const at::Tensor& column_count,
|
||||
const at::Tensor& column_index,
|
||||
const c10::optional<at::Tensor>& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor& cu_seqlens_q, // b+1
|
||||
const at::Tensor& cu_seqlens_k, // b+1
|
||||
const c10::optional<at::Tensor>&
|
||||
seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||
const c10::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
|
||||
int64_t max_seqlen_q,
|
||||
const int64_t max_seqlen_k,
|
||||
const double p_dropout,
|
||||
const double softmax_scale,
|
||||
const bool zero_tensors,
|
||||
bool is_causal,
|
||||
const double softcap,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_);
|
||||
} // namespace flash
|
||||
|
||||
Reference in New Issue
Block a user