[hicache] Optimization for DMA copy (#8245)
This commit is contained in:
@@ -433,7 +433,9 @@ class HiCacheController:
|
|||||||
if self.io_backend == "kernel":
|
if self.io_backend == "kernel":
|
||||||
return host_indices.to(self.mem_pool_device.device), device_indices
|
return host_indices.to(self.mem_pool_device.device), device_indices
|
||||||
elif self.io_backend == "direct":
|
elif self.io_backend == "direct":
|
||||||
return host_indices, device_indices.cpu()
|
device_indices = device_indices.cpu()
|
||||||
|
host_indices, idx = host_indices.sort()
|
||||||
|
return host_indices, device_indices.index_select(0, idx)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported io backend")
|
raise ValueError(f"Unsupported io backend")
|
||||||
|
|
||||||
|
|||||||
@@ -451,15 +451,33 @@ void transfer_kv_direct(
|
|||||||
auto src_indices_cpu = src_indices.cpu();
|
auto src_indices_cpu = src_indices.cpu();
|
||||||
auto dst_indices_cpu = dst_indices.cpu();
|
auto dst_indices_cpu = dst_indices.cpu();
|
||||||
|
|
||||||
const int64_t num_pages = src_indices_cpu.size(0) / page_size;
|
const auto num_indices = src_indices_cpu.numel();
|
||||||
const int64_t num_layers = src_layers.size();
|
const int64_t num_layers = src_layers.size();
|
||||||
|
int64_t* src_indices_ptr = src_indices_cpu.data_ptr<int64_t>();
|
||||||
|
int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr<int64_t>();
|
||||||
|
|
||||||
for (int64_t i = 0; i < num_pages; ++i) {
|
int64_t start_index = 0;
|
||||||
auto src_index = src_indices_cpu[i * page_size].item<int64_t>();
|
int64_t end_index = 0;
|
||||||
auto dst_index = dst_indices_cpu[i * page_size].item<int64_t>();
|
|
||||||
|
for (int64_t i = 0; i < num_indices; ++i) {
|
||||||
|
if (i < num_indices - 1) {
|
||||||
|
auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i];
|
||||||
|
auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i];
|
||||||
|
|
||||||
|
if (src_diff == 1 && dst_diff == 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
end_index = i + 1;
|
||||||
|
} else { // last batch
|
||||||
|
end_index = num_indices;
|
||||||
|
}
|
||||||
|
auto src_index = src_indices_ptr[start_index];
|
||||||
|
auto dst_index = dst_indices_ptr[start_index];
|
||||||
|
auto num_tokens = end_index - start_index;
|
||||||
|
|
||||||
for (int64_t j = 0; j < num_layers; ++j) {
|
for (int64_t j = 0; j < num_layers; ++j) {
|
||||||
transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, page_size);
|
transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens);
|
||||||
}
|
}
|
||||||
|
start_index = end_index;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user