[perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (#6929)

This commit is contained in:
JieXin Liang
2025-06-09 10:37:34 +08:00
committed by GitHub
parent de1350ea20
commit 18efb5e8e0
10 changed files with 2959 additions and 37 deletions

View File

@@ -109,8 +109,10 @@ void cutlass_mla_decode(
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace);
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0);
torch::Tensor const& workspace,
int64_t num_kv_splits = -1);
int64_t cutlass_mla_get_workspace_size(
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, int64_t num_kv_splits = -1);
/*
* From csrc/elementwise
*/