Simple prefetch policy (#8692)
This commit is contained in:
@@ -2,11 +2,12 @@ import heapq
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
|
||||
hicache_io_backend: str,
|
||||
hicache_mem_layout: str,
|
||||
hicache_storage_backend: Optional[str] = None,
|
||||
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
||||
):
|
||||
|
||||
if hicache_io_backend == "direct":
|
||||
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
|
||||
prefetch_threshold=self.prefetch_threshold,
|
||||
)
|
||||
|
||||
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
||||
# todo: customizable storage prefetch timeout
|
||||
self.prefetch_timeout = 3 # seconds
|
||||
logger.info(
|
||||
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
self.ongoing_write_through = {}
|
||||
# record the node segments with ongoing load back
|
||||
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
|
||||
for _ in range(queue_size.item()):
|
||||
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||
if req_id in self.ongoing_prefetch:
|
||||
last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
|
||||
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||
else:
|
||||
# the revoked operation already got terminated
|
||||
pass
|
||||
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
|
||||
host_node.release_host()
|
||||
del self.ongoing_backup[ack_id]
|
||||
|
||||
def check_prefetch_progress(self, req_id: str):
|
||||
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
||||
can_terminate = True
|
||||
|
||||
if self.prefetch_stop_policy == "best_effort":
|
||||
return can_terminate
|
||||
|
||||
completed = (
|
||||
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
||||
)
|
||||
|
||||
if self.prefetch_stop_policy == "wait_complete":
|
||||
can_terminate = completed
|
||||
elif self.prefetch_stop_policy == "timeout":
|
||||
can_terminate = completed or (
|
||||
time.monotonic() - operation.start_time > self.prefetch_timeout
|
||||
)
|
||||
else:
|
||||
# unknown prefetch stop policy, just return True
|
||||
return True
|
||||
|
||||
if self.tp_world_size > 1:
|
||||
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
|
||||
torch.distributed.all_reduce(
|
||||
can_terminate,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
can_terminate = bool(can_terminate.item())
|
||||
|
||||
return can_terminate
|
||||
|
||||
def check_prefetch_progress(self, req_id: str) -> bool:
|
||||
if req_id not in self.ongoing_prefetch:
|
||||
# there is no ongoing prefetch for this request or it has been revoked
|
||||
return
|
||||
return True
|
||||
|
||||
# todo: more policies for prefetch progress such as timeout
|
||||
# the current policy is to prefetch with best effort and terminate when queuing is over
|
||||
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
|
||||
req_id
|
||||
]
|
||||
|
||||
if not self.can_terminate_prefetch(operation):
|
||||
return False
|
||||
|
||||
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
||||
operation
|
||||
)
|
||||
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||
|
||||
min_completed_tokens = completed_tokens
|
||||
if self.tp_world_size > 1:
|
||||
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
completed_tokens_tensor = torch.tensor(
|
||||
min_completed_tokens, dtype=torch.int
|
||||
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
|
||||
)
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||
|
||||
return True
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
|
||||
host_indices,
|
||||
operation,
|
||||
)
|
||||
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
||||
|
||||
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
Reference in New Issue
Block a user