kvcache io kernels and test case (#7382)
This commit is contained in:
@@ -230,6 +230,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"int cuda_stream) -> ()");
|
||||
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
|
||||
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
m.def(
|
||||
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int page_size) -> ()");
|
||||
m.impl("transfer_kv_per_layer_direct", torch::kCUDA, &transfer_kv_per_layer_direct);
|
||||
m.def(
|
||||
"transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int item_size, int num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int "
|
||||
"num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int page_size, int num_layers) -> ()");
|
||||
m.impl("transfer_kv_all_layer_direct", torch::kCUDA, &transfer_kv_all_layer_direct);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
|
||||
"block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size) "
|
||||
"-> ()");
|
||||
m.impl("transfer_kv_per_layer_mla_direct", torch::kCUDA, &transfer_kv_per_layer_mla_direct);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
|
||||
"num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, "
|
||||
"int num_layers) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user