[sgl-kernel] support flashmla libtorch (#11717)
This commit is contained in:
@@ -842,6 +842,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace);
|
||||
|
||||
/*
|
||||
* From fast-hadamard-transform
|
||||
*/
|
||||
@@ -850,3 +851,47 @@ torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
|
||||
|
||||
/*
|
||||
* From csrc/fastertransformer
|
||||
*/
|
||||
std::vector<at::Tensor> get_mla_decoding_metadata(
|
||||
at::Tensor& seqlens_k,
|
||||
const int64_t num_q_tokens_per_head_k,
|
||||
const int64_t h_k,
|
||||
const std::optional<int64_t> h_q,
|
||||
const bool is_fp8_kvcache,
|
||||
const std::optional<int64_t> topk);
|
||||
|
||||
std::vector<at::Tensor> fwd_kvcache_mla(
|
||||
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor& kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or
|
||||
// num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
|
||||
const int64_t head_size_v,
|
||||
const at::Tensor& seqlens_k, // batch_size
|
||||
const at::Tensor& block_table, // batch_size x max_num_blocks_per_seq
|
||||
const double softmax_scale,
|
||||
bool is_causal,
|
||||
const at::Tensor& tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||
const at::Tensor& num_splits, // batch_size + 1
|
||||
const bool& is_fp8,
|
||||
const std::optional<at::Tensor>& indices // None, or batch_size x seqlen_q x topk
|
||||
);
|
||||
|
||||
void FMHACutlassSM100FwdRun(
|
||||
at::Tensor workspace_buffer,
|
||||
at::Tensor q,
|
||||
at::Tensor k,
|
||||
at::Tensor v,
|
||||
at::Tensor cumulative_seqlen_q,
|
||||
at::Tensor cumulative_seqlen_kv,
|
||||
at::Tensor o,
|
||||
at::Tensor lse,
|
||||
int64_t mask_mode_code,
|
||||
double softmax_scale,
|
||||
int64_t max_seqlen_q,
|
||||
int64_t max_seqlen_kv,
|
||||
bool is_varlen);
|
||||
|
||||
std::vector<at::Tensor>
|
||||
sparse_prefill_fwd(const at::Tensor& q, const at::Tensor& kv, const at::Tensor& indices, double sm_scale, int64_t d_v);
|
||||
|
||||
Reference in New Issue
Block a user