[HiCache] Storage Refactoring (#9797)
Co-authored-by: pansicheng <27603155+pansicheng@users.noreply.github.com>
This commit is contained in:
@@ -250,26 +250,21 @@ class HiCacheController:
|
||||
self.write_policy = write_policy
|
||||
self.page_size = page_size
|
||||
self.io_backend = io_backend
|
||||
|
||||
self.enable_storage = False
|
||||
|
||||
# todo: move backend initialization to storage backend module
|
||||
if storage_backend is not None:
|
||||
self.storage_backend_type = storage_backend
|
||||
from sglang.srt.mem_cache.hicache_storage import get_hash_str
|
||||
|
||||
self.get_hash_str = get_hash_str
|
||||
|
||||
self.storage_config = self._generate_storage_config(
|
||||
model_name, storage_backend_extra_config
|
||||
)
|
||||
# In MLA backend, only one rank needs to backup the KV cache
|
||||
# for MLA models, only one rank needs to backup the KV cache
|
||||
self.backup_skip = (
|
||||
self.storage_config.is_mla_model
|
||||
# todo: for load balancing, decide which rank to backup the KV cache by hash value
|
||||
# todo: load balancing
|
||||
and self.storage_config.tp_rank != 0
|
||||
# todo: support other storage backends
|
||||
and self.storage_backend_type in ["file", "mooncake"]
|
||||
)
|
||||
|
||||
if storage_backend == "file":
|
||||
@@ -309,12 +304,15 @@ class HiCacheController:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported storage backend: {storage_backend}"
|
||||
)
|
||||
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||
self.prefetch_capacity_limit = int(
|
||||
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
||||
)
|
||||
# granularity of batch storage IO operations, in number of pages
|
||||
self.storage_batch_size = 128
|
||||
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
||||
self.prefetch_tokens_occupied = 0
|
||||
|
||||
@@ -325,12 +323,6 @@ class HiCacheController:
|
||||
self.prefetch_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
self.prefetch_io_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
self.backup_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
@@ -380,6 +372,7 @@ class HiCacheController:
|
||||
|
||||
self.prefetch_revoke_queue = Queue()
|
||||
self.ack_backup_queue = Queue()
|
||||
self.host_mem_release_queue = Queue()
|
||||
|
||||
self.prefetch_thread.start()
|
||||
self.backup_thread.start()
|
||||
@@ -618,7 +611,11 @@ class HiCacheController:
|
||||
operation.mark_done()
|
||||
return operation.completed_tokens, operation.hash_value
|
||||
|
||||
# zero copy
|
||||
def append_host_mem_release(self, host_indices: torch.Tensor):
|
||||
chunks = host_indices.split(self.mem_pool_host.page_size)
|
||||
for chunk in chunks:
|
||||
self.host_mem_release_queue.put(chunk)
|
||||
|
||||
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||
hash_values, host_indices
|
||||
@@ -631,7 +628,6 @@ class HiCacheController:
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
||||
)
|
||||
|
||||
# zero copy
|
||||
def _mooncake_page_get(self, operation, hash_values, host_indices):
|
||||
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||
hash_values,
|
||||
@@ -650,9 +646,7 @@ class HiCacheController:
|
||||
if get_result != 0:
|
||||
operation.increment(get_result * self.page_size)
|
||||
|
||||
# non-zero copy
|
||||
def _generic_page_get(self, operation, hash_values, host_indices):
|
||||
# todo: zero copy
|
||||
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
||||
hash_values
|
||||
)
|
||||
@@ -675,22 +669,19 @@ class HiCacheController:
|
||||
|
||||
def _page_transfer(self, operation):
|
||||
# Select the get function and batch size
|
||||
if self.is_mooncake_backend():
|
||||
if self.storage_backend_type == "mooncake":
|
||||
get_func = self._mooncake_page_get
|
||||
batch_size = 128
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
get_func = self._3fs_zero_copy_page_get
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
get_func = self._generic_page_get
|
||||
batch_size = 128
|
||||
elif (
|
||||
self.storage_backend_type == "hf3fs"
|
||||
and self.mem_pool_host.layout == "page_first"
|
||||
):
|
||||
get_func = self._3fs_zero_copy_page_get
|
||||
else:
|
||||
get_func = self._generic_page_get
|
||||
batch_size = 8
|
||||
|
||||
# Transfer batch by batch
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
batch_hashes = operation.hash_value[i : i + batch_size]
|
||||
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
||||
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
||||
batch_host_indices = operation.host_indices[
|
||||
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
||||
]
|
||||
@@ -704,10 +695,9 @@ class HiCacheController:
|
||||
):
|
||||
break # Some operations fail or operation terminated by controller
|
||||
# release pre-allocated memory
|
||||
self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :])
|
||||
|
||||
def is_mooncake_backend(self):
|
||||
return self.storage_backend_type == "mooncake"
|
||||
self.append_host_mem_release(
|
||||
operation.host_indices[operation.completed_tokens :]
|
||||
)
|
||||
|
||||
def prefetch_io_aux_func(self):
|
||||
"""
|
||||
@@ -717,47 +707,49 @@ class HiCacheController:
|
||||
try:
|
||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||
self._page_transfer(operation)
|
||||
|
||||
if self.tp_world_size > 1:
|
||||
# to ensure all TP workers release the host memory at the same time
|
||||
torch.distributed.barrier(group=self.prefetch_io_tp_group)
|
||||
# operation terminated by controller, release pre-allocated memory
|
||||
self.mem_pool_host.free(
|
||||
self.append_host_mem_release(
|
||||
operation.host_indices[operation.completed_tokens :]
|
||||
)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
def prefetch_rate_limit_check(self) -> bool:
|
||||
def prefetch_rate_limited(self) -> bool:
|
||||
"""
|
||||
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
||||
"""
|
||||
# cancel prefetch if too much memory is occupied
|
||||
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
||||
return False
|
||||
return True
|
||||
# todo: more sophisticated rate limiting based on storage backend performance
|
||||
return True
|
||||
return False
|
||||
|
||||
def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]:
|
||||
def _storage_hit_query(self, operation) -> tuple[list[str], int]:
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
|
||||
storage_query_count = 0
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_query_count : storage_query_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
|
||||
for start in range(
|
||||
0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
|
||||
):
|
||||
end = min(
|
||||
start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
|
||||
)
|
||||
hash_value.append(last_hash)
|
||||
storage_query_count += self.page_size
|
||||
remaining_tokens -= self.page_size
|
||||
# deferring to batch exists
|
||||
hit_page_num = self.storage_backend.batch_exists(hash_value)
|
||||
return hash_value[:hit_page_num], hit_page_num * self.page_size
|
||||
batch_tokens = tokens_to_fetch[start:end]
|
||||
batch_hashes = []
|
||||
for i in range(0, len(batch_tokens), self.page_size):
|
||||
last_hash = self.get_hash_str(
|
||||
batch_tokens[i : i + self.page_size], last_hash
|
||||
)
|
||||
batch_hashes.append(last_hash)
|
||||
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
|
||||
hash_value.extend(batch_hashes[:hit_page_num])
|
||||
storage_query_count += hit_page_num * self.page_size
|
||||
if hit_page_num < len(batch_hashes):
|
||||
break
|
||||
return hash_value, storage_query_count
|
||||
|
||||
def prefetch_thread_func(self):
|
||||
"""
|
||||
@@ -772,13 +764,7 @@ class HiCacheController:
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
operation.host_indices is not None
|
||||
) and self.prefetch_rate_limit_check():
|
||||
hash_value, storage_hit_count = self._generic_storage_hit_query(
|
||||
operation
|
||||
)
|
||||
|
||||
hash_value, storage_hit_count = self._storage_hit_query(operation)
|
||||
if self.tp_world_size > 1:
|
||||
storage_hit_count_tensor = torch.tensor(
|
||||
storage_hit_count, dtype=torch.int
|
||||
@@ -793,8 +779,7 @@ class HiCacheController:
|
||||
if storage_hit_count < self.prefetch_threshold:
|
||||
# not to prefetch if not enough benefits
|
||||
self.prefetch_revoke_queue.put(operation.request_id)
|
||||
if operation.host_indices is not None:
|
||||
self.mem_pool_host.free(operation.host_indices)
|
||||
self.append_host_mem_release(operation.host_indices)
|
||||
logger.debug(
|
||||
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
||||
)
|
||||
@@ -803,7 +788,9 @@ class HiCacheController:
|
||||
: (storage_hit_count // self.page_size)
|
||||
]
|
||||
# free the pre-allocated memory for pages that are not hit
|
||||
self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
|
||||
self.append_host_mem_release(
|
||||
operation.host_indices[storage_hit_count:]
|
||||
)
|
||||
operation.host_indices = operation.host_indices[:storage_hit_count]
|
||||
logger.debug(
|
||||
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
|
||||
@@ -858,21 +845,18 @@ class HiCacheController:
|
||||
# Backup batch by batch
|
||||
def _page_backup(self, operation):
|
||||
# Select the set function and batch size
|
||||
if self.is_mooncake_backend():
|
||||
if self.storage_backend_type == "mooncake":
|
||||
backup_set_func = self._mooncake_page_set
|
||||
batch_size = 128
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
backup_set_func = self._3fs_zero_copy_page_set
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
backup_set_func = self._generic_page_set
|
||||
batch_size = 128
|
||||
elif (
|
||||
self.storage_backend_type == "hf3fs"
|
||||
and self.mem_pool_host.layout == "page_first"
|
||||
):
|
||||
backup_set_func = self._3fs_zero_copy_page_set
|
||||
else:
|
||||
backup_set_func = self._generic_page_set
|
||||
batch_size = 8
|
||||
# Backup batch by batch
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
batch_hashes = operation.hash_value[i : i + batch_size]
|
||||
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
||||
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
||||
batch_host_indices = operation.host_indices[
|
||||
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
||||
]
|
||||
@@ -898,27 +882,7 @@ class HiCacheController:
|
||||
|
||||
if not self.backup_skip:
|
||||
self._page_backup(operation)
|
||||
min_completed_tokens = operation.completed_tokens
|
||||
else:
|
||||
min_completed_tokens = len(operation.token_ids)
|
||||
|
||||
if self.tp_world_size > 1:
|
||||
completed_tokens_tensor = torch.tensor(
|
||||
min_completed_tokens, dtype=torch.int
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
completed_tokens_tensor,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.backup_tp_group,
|
||||
)
|
||||
min_completed_tokens = completed_tokens_tensor.item()
|
||||
|
||||
self.ack_backup_queue.put(
|
||||
(
|
||||
operation.id,
|
||||
min_completed_tokens,
|
||||
)
|
||||
)
|
||||
self.ack_backup_queue.put(operation.id)
|
||||
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
@@ -104,9 +104,6 @@ class HiRadixCache(RadixCache):
|
||||
self.write_through_threshold = (
|
||||
1 if hicache_write_policy == "write_through" else 2
|
||||
)
|
||||
self.write_through_threshold_storage = (
|
||||
1 if hicache_write_policy == "write_through" else 3
|
||||
)
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||
@@ -174,14 +171,6 @@ class HiRadixCache(RadixCache):
|
||||
if node.hit_count >= self.write_through_threshold:
|
||||
# write to host if the node is not backuped
|
||||
self.write_backup(node)
|
||||
else:
|
||||
if (
|
||||
self.enable_storage
|
||||
and (not node.backuped_storage)
|
||||
and node.hit_count >= self.write_through_threshold_storage
|
||||
):
|
||||
# if the node is backuped on host memory but not on storage
|
||||
self.write_backup_storage(node)
|
||||
|
||||
def writing_check(self, write_back=False):
|
||||
if write_back:
|
||||
@@ -202,8 +191,11 @@ class HiRadixCache(RadixCache):
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
||||
backuped_node = self.ongoing_write_through[ack_id]
|
||||
self.dec_lock_ref(backuped_node)
|
||||
del self.ongoing_write_through[ack_id]
|
||||
if self.enable_storage:
|
||||
self.write_backup_storage(backuped_node)
|
||||
|
||||
def loading_check(self):
|
||||
while not self.cache_controller.ack_load_queue.empty():
|
||||
@@ -386,57 +378,54 @@ class HiRadixCache(RadixCache):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
if self.enable_storage:
|
||||
self.check_revoked_prefetch()
|
||||
self.check_backup_progress()
|
||||
self.drain_storage_control_queues()
|
||||
|
||||
def check_revoked_prefetch(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
||||
def drain_storage_control_queues(self):
|
||||
"""
|
||||
Combine prefetch revoke, backup ack, and host mem release checks
|
||||
to minimize TP synchronization and Python overhead.
|
||||
"""
|
||||
cc = self.cache_controller
|
||||
|
||||
qsizes = torch.tensor(
|
||||
[
|
||||
cc.prefetch_revoke_queue.qsize(),
|
||||
cc.ack_backup_queue.qsize(),
|
||||
cc.host_mem_release_queue.qsize(),
|
||||
],
|
||||
dtype=torch.int,
|
||||
)
|
||||
if self.tp_world_size > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||
if req_id in self.ongoing_prefetch:
|
||||
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
||||
|
||||
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
|
||||
|
||||
# process prefetch revokes
|
||||
for _ in range(n_revoke):
|
||||
req_id = cc.prefetch_revoke_queue.get()
|
||||
info = self.ongoing_prefetch.pop(req_id, None)
|
||||
if info is not None:
|
||||
last_host_node, token_ids, _, _ = info
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||
else:
|
||||
# the revoked operation already got terminated
|
||||
pass
|
||||
cc.prefetch_tokens_occupied -= len(token_ids)
|
||||
# else: the revoked operation already got terminated, nothing to do
|
||||
|
||||
def check_backup_progress(self):
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if self.tp_world_size > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
|
||||
host_node = self.ongoing_backup[ack_id]
|
||||
# process backup acks
|
||||
for _ in range(n_backup):
|
||||
ack_id = cc.ack_backup_queue.get()
|
||||
entry = self.ongoing_backup.pop(ack_id, None)
|
||||
if entry is not None:
|
||||
entry.release_host()
|
||||
|
||||
if completed_tokens > 0:
|
||||
if completed_tokens < len(host_node.key):
|
||||
# backup is only partially successful, split the node
|
||||
new_node = self._split_node(
|
||||
host_node.key, host_node, completed_tokens
|
||||
)
|
||||
new_node.backuped_storage = True
|
||||
else:
|
||||
host_node.backuped_storage = True
|
||||
host_node.release_host()
|
||||
del self.ongoing_backup[ack_id]
|
||||
# release host memory
|
||||
host_indices_list = []
|
||||
for _ in range(n_release):
|
||||
host_indices_list.append(cc.host_mem_release_queue.get())
|
||||
if host_indices_list:
|
||||
host_indices = torch.cat(host_indices_list, dim=0)
|
||||
cc.mem_pool_host.free(host_indices)
|
||||
|
||||
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
||||
can_terminate = True
|
||||
@@ -519,7 +508,7 @@ class HiRadixCache(RadixCache):
|
||||
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
||||
|
||||
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||
self.cache_controller.mem_pool_host.free(
|
||||
self.cache_controller.append_host_mem_release(
|
||||
host_indices[min_completed_tokens:completed_tokens]
|
||||
)
|
||||
last_host_node.release_host()
|
||||
@@ -575,7 +564,11 @@ class HiRadixCache(RadixCache):
|
||||
len(new_input_tokens) % self.page_size
|
||||
)
|
||||
new_input_tokens = new_input_tokens[:prefetch_length]
|
||||
if not self.enable_storage or prefetch_length < self.prefetch_threshold:
|
||||
if (
|
||||
not self.enable_storage
|
||||
or prefetch_length < self.prefetch_threshold
|
||||
or self.cache_controller.prefetch_rate_limited()
|
||||
):
|
||||
return
|
||||
|
||||
last_host_node.protect_host()
|
||||
@@ -583,6 +576,10 @@ class HiRadixCache(RadixCache):
|
||||
if host_indices is None:
|
||||
self.evict_host(prefetch_length)
|
||||
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
||||
if host_indices is None:
|
||||
last_host_node.release_host()
|
||||
# no sufficient host memory for prefetch
|
||||
return
|
||||
operation = self.cache_controller.prefetch(
|
||||
req_id, host_indices, new_input_tokens, last_hash
|
||||
)
|
||||
|
||||
@@ -62,7 +62,6 @@ class TreeNode:
|
||||
self.host_value: Optional[torch.Tensor] = None
|
||||
# store hash values of each pages
|
||||
self.hash_value: Optional[List[str]] = None
|
||||
self.backuped_storage = False
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
|
||||
Reference in New Issue
Block a user