[sgl-kernel] support flashmla libtorch (#11717)

This commit is contained in:
Fan Yin
2025-10-22 12:17:50 +08:00
committed by GitHub
parent 9d61205dac
commit 23afdfd1c2
6 changed files with 819 additions and 15 deletions

View File

@@ -842,6 +842,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace);
/*
* From fast-hadamard-transform
*/
@@ -850,3 +851,47 @@ torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
/*
* From csrc/fastertransformer
*/
std::vector<at::Tensor> get_mla_decoding_metadata(
at::Tensor& seqlens_k,
const int64_t num_q_tokens_per_head_k,
const int64_t h_k,
const std::optional<int64_t> h_q,
const bool is_fp8_kvcache,
const std::optional<int64_t> topk);
std::vector<at::Tensor> fwd_kvcache_mla(
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or
// num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int64_t head_size_v,
const at::Tensor& seqlens_k, // batch_size
const at::Tensor& block_table, // batch_size x max_num_blocks_per_seq
const double softmax_scale,
bool is_causal,
const at::Tensor& tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor& num_splits, // batch_size + 1
const bool& is_fp8,
const std::optional<at::Tensor>& indices // None, or batch_size x seqlen_q x topk
);
void FMHACutlassSM100FwdRun(
at::Tensor workspace_buffer,
at::Tensor q,
at::Tensor k,
at::Tensor v,
at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv,
at::Tensor o,
at::Tensor lse,
int64_t mask_mode_code,
double softmax_scale,
int64_t max_seqlen_q,
int64_t max_seqlen_kv,
bool is_varlen);
std::vector<at::Tensor>
sparse_prefill_fwd(const at::Tensor& q, const at::Tensor& kv, const at::Tensor& indices, double sm_scale, int64_t d_v);