Simple prefetch policy (#8692)
This commit is contained in:
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty, Full, PriorityQueue, Queue
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
|
||||
self._done_flag = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self.start_time = time.monotonic()
|
||||
|
||||
super().__init__(host_indices, token_ids, last_hash)
|
||||
|
||||
def increment(self, num_tokens: int):
|
||||
@@ -278,6 +281,12 @@ class HiCacheController:
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||
self.prefetch_capacity_limit = int(
|
||||
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
||||
)
|
||||
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
||||
self.prefetch_tokens_occupied = 0
|
||||
|
||||
# create a new communication group for synchronizing storage operations across TP workers
|
||||
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||
if self.tp_world_size > 1:
|
||||
@@ -525,7 +534,7 @@ class HiCacheController:
|
||||
host_indices: torch.Tensor,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
) -> int:
|
||||
) -> PrefetchOperation:
|
||||
"""
|
||||
Prefetch KV caches from storage backend to host memory.
|
||||
"""
|
||||
@@ -586,11 +595,23 @@ class HiCacheController:
|
||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_transfer(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
self.generic_page_transfer(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_transfer(operation)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
def prefetch_rate_limit_check(self) -> bool:
|
||||
"""
|
||||
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
||||
"""
|
||||
# cancel prefetch if too much memory is occupied
|
||||
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
||||
return False
|
||||
# todo: more sophisticated rate limiting based on storage backend performance
|
||||
return True
|
||||
|
||||
def prefetch_thread_func(self):
|
||||
"""
|
||||
Manage prefetching operations from storage backend to host memory.
|
||||
@@ -604,34 +625,36 @@ class HiCacheController:
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
|
||||
storage_hit_count = 0
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_hit_count : storage_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
if self.prefetch_rate_limit_check():
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
|
||||
# todo, more unified interface
|
||||
if not self.is_mooncake_backend():
|
||||
if not self.storage_backend.exists(last_hash):
|
||||
break
|
||||
hash_value.append(last_hash)
|
||||
storage_hit_count += self.page_size
|
||||
remaining_tokens -= self.page_size
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_hit_count : storage_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
|
||||
if self.is_mooncake_backend():
|
||||
# deferring to batch exists for mooncake store
|
||||
exist_result = self.storage_backend.exists(hash_value)
|
||||
storage_hit_count = (
|
||||
sum(1 for v in exist_result.values() if v != 0) * self.page_size
|
||||
)
|
||||
# todo, more unified interface
|
||||
if not self.is_mooncake_backend():
|
||||
if not self.storage_backend.exists(last_hash):
|
||||
break
|
||||
hash_value.append(last_hash)
|
||||
storage_hit_count += self.page_size
|
||||
remaining_tokens -= self.page_size
|
||||
|
||||
if self.is_mooncake_backend():
|
||||
# deferring to batch exists for mooncake store
|
||||
exist_result = self.storage_backend.exists(hash_value)
|
||||
storage_hit_count = (
|
||||
sum(1 for v in exist_result.values() if v != 0)
|
||||
* self.page_size
|
||||
)
|
||||
|
||||
if self.tp_world_size > 1:
|
||||
storage_hit_count_tensor = torch.tensor(
|
||||
@@ -750,6 +773,8 @@ class HiCacheController:
|
||||
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_backup(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
self.generic_page_backup(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_backup(operation)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user