Page first direct IO kernel (#10060)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
huangtingwei
2025-09-10 13:35:34 +08:00
committed by GitHub
parent 737d73ed5b
commit 5be8c2f7f7
5 changed files with 358 additions and 2 deletions

View File

@@ -331,6 +331,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
"page_size) -> ()");
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
m.def(
"transfer_kv_per_layer_direct_pf_lf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
"Tensor dst_indices, int layer_id, int page_size)->() ");
m.impl("transfer_kv_per_layer_direct_pf_lf", torch::kCUDA, &transfer_kv_per_layer_direct_pf_lf);
m.def(
"transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
"Tensor dst_indices, int page_size) ->() ");
m.impl("transfer_kv_all_layer_direct_lf_pf", torch::kCUDA, &transfer_kv_all_layer_direct_lf_pf);
/*
* From csrc/memory

View File

@@ -437,8 +437,8 @@ void transfer_kv_all_layer_mla_lf_pf(
}
inline void transfer_page_direct(
const at::Tensor& src_buffer,
at::Tensor& dst_buffer,
const at::Tensor src_buffer,
at::Tensor dst_buffer,
int64_t src_page_index,
int64_t dst_page_index,
int64_t page_size) {
@@ -493,3 +493,81 @@ void transfer_kv_direct(
start_index = end_index;
}
}
template <bool IsLf2Pf>
inline void transfer_kv_page_first_direct_impl(
const std::vector<at::Tensor>& src_ptrs,
std::vector<at::Tensor> dst_ptrs,
const at::Tensor& src_indices,
const at::Tensor& dst_indices,
int64_t start_layer_id,
int64_t page_size) {
TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
TORCH_CHECK(page_size > 0, "Page size must be positive");
TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size");
auto src_indices_cpu = src_indices.cpu();
auto dst_indices_cpu = dst_indices.cpu();
const int64_t num_pages = src_indices_cpu.size(0) / page_size;
if constexpr (IsLf2Pf) {
const bool is_mla = dst_ptrs.size() == 1;
const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2;
for (const auto i : c10::irange(num_pages)) {
auto s_index = src_indices_cpu[i * page_size].item<int64_t>();
auto d_index = dst_indices_cpu[i * page_size].item<int64_t>() / page_size;
for (int64_t j = 0; j < num_layers; ++j) {
transfer_page_direct(
src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size);
if (!is_mla) {
transfer_page_direct(
src_ptrs[j + num_layers],
dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j),
s_index,
0,
page_size);
}
}
}
} else {
const bool is_mla = src_ptrs.size() == 1;
const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2;
for (const auto i : c10::irange(num_pages)) {
auto s_index = src_indices_cpu[i * page_size].item<int64_t>() / page_size;
auto d_index = dst_indices_cpu[i * page_size].item<int64_t>();
for (int64_t j = 0; j < num_layers; ++j) {
transfer_page_direct(
src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size);
if (!is_mla) {
transfer_page_direct(
src_ptrs[1].select(0, s_index).select(0, start_layer_id + j),
dst_ptrs[j + num_layers],
0,
d_index,
page_size);
}
}
}
}
}
void transfer_kv_per_layer_direct_pf_lf(
const std::vector<at::Tensor>& src_ptrs,
std::vector<at::Tensor> dst_ptrs,
const at::Tensor& src_indices,
const at::Tensor& dst_indices,
int64_t layer_id,
int64_t page_size) {
transfer_kv_page_first_direct_impl<false>(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size);
}
void transfer_kv_all_layer_direct_lf_pf(
const std::vector<at::Tensor>& src_ptrs,
std::vector<at::Tensor> dst_ptrs,
const at::Tensor& src_indices,
const at::Tensor& dst_indices,
int64_t page_size) {
transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
}