Add mamba kernel (#10234)

This commit is contained in:
Yi Zhang
2025-09-10 03:58:43 +08:00
committed by GitHub
parent 8471e5e616
commit 8cbe1538ef
8 changed files with 1418 additions and 0 deletions

View File

@@ -438,6 +438,31 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("copy_to_gpu_no_ce", torch::kCUDA, &copy_to_gpu_no_ce);
m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()");
m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k);
/*
* From csrc/mamba
*/
m.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
m.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
m.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
m.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
}
REGISTER_EXTENSION(common_ops)