Blackwell Cutlass MLA kernel (#5142)

This commit is contained in:
Trevor Morris
2025-04-11 22:16:51 -07:00
committed by GitHub
parent 5ad0571903
commit f65b8d5c89
7 changed files with 371 additions and 3 deletions

View File

@@ -45,6 +45,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"new_kv) -> ()");
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
m.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor workspace) -> ()");
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);
/*
* From csrc/elementwise