Support l3 cache (mooncake store) for hiradix cache (#7211)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: AniZpZ <zhuangsen.zp@antgroup.com> Co-authored-by: zuoyuan <zhangzuo21@mails.tsinghua.edu.cn> Co-authored-by: @wangyueneng.wyn <wangyueneng.wyn@antgroup.com> Co-authored-by: JinYan Su <jinyansu792@gmail.com>
This commit is contained in:
@@ -26,6 +26,10 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
||||
MooncakeStore,
|
||||
get_hash_str_mooncake,
|
||||
)
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -125,7 +129,7 @@ class TransferBuffer:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
|
||||
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
|
||||
) -> None:
|
||||
self.stop_event = stop_event
|
||||
self.buffers = Queue(maxsize=buffer_count)
|
||||
@@ -260,6 +264,11 @@ class HiCacheController:
|
||||
|
||||
if storage_backend == "file":
|
||||
self.storage_backend = HiCacheFile()
|
||||
self.get_hash_str = get_hash_str
|
||||
elif storage_backend == "mooncake":
|
||||
self.storage_backend = MooncakeStore()
|
||||
self.get_hash_str = get_hash_str_mooncake
|
||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||
elif storage_backend == "hf3fs":
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
@@ -271,6 +280,7 @@ class HiCacheController:
|
||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||
rank, bytes_per_page, dtype
|
||||
)
|
||||
self.get_hash_str = get_hash_str
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported storage backend: {storage_backend}"
|
||||
@@ -532,6 +542,37 @@ class HiCacheController:
|
||||
operation.mark_done()
|
||||
return operation.completed_tokens, operation.hash_value
|
||||
|
||||
def generic_page_transfer(self, operation, batch_size=8):
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
page_data = self.storage_backend.batch_get(page_hashes)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
||||
)
|
||||
break
|
||||
completed_tokens = operation.completed_tokens
|
||||
if operation.increment(self.page_size * len(page_hashes)):
|
||||
for i in range(len(page_hashes)):
|
||||
self.mem_pool_host.set_from_flat_data_page(
|
||||
operation.host_indices[completed_tokens],
|
||||
page_data[i],
|
||||
)
|
||||
completed_tokens += self.page_size
|
||||
else:
|
||||
# operation terminated by controller, release pre-allocated memory
|
||||
self.mem_pool_host.free(
|
||||
operation.host_indices[operation.completed_tokens :]
|
||||
)
|
||||
break
|
||||
|
||||
def mooncake_page_transfer(self, operation):
|
||||
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||
operation.hash_value, operation.host_indices
|
||||
)
|
||||
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
|
||||
operation.increment(len(operation.hash_value) * self.page_size)
|
||||
|
||||
def prefetch_io_aux_func(self):
|
||||
"""
|
||||
Auxiliary function conducting IO operations for prefetching.
|
||||
@@ -539,26 +580,10 @@ class HiCacheController:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||
page_datas = self.storage_backend.batch_get(operation.hash_value)
|
||||
for h, page_data in zip(operation.hash_value, page_datas):
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
||||
)
|
||||
break
|
||||
if operation.increment(self.page_size):
|
||||
self.mem_pool_host.set_from_flat_data_page(
|
||||
operation.host_indices[
|
||||
operation.completed_tokens - self.page_size
|
||||
],
|
||||
page_data,
|
||||
)
|
||||
else:
|
||||
# operation terminated by controller, release pre-allocated memory
|
||||
self.mem_pool_host.free(
|
||||
operation.host_indices[operation.completed_tokens :]
|
||||
)
|
||||
break
|
||||
if isinstance(self.storage_backend, MooncakeStore):
|
||||
self.mooncake_page_transfer(operation)
|
||||
else:
|
||||
self.generic_page_transfer(operation)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
@@ -582,18 +607,27 @@ class HiCacheController:
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = get_hash_str(
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_hit_count : storage_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
if self.storage_backend.exists(last_hash):
|
||||
storage_hit_count += self.page_size
|
||||
hash_value.append(last_hash)
|
||||
remaining_tokens -= self.page_size
|
||||
else:
|
||||
break
|
||||
|
||||
# todo, more unified interface
|
||||
if not isinstance(self.storage_backend, MooncakeStore):
|
||||
if not self.storage_backend.exists(last_hash):
|
||||
break
|
||||
hash_value.append(last_hash)
|
||||
storage_hit_count += self.page_size
|
||||
remaining_tokens -= self.page_size
|
||||
|
||||
if isinstance(self.storage_backend, MooncakeStore):
|
||||
# deferring to batch exists for mooncake store
|
||||
exist_result = self.storage_backend.exists(hash_value)
|
||||
storage_hit_count = (
|
||||
sum(1 for v in exist_result.values() if v != 0) * self.page_size
|
||||
)
|
||||
|
||||
if self.tp_world_size > 1:
|
||||
storage_hit_count_tensor = torch.tensor(
|
||||
@@ -641,6 +675,47 @@ class HiCacheController:
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
def generic_page_backup(self, operation, batch_size=8):
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
page_data = [
|
||||
self.mem_pool_host.get_flat_data_pages(
|
||||
operation.host_indices[j * self.page_size]
|
||||
)
|
||||
for j in range(i, i + len(page_hashes))
|
||||
]
|
||||
success = self.storage_backend.batch_set(page_hashes, page_data)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
||||
break
|
||||
operation.completed_tokens += self.page_size * len(page_hashes)
|
||||
|
||||
def mooncake_page_backup(self, operation):
|
||||
if len(operation.hash_value):
|
||||
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
|
||||
indices = operation.host_indices.tolist()
|
||||
non_exist_keys = []
|
||||
non_exist_indices = []
|
||||
for i in range(len(operation.hash_value)):
|
||||
if not exist_hashvalues[operation.hash_value[i]]:
|
||||
non_exist_keys.append(operation.hash_value[i])
|
||||
non_exist_indices.extend(
|
||||
indices[i * self.page_size : (i + 1) * self.page_size]
|
||||
)
|
||||
if len(non_exist_keys) > 0:
|
||||
key_strs, buffer_ptrs, buffer_sizes = (
|
||||
self.mem_pool_host.get_buffer_meta(
|
||||
non_exist_keys, non_exist_indices
|
||||
)
|
||||
)
|
||||
# TODO: check the return value of batch set to see how many tokens are set successfully
|
||||
self.storage_backend.batch_set(
|
||||
key_strs,
|
||||
target_location=buffer_ptrs,
|
||||
target_sizes=buffer_sizes,
|
||||
)
|
||||
operation.completed_tokens += len(operation.hash_value) * self.page_size
|
||||
|
||||
def backup_thread_func(self):
|
||||
"""
|
||||
Manage backup operations from host memory to storage backend.
|
||||
@@ -654,23 +729,25 @@ class HiCacheController:
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_backup = operation.token_ids
|
||||
|
||||
last_hashes, data_pages = [], []
|
||||
for i in range(0, len(tokens_to_backup), self.page_size):
|
||||
last_hash = get_hash_str(
|
||||
tokens_to_backup[i : i + self.page_size], last_hash
|
||||
backup_hit_count = 0
|
||||
remaining_tokens = len(tokens_to_backup)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_backup[
|
||||
backup_hit_count : backup_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
data_page = self.mem_pool_host.get_flat_data_page(
|
||||
operation.host_indices[i]
|
||||
)
|
||||
last_hashes.append(last_hash)
|
||||
data_pages.append(data_page)
|
||||
backup_hit_count += self.page_size
|
||||
hash_value.append(last_hash)
|
||||
remaining_tokens -= self.page_size
|
||||
operation.hash_value = hash_value
|
||||
|
||||
success = self.storage_backend.batch_set(last_hashes, data_pages)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {last_hashes} to storage.")
|
||||
if isinstance(self.storage_backend, MooncakeStore):
|
||||
self.mooncake_page_backup(operation)
|
||||
else:
|
||||
operation.completed_tokens += len(tokens_to_backup)
|
||||
operation.hash_value.extend(last_hashes)
|
||||
self.generic_page_backup(operation)
|
||||
|
||||
min_completed_tokens = operation.completed_tokens
|
||||
if self.tp_world_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user