HiCache Storage tp fix (#8878)
This commit is contained in:
@@ -570,10 +570,6 @@ class HiCacheController:
|
|||||||
)
|
)
|
||||||
completed_tokens += self.page_size
|
completed_tokens += self.page_size
|
||||||
else:
|
else:
|
||||||
# operation terminated by controller, release pre-allocated memory
|
|
||||||
self.mem_pool_host.free(
|
|
||||||
operation.host_indices[operation.completed_tokens :]
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
def mooncake_page_transfer(self, operation):
|
def mooncake_page_transfer(self, operation):
|
||||||
@@ -599,6 +595,14 @@ class HiCacheController:
|
|||||||
self.generic_page_transfer(operation, batch_size=128)
|
self.generic_page_transfer(operation, batch_size=128)
|
||||||
else:
|
else:
|
||||||
self.generic_page_transfer(operation)
|
self.generic_page_transfer(operation)
|
||||||
|
|
||||||
|
if self.tp_world_size > 1:
|
||||||
|
# to ensure all TP workers release the host memory at the same time
|
||||||
|
torch.distributed.barrier(group=self.prefetch_tp_group)
|
||||||
|
# operation terminated by controller, release pre-allocated memory
|
||||||
|
self.mem_pool_host.free(
|
||||||
|
operation.host_indices[operation.completed_tokens :]
|
||||||
|
)
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -626,7 +630,9 @@ class HiCacheController:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
storage_hit_count = 0
|
storage_hit_count = 0
|
||||||
if self.prefetch_rate_limit_check():
|
if (
|
||||||
|
operation.host_indices is not None
|
||||||
|
) and self.prefetch_rate_limit_check():
|
||||||
last_hash = operation.last_hash
|
last_hash = operation.last_hash
|
||||||
tokens_to_fetch = operation.token_ids
|
tokens_to_fetch = operation.token_ids
|
||||||
|
|
||||||
@@ -670,7 +676,8 @@ class HiCacheController:
|
|||||||
if storage_hit_count < self.prefetch_threshold:
|
if storage_hit_count < self.prefetch_threshold:
|
||||||
# not to prefetch if not enough benefits
|
# not to prefetch if not enough benefits
|
||||||
self.prefetch_revoke_queue.put(operation.request_id)
|
self.prefetch_revoke_queue.put(operation.request_id)
|
||||||
self.mem_pool_host.free(operation.host_indices)
|
if operation.host_indices is not None:
|
||||||
|
self.mem_pool_host.free(operation.host_indices)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -471,6 +471,10 @@ class HiRadixCache(RadixCache):
|
|||||||
req_id
|
req_id
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if operation.host_indices is None:
|
||||||
|
# prefetch has not been issued due to insufficient host memory
|
||||||
|
return True
|
||||||
|
|
||||||
if not self.can_terminate_prefetch(operation):
|
if not self.can_terminate_prefetch(operation):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -565,10 +569,6 @@ class HiRadixCache(RadixCache):
|
|||||||
if host_indices is None:
|
if host_indices is None:
|
||||||
self.evict_host(prefetch_length)
|
self.evict_host(prefetch_length)
|
||||||
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
||||||
if host_indices is None:
|
|
||||||
last_host_node.release_host()
|
|
||||||
# no sufficient host memory to prefetch
|
|
||||||
return
|
|
||||||
operation = self.cache_controller.prefetch(
|
operation = self.cache_controller.prefetch(
|
||||||
req_id, host_indices, new_input_tokens, last_hash
|
req_id, host_indices, new_input_tokens, last_hash
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user