Hicache IO kernel refactoring (#8264)

This commit is contained in:
Zhiqiang Xie
2025-07-23 01:49:03 -07:00
committed by GitHub
parent 8abd3e77fe
commit b43263307f
5 changed files with 545 additions and 280 deletions

View File

@@ -399,16 +399,7 @@ void transfer_kv_per_layer(
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_direct(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size);
void transfer_kv_all_layer(
void transfer_kv_per_layer_pf_lf(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
@@ -416,21 +407,34 @@ void transfer_kv_all_layer(
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t src_layer_offset,
int64_t dst_layer_offset,
int64_t src_layout_dim,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_direct(
const at::Tensor src_k,
void transfer_kv_all_layer(
const at::Tensor src_k_layers,
const at::Tensor dst_k_layers,
const at::Tensor src_v_layers,
const at::Tensor dst_v_layers,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_lf_pf(
const at::Tensor src_k_layers,
at::Tensor dst_k,
const at::Tensor src_v,
const at::Tensor src_v_layers,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size,
int64_t num_layers);
int64_t item_size,
int64_t dst_layout_dim,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_mla(
const at::Tensor src,
@@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla(
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_mla_direct(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size);
void transfer_kv_all_layer_mla(
void transfer_kv_per_layer_mla_pf_lf(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t src_layer_offset,
int64_t dst_layer_offset,
int64_t src_layout_dim,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_mla_direct(
const at::Tensor src,
void transfer_kv_all_layer_mla(
const at::Tensor src_layers,
const at::Tensor dst_layers,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_mla_lf_pf(
const at::Tensor src_layers,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size,
int64_t num_layers);
int64_t item_size,
int64_t dst_layout_dim,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_direct(
const std::vector<at::Tensor>& src_layers,
std::vector<at::Tensor> dst_layers,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size);
/*
* From csrc/moe/cutlass_moe/w4a8