Blackwell Cutlass MLA kernel (#5142)
This commit is contained in:
@@ -87,7 +87,14 @@ void lightning_attention_decode(
|
||||
const torch::Tensor& slope,
|
||||
torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
|
||||
void cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope_and_q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace);
|
||||
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0);
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user