Hicache IO kernel refactoring (#8264)
This commit is contained in:
@@ -399,16 +399,7 @@ void transfer_kv_per_layer(
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_direct(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
void transfer_kv_all_layer(
|
||||
void transfer_kv_per_layer_pf_lf(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
@@ -416,21 +407,34 @@ void transfer_kv_all_layer(
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t src_layer_offset,
|
||||
int64_t dst_layer_offset,
|
||||
int64_t src_layout_dim,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_direct(
|
||||
const at::Tensor src_k,
|
||||
void transfer_kv_all_layer(
|
||||
const at::Tensor src_k_layers,
|
||||
const at::Tensor dst_k_layers,
|
||||
const at::Tensor src_v_layers,
|
||||
const at::Tensor dst_v_layers,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_lf_pf(
|
||||
const at::Tensor src_k_layers,
|
||||
at::Tensor dst_k,
|
||||
const at::Tensor src_v,
|
||||
const at::Tensor src_v_layers,
|
||||
at::Tensor dst_v,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size,
|
||||
int64_t num_layers);
|
||||
int64_t item_size,
|
||||
int64_t dst_layout_dim,
|
||||
int64_t num_layers,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_mla(
|
||||
const at::Tensor src,
|
||||
@@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla(
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_per_layer_mla_direct(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
void transfer_kv_all_layer_mla(
|
||||
void transfer_kv_per_layer_mla_pf_lf(
|
||||
const at::Tensor src,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t src_layer_offset,
|
||||
int64_t dst_layer_offset,
|
||||
int64_t src_layout_dim,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_mla_direct(
|
||||
const at::Tensor src,
|
||||
void transfer_kv_all_layer_mla(
|
||||
const at::Tensor src_layers,
|
||||
const at::Tensor dst_layers,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t item_size,
|
||||
int64_t num_layers,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_all_layer_mla_lf_pf(
|
||||
const at::Tensor src_layers,
|
||||
at::Tensor dst,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size,
|
||||
int64_t num_layers);
|
||||
int64_t item_size,
|
||||
int64_t dst_layout_dim,
|
||||
int64_t num_layers,
|
||||
int64_t block_quota,
|
||||
int64_t num_warps_per_block);
|
||||
|
||||
void transfer_kv_direct(
|
||||
const std::vector<at::Tensor>& src_layers,
|
||||
std::vector<at::Tensor> dst_layers,
|
||||
const at::Tensor src_indices,
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
|
||||
Reference in New Issue
Block a user