Simple prefetch policy (#8692)
This commit is contained in:
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty, Full, PriorityQueue, Queue
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
|
||||
self._done_flag = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self.start_time = time.monotonic()
|
||||
|
||||
super().__init__(host_indices, token_ids, last_hash)
|
||||
|
||||
def increment(self, num_tokens: int):
|
||||
@@ -278,6 +281,12 @@ class HiCacheController:
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||
self.prefetch_capacity_limit = int(
|
||||
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
||||
)
|
||||
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
||||
self.prefetch_tokens_occupied = 0
|
||||
|
||||
# create a new communication group for synchronizing storage operations across TP workers
|
||||
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||
if self.tp_world_size > 1:
|
||||
@@ -525,7 +534,7 @@ class HiCacheController:
|
||||
host_indices: torch.Tensor,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
) -> int:
|
||||
) -> PrefetchOperation:
|
||||
"""
|
||||
Prefetch KV caches from storage backend to host memory.
|
||||
"""
|
||||
@@ -586,11 +595,23 @@ class HiCacheController:
|
||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_transfer(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
self.generic_page_transfer(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_transfer(operation)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
def prefetch_rate_limit_check(self) -> bool:
|
||||
"""
|
||||
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
||||
"""
|
||||
# cancel prefetch if too much memory is occupied
|
||||
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
||||
return False
|
||||
# todo: more sophisticated rate limiting based on storage backend performance
|
||||
return True
|
||||
|
||||
def prefetch_thread_func(self):
|
||||
"""
|
||||
Manage prefetching operations from storage backend to host memory.
|
||||
@@ -604,34 +625,36 @@ class HiCacheController:
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
|
||||
storage_hit_count = 0
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_hit_count : storage_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
if self.prefetch_rate_limit_check():
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
|
||||
# todo, more unified interface
|
||||
if not self.is_mooncake_backend():
|
||||
if not self.storage_backend.exists(last_hash):
|
||||
break
|
||||
hash_value.append(last_hash)
|
||||
storage_hit_count += self.page_size
|
||||
remaining_tokens -= self.page_size
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_hit_count : storage_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
|
||||
if self.is_mooncake_backend():
|
||||
# deferring to batch exists for mooncake store
|
||||
exist_result = self.storage_backend.exists(hash_value)
|
||||
storage_hit_count = (
|
||||
sum(1 for v in exist_result.values() if v != 0) * self.page_size
|
||||
)
|
||||
# todo, more unified interface
|
||||
if not self.is_mooncake_backend():
|
||||
if not self.storage_backend.exists(last_hash):
|
||||
break
|
||||
hash_value.append(last_hash)
|
||||
storage_hit_count += self.page_size
|
||||
remaining_tokens -= self.page_size
|
||||
|
||||
if self.is_mooncake_backend():
|
||||
# deferring to batch exists for mooncake store
|
||||
exist_result = self.storage_backend.exists(hash_value)
|
||||
storage_hit_count = (
|
||||
sum(1 for v in exist_result.values() if v != 0)
|
||||
* self.page_size
|
||||
)
|
||||
|
||||
if self.tp_world_size > 1:
|
||||
storage_hit_count_tensor = torch.tensor(
|
||||
@@ -750,6 +773,8 @@ class HiCacheController:
|
||||
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_backup(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
self.generic_page_backup(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_backup(operation)
|
||||
|
||||
|
||||
@@ -619,6 +619,7 @@ class Scheduler(
|
||||
),
|
||||
hicache_mem_layout=server_args.hicache_mem_layout,
|
||||
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
||||
)
|
||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||
self.tree_cache.cache_controller.layer_done_counter
|
||||
@@ -1572,7 +1573,10 @@ class Scheduler(
|
||||
break
|
||||
|
||||
if self.enable_hicache_storage:
|
||||
self.tree_cache.check_prefetch_progress(req.rid)
|
||||
prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
|
||||
if not prefetch_done:
|
||||
# skip staging requests that are ongoing prefetch
|
||||
continue
|
||||
|
||||
req.init_next_round_input(self.tree_cache)
|
||||
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
||||
|
||||
@@ -2,11 +2,12 @@ import heapq
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
|
||||
hicache_io_backend: str,
|
||||
hicache_mem_layout: str,
|
||||
hicache_storage_backend: Optional[str] = None,
|
||||
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
||||
):
|
||||
|
||||
if hicache_io_backend == "direct":
|
||||
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
|
||||
prefetch_threshold=self.prefetch_threshold,
|
||||
)
|
||||
|
||||
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
||||
# todo: customizable storage prefetch timeout
|
||||
self.prefetch_timeout = 3 # seconds
|
||||
logger.info(
|
||||
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
self.ongoing_write_through = {}
|
||||
# record the node segments with ongoing load back
|
||||
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
|
||||
for _ in range(queue_size.item()):
|
||||
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||
if req_id in self.ongoing_prefetch:
|
||||
last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
|
||||
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||
else:
|
||||
# the revoked operation already got terminated
|
||||
pass
|
||||
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
|
||||
host_node.release_host()
|
||||
del self.ongoing_backup[ack_id]
|
||||
|
||||
def check_prefetch_progress(self, req_id: str):
|
||||
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
||||
can_terminate = True
|
||||
|
||||
if self.prefetch_stop_policy == "best_effort":
|
||||
return can_terminate
|
||||
|
||||
completed = (
|
||||
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
||||
)
|
||||
|
||||
if self.prefetch_stop_policy == "wait_complete":
|
||||
can_terminate = completed
|
||||
elif self.prefetch_stop_policy == "timeout":
|
||||
can_terminate = completed or (
|
||||
time.monotonic() - operation.start_time > self.prefetch_timeout
|
||||
)
|
||||
else:
|
||||
# unknown prefetch stop policy, just return True
|
||||
return True
|
||||
|
||||
if self.tp_world_size > 1:
|
||||
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
|
||||
torch.distributed.all_reduce(
|
||||
can_terminate,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
can_terminate = bool(can_terminate.item())
|
||||
|
||||
return can_terminate
|
||||
|
||||
def check_prefetch_progress(self, req_id: str) -> bool:
|
||||
if req_id not in self.ongoing_prefetch:
|
||||
# there is no ongoing prefetch for this request or it has been revoked
|
||||
return
|
||||
return True
|
||||
|
||||
# todo: more policies for prefetch progress such as timeout
|
||||
# the current policy is to prefetch with best effort and terminate when queuing is over
|
||||
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
|
||||
req_id
|
||||
]
|
||||
|
||||
if not self.can_terminate_prefetch(operation):
|
||||
return False
|
||||
|
||||
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
||||
operation
|
||||
)
|
||||
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||
|
||||
min_completed_tokens = completed_tokens
|
||||
if self.tp_world_size > 1:
|
||||
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
completed_tokens_tensor = torch.tensor(
|
||||
min_completed_tokens, dtype=torch.int
|
||||
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
|
||||
)
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||
|
||||
return True
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
|
||||
host_indices,
|
||||
operation,
|
||||
)
|
||||
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
||||
|
||||
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
@@ -96,6 +96,8 @@ class Hf3fsClient:
|
||||
)
|
||||
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
|
||||
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
|
||||
self.shm_r.unlink()
|
||||
self.shm_w.unlink()
|
||||
|
||||
self.rlock = threading.RLock()
|
||||
self.wlock = threading.RLock()
|
||||
@@ -176,8 +178,6 @@ class Hf3fsClient:
|
||||
del self.iov_w
|
||||
self.shm_r.close()
|
||||
self.shm_w.close()
|
||||
self.shm_r.unlink()
|
||||
self.shm_w.unlink()
|
||||
|
||||
def flush(self) -> None:
|
||||
os.fsync(self.file)
|
||||
|
||||
@@ -203,6 +203,7 @@ class ServerArgs:
|
||||
hicache_io_backend: str = "kernel"
|
||||
hicache_mem_layout: str = "layer_first"
|
||||
hicache_storage_backend: Optional[str] = None
|
||||
hicache_storage_prefetch_policy: str = "best_effort"
|
||||
|
||||
# Double Sparsity
|
||||
enable_double_sparsity: bool = False
|
||||
@@ -1626,6 +1627,13 @@ class ServerArgs:
|
||||
default=ServerArgs.hicache_storage_backend,
|
||||
help="The storage backend for hierarchical KV cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hicache-storage-prefetch-policy",
|
||||
type=str,
|
||||
choices=["best_effort", "wait_complete", "timeout"],
|
||||
default=ServerArgs.hicache_storage_prefetch_policy,
|
||||
help="Control when prefetching from the storage backend should stop.",
|
||||
)
|
||||
|
||||
# Double Sparsity
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user