[hicache] Optimization for DMA copy (#8245)
This commit is contained in:
@@ -433,7 +433,9 @@ class HiCacheController:
|
||||
if self.io_backend == "kernel":
|
||||
return host_indices.to(self.mem_pool_device.device), device_indices
|
||||
elif self.io_backend == "direct":
|
||||
return host_indices, device_indices.cpu()
|
||||
device_indices = device_indices.cpu()
|
||||
host_indices, idx = host_indices.sort()
|
||||
return host_indices, device_indices.index_select(0, idx)
|
||||
else:
|
||||
raise ValueError(f"Unsupported io backend")
|
||||
|
||||
|
||||
@@ -451,15 +451,33 @@ void transfer_kv_direct(
|
||||
auto src_indices_cpu = src_indices.cpu();
|
||||
auto dst_indices_cpu = dst_indices.cpu();
|
||||
|
||||
const int64_t num_pages = src_indices_cpu.size(0) / page_size;
|
||||
const auto num_indices = src_indices_cpu.numel();
|
||||
const int64_t num_layers = src_layers.size();
|
||||
int64_t* src_indices_ptr = src_indices_cpu.data_ptr<int64_t>();
|
||||
int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr<int64_t>();
|
||||
|
||||
for (int64_t i = 0; i < num_pages; ++i) {
|
||||
auto src_index = src_indices_cpu[i * page_size].item<int64_t>();
|
||||
auto dst_index = dst_indices_cpu[i * page_size].item<int64_t>();
|
||||
int64_t start_index = 0;
|
||||
int64_t end_index = 0;
|
||||
|
||||
for (int64_t i = 0; i < num_indices; ++i) {
|
||||
if (i < num_indices - 1) {
|
||||
auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i];
|
||||
auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i];
|
||||
|
||||
if (src_diff == 1 && dst_diff == 1) {
|
||||
continue;
|
||||
}
|
||||
end_index = i + 1;
|
||||
} else { // last batch
|
||||
end_index = num_indices;
|
||||
}
|
||||
auto src_index = src_indices_ptr[start_index];
|
||||
auto dst_index = dst_indices_ptr[start_index];
|
||||
auto num_tokens = end_index - start_index;
|
||||
|
||||
for (int64_t j = 0; j < num_layers; ++j) {
|
||||
transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, page_size);
|
||||
transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens);
|
||||
}
|
||||
start_index = end_index;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user