diff --git a/sgl-kernel/csrc/common_extension_rocm.cc b/sgl-kernel/csrc/common_extension_rocm.cc index 1f94d2615..f4e14d0d5 100644 --- a/sgl-kernel/csrc/common_extension_rocm.cc +++ b/sgl-kernel/csrc/common_extension_rocm.cc @@ -163,6 +163,14 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int " "page_size) -> ()"); m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct); + m.def( + "transfer_kv_per_layer_direct_pf_lf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, " + "Tensor dst_indices, int layer_id, int page_size)->() "); + m.impl("transfer_kv_per_layer_direct_pf_lf", torch::kCUDA, &transfer_kv_per_layer_direct_pf_lf); + m.def( + "transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, " + "Tensor dst_indices, int page_size) ->() "); + m.impl("transfer_kv_all_layer_direct_lf_pf", torch::kCUDA, &transfer_kv_all_layer_direct_lf_pf); } REGISTER_EXTENSION(common_ops)