Add mamba kernel (#10234)

This commit is contained in:
Yi Zhang
2025-09-10 03:58:43 +08:00
committed by GitHub
parent 8471e5e616
commit 8cbe1538ef
8 changed files with 1418 additions and 0 deletions

View File

@@ -724,3 +724,27 @@ void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc,
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
/*
* From csrc/mamba
*/
void causal_conv1d_update(
const at::Tensor& x,
const at::Tensor& conv_state,
const at::Tensor& weight,
const std::optional<at::Tensor>& bias_,
bool silu_activation,
const std::optional<at::Tensor>& cache_seqlens_,
const std::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);
void causal_conv1d_fwd(
const at::Tensor& x,
const at::Tensor& weight,
const std::optional<at::Tensor>& bias_,
const std::optional<at::Tensor>& conv_states,
const std::optional<at::Tensor>& query_start_loc,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool silu_activation,
int64_t pad_slot_id);