Page first direct IO kernel (#10060)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -331,6 +331,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
||||
"page_size) -> ()");
|
||||
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_direct_pf_lf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
|
||||
"Tensor dst_indices, int layer_id, int page_size)->() ");
|
||||
m.impl("transfer_kv_per_layer_direct_pf_lf", torch::kCUDA, &transfer_kv_per_layer_direct_pf_lf);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
|
||||
"Tensor dst_indices, int page_size) ->() ");
|
||||
m.impl("transfer_kv_all_layer_direct_lf_pf", torch::kCUDA, &transfer_kv_all_layer_direct_lf_pf);
|
||||
|
||||
/*
|
||||
* From csrc/memory
|
||||
|
||||
@@ -437,8 +437,8 @@ void transfer_kv_all_layer_mla_lf_pf(
|
||||
}
|
||||
|
||||
inline void transfer_page_direct(
|
||||
const at::Tensor& src_buffer,
|
||||
at::Tensor& dst_buffer,
|
||||
const at::Tensor src_buffer,
|
||||
at::Tensor dst_buffer,
|
||||
int64_t src_page_index,
|
||||
int64_t dst_page_index,
|
||||
int64_t page_size) {
|
||||
@@ -493,3 +493,81 @@ void transfer_kv_direct(
|
||||
start_index = end_index;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool IsLf2Pf>
|
||||
inline void transfer_kv_page_first_direct_impl(
|
||||
const std::vector<at::Tensor>& src_ptrs,
|
||||
std::vector<at::Tensor> dst_ptrs,
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t start_layer_id,
|
||||
int64_t page_size) {
|
||||
TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
|
||||
TORCH_CHECK(page_size > 0, "Page size must be positive");
|
||||
TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size");
|
||||
|
||||
auto src_indices_cpu = src_indices.cpu();
|
||||
auto dst_indices_cpu = dst_indices.cpu();
|
||||
const int64_t num_pages = src_indices_cpu.size(0) / page_size;
|
||||
|
||||
if constexpr (IsLf2Pf) {
|
||||
const bool is_mla = dst_ptrs.size() == 1;
|
||||
const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2;
|
||||
|
||||
for (const auto i : c10::irange(num_pages)) {
|
||||
auto s_index = src_indices_cpu[i * page_size].item<int64_t>();
|
||||
auto d_index = dst_indices_cpu[i * page_size].item<int64_t>() / page_size;
|
||||
for (int64_t j = 0; j < num_layers; ++j) {
|
||||
transfer_page_direct(
|
||||
src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size);
|
||||
if (!is_mla) {
|
||||
transfer_page_direct(
|
||||
src_ptrs[j + num_layers],
|
||||
dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j),
|
||||
s_index,
|
||||
0,
|
||||
page_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const bool is_mla = src_ptrs.size() == 1;
|
||||
const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2;
|
||||
|
||||
for (const auto i : c10::irange(num_pages)) {
|
||||
auto s_index = src_indices_cpu[i * page_size].item<int64_t>() / page_size;
|
||||
auto d_index = dst_indices_cpu[i * page_size].item<int64_t>();
|
||||
for (int64_t j = 0; j < num_layers; ++j) {
|
||||
transfer_page_direct(
|
||||
src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size);
|
||||
if (!is_mla) {
|
||||
transfer_page_direct(
|
||||
src_ptrs[1].select(0, s_index).select(0, start_layer_id + j),
|
||||
dst_ptrs[j + num_layers],
|
||||
0,
|
||||
d_index,
|
||||
page_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void transfer_kv_per_layer_direct_pf_lf(
|
||||
const std::vector<at::Tensor>& src_ptrs,
|
||||
std::vector<at::Tensor> dst_ptrs,
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t layer_id,
|
||||
int64_t page_size) {
|
||||
transfer_kv_page_first_direct_impl<false>(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size);
|
||||
}
|
||||
|
||||
void transfer_kv_all_layer_direct_lf_pf(
|
||||
const std::vector<at::Tensor>& src_ptrs,
|
||||
std::vector<at::Tensor> dst_ptrs,
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t page_size) {
|
||||
transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
|
||||
}
|
||||
|
||||
@@ -569,6 +569,21 @@ void transfer_kv_direct(
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
void transfer_kv_per_layer_direct_pf_lf(
|
||||
const std::vector<at::Tensor>& src_ptrs,
|
||||
std::vector<at::Tensor> dst_ptrs,
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t layer_id,
|
||||
int64_t page_size);
|
||||
|
||||
void transfer_kv_all_layer_direct_lf_pf(
|
||||
const std::vector<at::Tensor>& src_ptrs,
|
||||
std::vector<at::Tensor> dst_ptrs,
|
||||
const at::Tensor& src_indices,
|
||||
const at::Tensor& dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
|
||||
@@ -128,6 +128,31 @@ def transfer_kv_direct(
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_per_layer_direct_pf_lf(
|
||||
src_ptrs: List[torch.Tensor],
|
||||
dst_ptrs: List[torch.Tensor],
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
layer_id: int,
|
||||
page_size: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf(
|
||||
src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer_direct_lf_pf(
|
||||
src_ptrs: List[torch.Tensor],
|
||||
dst_ptrs: List[torch.Tensor],
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
page_size: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf(
|
||||
src_ptrs, dst_ptrs, src_indices, dst_indices, page_size
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_per_layer_mla(
|
||||
src: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
|
||||
@@ -2,9 +2,11 @@ import pytest
|
||||
import torch
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_direct_lf_pf,
|
||||
transfer_kv_all_layer_mla,
|
||||
transfer_kv_direct,
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_direct_pf_lf,
|
||||
transfer_kv_per_layer_mla,
|
||||
)
|
||||
|
||||
@@ -13,6 +15,21 @@ def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices):
|
||||
dst_pool[dst_indices] = src_pool[src_indices].to(dst_pool.device)
|
||||
|
||||
|
||||
def ref_copy_with_indices_pf_direct(
|
||||
src_pool, dst_pool, src_indices, dst_indices, page_size, layer_id, lf_to_pf=False
|
||||
):
|
||||
if lf_to_pf:
|
||||
for i in range(0, len(src_indices), page_size):
|
||||
dst_pool[dst_indices[i] // page_size][layer_id] = src_pool[layer_id][
|
||||
src_indices[i : i + page_size]
|
||||
].to(dst_pool.device)
|
||||
else:
|
||||
for i in range(0, len(src_indices), page_size):
|
||||
dst_pool[layer_id][dst_indices[i : i + page_size]] = src_pool[
|
||||
src_indices[i] // page_size
|
||||
][layer_id].to(dst_pool.device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024])
|
||||
@pytest.mark.parametrize("page_size", [1, 16, 64])
|
||||
@@ -251,5 +268,218 @@ def test_transfer_kv(
|
||||
torch.set_default_dtype(original_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("num_items_to_transfer", [128, 1024, 8192])
|
||||
@pytest.mark.parametrize("page_size", [16, 64, 128])
|
||||
@pytest.mark.parametrize("item_size", [256])
|
||||
@pytest.mark.parametrize("total_items_in_pool", [20480])
|
||||
@pytest.mark.parametrize("is_mla", [False, True])
|
||||
@pytest.mark.parametrize("lf_to_pf", [False, True])
|
||||
def test_transfer_kv_pf_direct(
|
||||
dtype: torch.dtype,
|
||||
num_items_to_transfer: int,
|
||||
item_size: int,
|
||||
page_size: int,
|
||||
total_items_in_pool: int,
|
||||
is_mla: bool,
|
||||
lf_to_pf: bool,
|
||||
):
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
device = "cuda"
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
num_layers = 4
|
||||
|
||||
total_pages_in_pool = total_items_in_pool // page_size
|
||||
num_pages_to_transfer = num_items_to_transfer // page_size
|
||||
if num_pages_to_transfer == 0:
|
||||
torch.set_default_dtype(original_dtype)
|
||||
return
|
||||
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
|
||||
src_indices_host = torch.cat(
|
||||
[
|
||||
torch.arange(p * page_size, (p + 1) * page_size)
|
||||
for p in page_indices[:num_pages_to_transfer]
|
||||
]
|
||||
)
|
||||
src_indices_device = src_indices_host.to(device)
|
||||
dst_indices_host = torch.cat(
|
||||
[
|
||||
torch.arange(p * page_size, (p + 1) * page_size)
|
||||
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
|
||||
]
|
||||
)
|
||||
dst_indices_device = dst_indices_host.to(device)
|
||||
|
||||
# We will test the per-layer function on the first layer (index 0) of the pool.
|
||||
layer_idx_to_test = 0
|
||||
|
||||
if lf_to_pf:
|
||||
if is_mla:
|
||||
src_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
src_pool_ptrs = [src_pool[i] for i in range(num_layers)]
|
||||
dst_pool_ref = torch.zeros(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
dst_pool_direct = torch.zeros_like(dst_pool_ref)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
transfer_kv_all_layer_direct_lf_pf(
|
||||
src_pool_ptrs,
|
||||
[dst_pool_direct],
|
||||
src_indices_host,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
)
|
||||
for i in range(num_layers):
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_pool,
|
||||
dst_pool_ref,
|
||||
src_indices_device,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
i,
|
||||
lf_to_pf=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
|
||||
|
||||
else:
|
||||
src_k_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
src_k_pool_ptrs = [src_k_pool[i] for i in range(num_layers)]
|
||||
src_v_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
src_v_pool_ptrs = [src_v_pool[i] for i in range(num_layers)]
|
||||
dst_k_pool_ref = torch.zeros(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
transfer_kv_all_layer_direct_lf_pf(
|
||||
src_k_pool_ptrs + src_v_pool_ptrs,
|
||||
[dst_k_pool_direct, dst_v_pool_direct],
|
||||
src_indices_host,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
)
|
||||
for i in range(num_layers):
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_k_pool,
|
||||
dst_k_pool_ref,
|
||||
src_indices_device,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
i,
|
||||
lf_to_pf=True,
|
||||
)
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_v_pool,
|
||||
dst_v_pool_ref,
|
||||
src_indices_device,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
i,
|
||||
lf_to_pf=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
|
||||
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
|
||||
else:
|
||||
if is_mla:
|
||||
src_pool = torch.randn(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
|
||||
dst_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
dst_pool_direct = torch.zeros_like(dst_pool_ref)
|
||||
dst_pool_direct_ptrs = [dst_pool_direct[i] for i in range(num_layers)]
|
||||
torch.cuda.synchronize()
|
||||
|
||||
transfer_kv_per_layer_direct_pf_lf(
|
||||
[src_pool],
|
||||
[dst_pool_direct_ptrs[layer_idx_to_test]],
|
||||
src_indices_host,
|
||||
dst_indices_host,
|
||||
layer_idx_to_test,
|
||||
page_size,
|
||||
)
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_pool,
|
||||
dst_pool_ref,
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size,
|
||||
layer_idx_to_test,
|
||||
lf_to_pf=False,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
|
||||
else:
|
||||
src_k_pool = torch.randn(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
src_v_pool = torch.randn(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
|
||||
dst_k_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_k_pool_direct_ptrs = [dst_k_pool_direct[i] for i in range(num_layers)]
|
||||
|
||||
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
|
||||
dst_v_pool_direct_ptrs = [dst_v_pool_direct[i] for i in range(num_layers)]
|
||||
torch.cuda.synchronize()
|
||||
|
||||
transfer_kv_per_layer_direct_pf_lf(
|
||||
[src_k_pool, src_v_pool],
|
||||
[
|
||||
dst_k_pool_direct_ptrs[layer_idx_to_test],
|
||||
dst_v_pool_direct_ptrs[layer_idx_to_test],
|
||||
],
|
||||
src_indices_host,
|
||||
dst_indices_host,
|
||||
layer_idx_to_test,
|
||||
page_size,
|
||||
)
|
||||
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_k_pool,
|
||||
dst_k_pool_ref,
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size,
|
||||
layer_idx_to_test,
|
||||
lf_to_pf=False,
|
||||
)
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_v_pool,
|
||||
dst_v_pool_ref,
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size,
|
||||
layer_idx_to_test,
|
||||
lf_to_pf=False,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
|
||||
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
|
||||
torch.set_default_dtype(original_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user