Add mamba kernel (#10234)
This commit is contained in:
@@ -724,3 +724,27 @@ void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc,
|
||||
|
||||
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
|
||||
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
|
||||
|
||||
/*
|
||||
* From csrc/mamba
|
||||
*/
|
||||
void causal_conv1d_update(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor>& cache_seqlens_,
|
||||
const std::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_fwd(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& conv_states,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
Reference in New Issue
Block a user