[perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (#6929)
This commit is contained in:
@@ -109,8 +109,10 @@ void cutlass_mla_decode(
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace);
|
||||
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0);
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits = -1);
|
||||
int64_t cutlass_mla_get_workspace_size(
|
||||
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, int64_t num_kv_splits = -1);
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user