[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)

This commit is contained in:
JieXin Liang
2025-06-15 03:45:41 +08:00
committed by GitHub
parent ed54bf9d19
commit ab1a4fa5cb
7 changed files with 29 additions and 17 deletions

View File

@@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
m.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor! workspace, int num_kv_splits) -> ()");
"page_table, Tensor! workspace, float sm_scale, int num_kv_splits) -> ()");
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);