Page first direct IO kernel (#10060)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -331,6 +331,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
||||
"page_size) -> ()");
|
||||
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_direct_pf_lf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
|
||||
"Tensor dst_indices, int layer_id, int page_size)->() ");
|
||||
m.impl("transfer_kv_per_layer_direct_pf_lf", torch::kCUDA, &transfer_kv_per_layer_direct_pf_lf);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
|
||||
"Tensor dst_indices, int page_size) ->() ");
|
||||
m.impl("transfer_kv_all_layer_direct_lf_pf", torch::kCUDA, &transfer_kv_all_layer_direct_lf_pf);
|
||||
|
||||
/*
|
||||
* From csrc/memory
|
||||
|
||||
@@ -437,8 +437,8 @@ void transfer_kv_all_layer_mla_lf_pf(
|
||||
}
|
||||
|
||||
inline void transfer_page_direct(
|
||||
const at::Tensor& src_buffer,
|
||||
at::Tensor& dst_buffer,
|
||||
const at::Tensor src_buffer,
|
||||
at::Tensor dst_buffer,
|
||||
int64_t src_page_index,
|
||||
int64_t dst_page_index,
|
||||
int64_t page_size) {
|
||||
@@ -493,3 +493,81 @@ void transfer_kv_direct(
|
||||
start_index = end_index;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool IsLf2Pf>
|
||||
inline void transfer_kv_page_first_direct_impl(
|
||||
const std::vector<at::Tensor>& src_ptrs,
|
||||
std::vector<at::Tensor> dst_ptrs,
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t start_layer_id,
|
||||
int64_t page_size) {
|
||||
TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
|
||||
TORCH_CHECK(page_size > 0, "Page size must be positive");
|
||||
TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size");
|
||||
|
||||
auto src_indices_cpu = src_indices.cpu();
|
||||
auto dst_indices_cpu = dst_indices.cpu();
|
||||
const int64_t num_pages = src_indices_cpu.size(0) / page_size;
|
||||
|
||||
if constexpr (IsLf2Pf) {
|
||||
const bool is_mla = dst_ptrs.size() == 1;
|
||||
const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2;
|
||||
|
||||
for (const auto i : c10::irange(num_pages)) {
|
||||
auto s_index = src_indices_cpu[i * page_size].item<int64_t>();
|
||||
auto d_index = dst_indices_cpu[i * page_size].item<int64_t>() / page_size;
|
||||
for (int64_t j = 0; j < num_layers; ++j) {
|
||||
transfer_page_direct(
|
||||
src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size);
|
||||
if (!is_mla) {
|
||||
transfer_page_direct(
|
||||
src_ptrs[j + num_layers],
|
||||
dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j),
|
||||
s_index,
|
||||
0,
|
||||
page_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const bool is_mla = src_ptrs.size() == 1;
|
||||
const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2;
|
||||
|
||||
for (const auto i : c10::irange(num_pages)) {
|
||||
auto s_index = src_indices_cpu[i * page_size].item<int64_t>() / page_size;
|
||||
auto d_index = dst_indices_cpu[i * page_size].item<int64_t>();
|
||||
for (int64_t j = 0; j < num_layers; ++j) {
|
||||
transfer_page_direct(
|
||||
src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size);
|
||||
if (!is_mla) {
|
||||
transfer_page_direct(
|
||||
src_ptrs[1].select(0, s_index).select(0, start_layer_id + j),
|
||||
dst_ptrs[j + num_layers],
|
||||
0,
|
||||
d_index,
|
||||
page_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void transfer_kv_per_layer_direct_pf_lf(
|
||||
const std::vector<at::Tensor>& src_ptrs,
|
||||
std::vector<at::Tensor> dst_ptrs,
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t layer_id,
|
||||
int64_t page_size) {
|
||||
transfer_kv_page_first_direct_impl<false>(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size);
|
||||
}
|
||||
|
||||
void transfer_kv_all_layer_direct_lf_pf(
|
||||
const std::vector<at::Tensor>& src_ptrs,
|
||||
std::vector<at::Tensor> dst_ptrs,
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t page_size) {
|
||||
transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user