Simple prefetch policy (#8692)

This commit is contained in:
pansicheng
2025-08-08 17:09:28 +08:00
committed by GitHub
parent 7490e3f67d
commit e2fd2b9c7e
6 changed files with 148 additions and 36 deletions

View File

@@ -16,6 +16,7 @@ limitations under the License.
import logging
import math
import threading
import time
from queue import Empty, Full, PriorityQueue, Queue
from typing import TYPE_CHECKING, List, Optional
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
self._done_flag = False
self._lock = threading.Lock()
self.start_time = time.monotonic()
super().__init__(host_indices, token_ids, last_hash)
def increment(self, num_tokens: int):
@@ -278,6 +281,12 @@ class HiCacheController:
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
self.prefetch_capacity_limit = int(
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
)
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
self.prefetch_tokens_occupied = 0
# create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1:
@@ -525,7 +534,7 @@ class HiCacheController:
host_indices: torch.Tensor,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
) -> int:
) -> PrefetchOperation:
"""
Prefetch KV caches from storage backend to host memory.
"""
@@ -586,11 +595,23 @@ class HiCacheController:
operation = self.prefetch_buffer.get(block=True, timeout=1)
if self.is_mooncake_backend():
self.mooncake_page_transfer(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_transfer(operation, batch_size=128)
else:
self.generic_page_transfer(operation)
except Empty:
continue
def prefetch_rate_limit_check(self) -> bool:
"""
Rate limit the prefetching operations to avoid overwhelming the storage backend.
"""
# cancel prefetch if too much memory is occupied
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
return False
# todo: more sophisticated rate limiting based on storage backend performance
return True
def prefetch_thread_func(self):
"""
Manage prefetching operations from storage backend to host memory.
@@ -604,34 +625,36 @@ class HiCacheController:
if operation is None:
continue
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
storage_hit_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
if self.prefetch_rate_limit_check():
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
# todo, more unified interface
if not self.is_mooncake_backend():
if not self.storage_backend.exists(last_hash):
break
hash_value.append(last_hash)
storage_hit_count += self.page_size
remaining_tokens -= self.page_size
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
if self.is_mooncake_backend():
# deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0) * self.page_size
)
# todo, more unified interface
if not self.is_mooncake_backend():
if not self.storage_backend.exists(last_hash):
break
hash_value.append(last_hash)
storage_hit_count += self.page_size
remaining_tokens -= self.page_size
if self.is_mooncake_backend():
# deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0)
* self.page_size
)
if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor(
@@ -750,6 +773,8 @@ class HiCacheController:
if self.is_mooncake_backend():
self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_backup(operation, batch_size=128)
else:
self.generic_page_backup(operation)

View File

@@ -619,6 +619,7 @@ class Scheduler(
),
hicache_mem_layout=server_args.hicache_mem_layout,
hicache_storage_backend=server_args.hicache_storage_backend,
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
)
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
@@ -1572,7 +1573,10 @@ class Scheduler(
break
if self.enable_hicache_storage:
self.tree_cache.check_prefetch_progress(req.rid)
prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
if not prefetch_done:
# skip staging requests that are ongoing prefetch
continue
req.init_next_round_input(self.tree_cache)
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))

View File

@@ -2,11 +2,12 @@ import heapq
import logging
import threading
import time
from queue import Queue
from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import (
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
hicache_io_backend: str,
hicache_mem_layout: str,
hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
):
if hicache_io_backend == "direct":
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
prefetch_threshold=self.prefetch_threshold,
)
self.prefetch_stop_policy = hicache_storage_prefetch_policy
# todo: customizable storage prefetch timeout
self.prefetch_timeout = 3 # seconds
logger.info(
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
)
# record the nodes with ongoing write through
self.ongoing_write_through = {}
# record the node segments with ongoing load back
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
for _ in range(queue_size.item()):
req_id = self.cache_controller.prefetch_revoke_queue.get()
if req_id in self.ongoing_prefetch:
last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
else:
# the revoked operation already got terminated
pass
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
host_node.release_host()
del self.ongoing_backup[ack_id]
def check_prefetch_progress(self, req_id: str):
def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True
if self.prefetch_stop_policy == "best_effort":
return can_terminate
completed = (
operation.completed_tokens == len(operation.hash_value) * self.page_size
)
if self.prefetch_stop_policy == "wait_complete":
can_terminate = completed
elif self.prefetch_stop_policy == "timeout":
can_terminate = completed or (
time.monotonic() - operation.start_time > self.prefetch_timeout
)
else:
# unknown prefetch stop policy, just return True
return True
if self.tp_world_size > 1:
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
torch.distributed.all_reduce(
can_terminate,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
can_terminate = bool(can_terminate.item())
return can_terminate
def check_prefetch_progress(self, req_id: str) -> bool:
if req_id not in self.ongoing_prefetch:
# there is no ongoing prefetch for this request or it has been revoked
return
return True
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
req_id
]
if not self.can_terminate_prefetch(operation):
return False
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
return True
def match_prefix(self, key: List[int], **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
host_indices,
operation,
)
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
node.last_access_time = time.monotonic()

View File

@@ -96,6 +96,8 @@ class Hf3fsClient:
)
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
self.shm_r.unlink()
self.shm_w.unlink()
self.rlock = threading.RLock()
self.wlock = threading.RLock()
@@ -176,8 +178,6 @@ class Hf3fsClient:
del self.iov_w
self.shm_r.close()
self.shm_w.close()
self.shm_r.unlink()
self.shm_w.unlink()
def flush(self) -> None:
os.fsync(self.file)

View File

@@ -203,6 +203,7 @@ class ServerArgs:
hicache_io_backend: str = "kernel"
hicache_mem_layout: str = "layer_first"
hicache_storage_backend: Optional[str] = None
hicache_storage_prefetch_policy: str = "best_effort"
# Double Sparsity
enable_double_sparsity: bool = False
@@ -1626,6 +1627,13 @@ class ServerArgs:
default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.",
)
parser.add_argument(
"--hicache-storage-prefetch-policy",
type=str,
choices=["best_effort", "wait_complete", "timeout"],
default=ServerArgs.hicache_storage_prefetch_policy,
help="Control when prefetching from the storage backend should stop.",
)
# Double Sparsity
parser.add_argument(