Blackwell Cutlass MLA kernel (#5142)

This commit is contained in:
Trevor Morris
2025-04-11 22:16:51 -07:00
committed by GitHub
parent 5ad0571903
commit f65b8d5c89
7 changed files with 371 additions and 3 deletions

View File

@@ -87,7 +87,14 @@ void lightning_attention_decode(
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv);
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace);
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0);
/*
* From csrc/elementwise
*/