fix page first per layer pf2lf kernel (#8915)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
huangtingwei
2025-08-10 08:16:11 +08:00
committed by GitHub
parent 5c31b35db2
commit 86497d99f2
5 changed files with 15 additions and 5 deletions

View File

@@ -250,7 +250,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
m.def(
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
"dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
m.def(
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
@@ -267,8 +267,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
m.def(
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, "
"int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, "
"int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf);
m.def(
"transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "