kvcache io kernels and test case (#7382)

This commit is contained in:
Zhiqiang Xie
2025-06-23 11:58:59 -07:00
committed by GitHub
parent 76139bfba0
commit 34c3f9b2d3
7 changed files with 845 additions and 0 deletions

View File

@@ -371,6 +371,89 @@ void segment_packbits(
int64_t batch_size,
int64_t cuda_stream = 0);
/*
* From csrc/kvcacheio
*/
void transfer_kv_per_layer(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_direct(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size);
void transfer_kv_all_layer(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t src_layer_offset,
int64_t dst_layer_offset,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_direct(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size,
int64_t num_layers);
void transfer_kv_per_layer_mla(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_mla_direct(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size);
void transfer_kv_all_layer_mla(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t src_layer_offset,
int64_t dst_layer_offset,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_mla_direct(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size,
int64_t num_layers);
/*
* From FlashInfer
*/