[perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (#6929)

This commit is contained in:
JieXin Liang
2025-06-09 10:37:34 +08:00
committed by GitHub
parent de1350ea20
commit 18efb5e8e0
10 changed files with 2959 additions and 37 deletions

View File

@@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
m.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor workspace) -> ()");
"page_table, Tensor! workspace, int num_kv_splits) -> ()");
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);