Large page size aligned hierarchical caching (#4581)
This commit is contained in:
@@ -149,6 +149,7 @@ class HiCacheController:
|
||||
self,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
mem_pool_host: HostKVCache,
|
||||
page_size: int,
|
||||
load_cache_event: threading.Event = None,
|
||||
write_policy: str = "write_through_selective",
|
||||
):
|
||||
@@ -156,6 +157,7 @@ class HiCacheController:
|
||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||
self.mem_pool_host = mem_pool_host
|
||||
self.write_policy = write_policy
|
||||
self.page_size = page_size
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
@@ -184,7 +186,12 @@ class HiCacheController:
|
||||
self.load_stream = torch.cuda.Stream()
|
||||
|
||||
self.write_thread = threading.Thread(
|
||||
target=self.write_thread_func_buffer, daemon=True
|
||||
target=(
|
||||
self.write_thread_func_buffer
|
||||
if self.page_size == 1
|
||||
else self.write_thread_func_direct
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
@@ -205,7 +212,12 @@ class HiCacheController:
|
||||
self.ack_load_queue.queue.clear()
|
||||
|
||||
self.write_thread = threading.Thread(
|
||||
target=self.write_thread_func_buffer, daemon=True
|
||||
target=(
|
||||
self.write_thread_func_buffer
|
||||
if self.page_size == 1
|
||||
else self.write_thread_func_direct
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
@@ -260,10 +272,12 @@ class HiCacheController:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
operation.data = self.mem_pool_device.get_flat_data(
|
||||
operation.device_indices
|
||||
self.mem_pool_host.write_page_all_layers(
|
||||
operation.host_indices,
|
||||
operation.device_indices,
|
||||
self.mem_pool_device,
|
||||
)
|
||||
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
||||
self.write_stream.synchronize()
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
@@ -320,12 +334,21 @@ class HiCacheController:
|
||||
|
||||
self.layer_done_counter.reset()
|
||||
for i in range(self.mem_pool_host.layer_num):
|
||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
||||
batch_operation.host_indices, i
|
||||
)
|
||||
self.mem_pool_device.transfer_per_layer(
|
||||
batch_operation.device_indices, flat_data, i
|
||||
)
|
||||
if self.page_size == 1:
|
||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
||||
batch_operation.host_indices, i
|
||||
)
|
||||
self.mem_pool_device.transfer_per_layer(
|
||||
batch_operation.device_indices, flat_data, i
|
||||
)
|
||||
else:
|
||||
self.mem_pool_host.load_page_per_layer(
|
||||
batch_operation.host_indices,
|
||||
batch_operation.device_indices,
|
||||
self.mem_pool_device,
|
||||
i,
|
||||
)
|
||||
self.load_stream.synchronize()
|
||||
self.layer_done_counter.increment()
|
||||
|
||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
||||
|
||||
Reference in New Issue
Block a user