Fix errors of hicache kernels in sgl-kernel for ROCm (#10339)
This commit is contained in:
@@ -163,6 +163,14 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
||||||
"page_size) -> ()");
|
"page_size) -> ()");
|
||||||
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
|
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
|
||||||
|
m.def(
|
||||||
|
"transfer_kv_per_layer_direct_pf_lf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
|
||||||
|
"Tensor dst_indices, int layer_id, int page_size)->() ");
|
||||||
|
m.impl("transfer_kv_per_layer_direct_pf_lf", torch::kCUDA, &transfer_kv_per_layer_direct_pf_lf);
|
||||||
|
m.def(
|
||||||
|
"transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
|
||||||
|
"Tensor dst_indices, int page_size) ->() ");
|
||||||
|
m.impl("transfer_kv_all_layer_direct_lf_pf", torch::kCUDA, &transfer_kv_all_layer_direct_lf_pf);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
Reference in New Issue
Block a user