HiCache, check before terminate prefetching (#8372)
This commit is contained in:
@@ -201,8 +201,9 @@ class PrefetchOperation(StorageOperation):
|
|||||||
def increment(self, num_tokens: int):
|
def increment(self, num_tokens: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._done_flag:
|
if self._done_flag:
|
||||||
return
|
return False
|
||||||
self.completed_tokens += num_tokens
|
self.completed_tokens += num_tokens
|
||||||
|
return True
|
||||||
|
|
||||||
def mark_done(self):
|
def mark_done(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@@ -528,12 +529,12 @@ class HiCacheController:
|
|||||||
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
self.mem_pool_host.set_from_flat_data_page(
|
if operation.increment(self.page_size):
|
||||||
operation.host_indices[operation.completed_tokens],
|
self.mem_pool_host.set_from_flat_data_page(
|
||||||
page_data,
|
operation.host_indices[operation.completed_tokens],
|
||||||
)
|
page_data,
|
||||||
operation.increment(self.page_size)
|
)
|
||||||
if operation.is_done():
|
else:
|
||||||
# operation terminated by controller, release pre-allocated memory
|
# operation terminated by controller, release pre-allocated memory
|
||||||
self.mem_pool_host.free(
|
self.mem_pool_host.free(
|
||||||
operation.host_indices[operation.completed_tokens :]
|
operation.host_indices[operation.completed_tokens :]
|
||||||
@@ -589,6 +590,7 @@ class HiCacheController:
|
|||||||
if storage_hit_count < self.prefetch_threshold:
|
if storage_hit_count < self.prefetch_threshold:
|
||||||
# not to prefetch if not enough benefits
|
# not to prefetch if not enough benefits
|
||||||
self.prefetch_revoke_queue.put(operation.request_id)
|
self.prefetch_revoke_queue.put(operation.request_id)
|
||||||
|
self.mem_pool_host.free(operation.host_indices)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -365,10 +365,12 @@ class HiRadixCache(RadixCache):
|
|||||||
for _ in range(queue_size.item()):
|
for _ in range(queue_size.item()):
|
||||||
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||||
if req_id in self.ongoing_prefetch:
|
if req_id in self.ongoing_prefetch:
|
||||||
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
|
last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
|
||||||
last_host_node.release_host()
|
last_host_node.release_host()
|
||||||
self.cache_controller.mem_pool_host.free(host_indices)
|
|
||||||
del self.ongoing_prefetch[req_id]
|
del self.ongoing_prefetch[req_id]
|
||||||
|
else:
|
||||||
|
# the revoked operation already got terminated
|
||||||
|
pass
|
||||||
|
|
||||||
def check_backup_progress(self):
|
def check_backup_progress(self):
|
||||||
queue_size = torch.tensor(
|
queue_size = torch.tensor(
|
||||||
@@ -403,6 +405,7 @@ class HiRadixCache(RadixCache):
|
|||||||
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
||||||
req_id
|
req_id
|
||||||
]
|
]
|
||||||
|
|
||||||
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
||||||
operation
|
operation
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user