Simple prefetch policy (#8692)

This commit is contained in:
pansicheng
2025-08-08 17:09:28 +08:00
committed by GitHub
parent 7490e3f67d
commit e2fd2b9c7e
6 changed files with 148 additions and 36 deletions

View File

@@ -16,6 +16,7 @@ limitations under the License.
import logging
import math
import threading
import time
from queue import Empty, Full, PriorityQueue, Queue
from typing import TYPE_CHECKING, List, Optional
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
self._done_flag = False
self._lock = threading.Lock()
self.start_time = time.monotonic()
super().__init__(host_indices, token_ids, last_hash)
def increment(self, num_tokens: int):
@@ -278,6 +281,12 @@ class HiCacheController:
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
self.prefetch_capacity_limit = int(
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
)
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
self.prefetch_tokens_occupied = 0
# create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1:
@@ -525,7 +534,7 @@ class HiCacheController:
host_indices: torch.Tensor,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
) -> int:
) -> PrefetchOperation:
"""
Prefetch KV caches from storage backend to host memory.
"""
@@ -586,11 +595,23 @@ class HiCacheController:
operation = self.prefetch_buffer.get(block=True, timeout=1)
if self.is_mooncake_backend():
self.mooncake_page_transfer(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_transfer(operation, batch_size=128)
else:
self.generic_page_transfer(operation)
except Empty:
continue
def prefetch_rate_limit_check(self) -> bool:
"""
Rate limit the prefetching operations to avoid overwhelming the storage backend.
"""
# cancel prefetch if too much memory is occupied
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
return False
# todo: more sophisticated rate limiting based on storage backend performance
return True
def prefetch_thread_func(self):
"""
Manage prefetching operations from storage backend to host memory.
@@ -604,34 +625,36 @@ class HiCacheController:
if operation is None:
continue
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
storage_hit_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
if self.prefetch_rate_limit_check():
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
# todo, more unified interface
if not self.is_mooncake_backend():
if not self.storage_backend.exists(last_hash):
break
hash_value.append(last_hash)
storage_hit_count += self.page_size
remaining_tokens -= self.page_size
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
if self.is_mooncake_backend():
# deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0) * self.page_size
)
# todo, more unified interface
if not self.is_mooncake_backend():
if not self.storage_backend.exists(last_hash):
break
hash_value.append(last_hash)
storage_hit_count += self.page_size
remaining_tokens -= self.page_size
if self.is_mooncake_backend():
# deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0)
* self.page_size
)
if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor(
@@ -750,6 +773,8 @@ class HiCacheController:
if self.is_mooncake_backend():
self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_backup(operation, batch_size=128)
else:
self.generic_page_backup(operation)