Remove 200us slow concat kernel (part 1: kernel) (#7145)

This commit is contained in:
fzyzcjy
2025-06-13 16:58:29 +08:00
committed by GitHub
parent 2f4ec752bc
commit aa46ed34d2
6 changed files with 79 additions and 48 deletions

View File

@@ -59,7 +59,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
m.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor! workspace, int num_kv_splits) -> ()");
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);