kvcache io kernels and test case (#7382)
This commit is contained in:
@@ -371,6 +371,89 @@ void segment_packbits(
|
||||
int64_t batch_size,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
void transfer_kv_per_layer(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_direct(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
void transfer_kv_all_layer(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t src_layer_offset,
|
||||
int64_t dst_layer_offset,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_direct(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size,
|
||||
int64_t num_layers);
|
||||
|
||||
void transfer_kv_per_layer_mla(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_mla_direct(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
void transfer_kv_all_layer_mla(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t src_layer_offset,
|
||||
int64_t dst_layer_offset,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_mla_direct(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size,
|
||||
int64_t num_layers);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user