Hicache IO kernel refactoring (#8264)

This commit is contained in:
Zhiqiang Xie
2025-07-23 01:49:03 -07:00
committed by GitHub
parent 8abd3e77fe
commit b43263307f
5 changed files with 545 additions and 280 deletions

View File

@@ -249,34 +249,39 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
m.def(
"transfer_kv_per_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int page_size) -> ()");
m.impl("transfer_kv_per_layer_direct", torch::kCUDA, &transfer_kv_per_layer_direct);
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
m.def(
"transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int "
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
"Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer);
m.def(
"transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int page_size, int num_layers) -> ()");
m.impl("transfer_kv_all_layer_direct", torch::kCUDA, &transfer_kv_all_layer_direct);
"transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf);
m.def(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
m.def(
"transfer_kv_per_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size) "
"-> ()");
m.impl("transfer_kv_per_layer_mla_direct", torch::kCUDA, &transfer_kv_per_layer_mla_direct);
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, "
"int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf);
m.def(
"transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int num_warps_per_block) -> ()");
"transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "
"item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla);
m.def(
"transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, "
"int num_layers) -> ()");
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct);
"transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
"int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf);
m.def(
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
"page_size) -> ()");
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
/*
* From csrc/moe/cutlass_moe/w4a8