[perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (#6929)
This commit is contained in:
@@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
|
||||
m.def(
|
||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
|
||||
"page_table, Tensor workspace) -> ()");
|
||||
"page_table, Tensor! workspace, int num_kv_splits) -> ()");
|
||||
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user