Add mamba kernel (#10234)
This commit is contained in:
@@ -438,6 +438,31 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce);
|
||||
m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()");
|
||||
m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k);
|
||||
|
||||
/*
|
||||
* From csrc/mamba
|
||||
*/
|
||||
m.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation,"
|
||||
"Tensor? cache_seqlens_,"
|
||||
"Tensor? conv_state_indices,"
|
||||
"int pad_slot_id) -> ()");
|
||||
m.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
m.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor!? conv_states,"
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"bool silu_activation,"
|
||||
"int pad_slot_id) -> ()");
|
||||
m.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
Reference in New Issue
Block a user