Large page size aligned hierarchical caching (#4581)

This commit is contained in:
Zhiqiang Xie
2025-04-01 22:38:15 -07:00
committed by GitHub
parent 9eb49e878b
commit e119f04215
8 changed files with 242 additions and 71 deletions

View File

@@ -149,6 +149,7 @@ class HiCacheController:
self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: HostKVCache,
page_size: int,
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective",
):
@@ -156,6 +157,7 @@ class HiCacheController:
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
self.mem_pool_host = mem_pool_host
self.write_policy = write_policy
self.page_size = page_size
self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -184,7 +186,12 @@ class HiCacheController:
self.load_stream = torch.cuda.Stream()
self.write_thread = threading.Thread(
target=self.write_thread_func_buffer, daemon=True
target=(
self.write_thread_func_buffer
if self.page_size == 1
else self.write_thread_func_direct
),
daemon=True,
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
@@ -205,7 +212,12 @@ class HiCacheController:
self.ack_load_queue.queue.clear()
self.write_thread = threading.Thread(
target=self.write_thread_func_buffer, daemon=True
target=(
self.write_thread_func_buffer
if self.page_size == 1
else self.write_thread_func_direct
),
daemon=True,
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
@@ -260,10 +272,12 @@ class HiCacheController:
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
operation.data = self.mem_pool_device.get_flat_data(
operation.device_indices
self.mem_pool_host.write_page_all_layers(
operation.host_indices,
operation.device_indices,
self.mem_pool_device,
)
self.mem_pool_host.transfer(operation.host_indices, operation.data)
self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
@@ -320,12 +334,21 @@ class HiCacheController:
self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num):
flat_data = self.mem_pool_host.get_flat_data_by_layer(
batch_operation.host_indices, i
)
self.mem_pool_device.transfer_per_layer(
batch_operation.device_indices, flat_data, i
)
if self.page_size == 1:
flat_data = self.mem_pool_host.get_flat_data_by_layer(
batch_operation.host_indices, i
)
self.mem_pool_device.transfer_per_layer(
batch_operation.device_indices, flat_data, i
)
else:
self.mem_pool_host.load_page_per_layer(
batch_operation.host_indices,
batch_operation.device_indices,
self.mem_pool_device,
i,
)
self.load_stream.synchronize()
self.layer_done_counter.increment()
self.mem_pool_host.complete_io(batch_operation.host_indices)