[Fix] HiCache Bugfix & Mooncake Error Handling Enhance (#8901)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
ykwd
2025-08-26 10:05:10 +08:00
committed by GitHub
parent 9b08d975a0
commit 80dc76e11a
5 changed files with 284 additions and 245 deletions

View File

@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
logger = logging.getLogger(__name__)
@@ -240,28 +240,38 @@ class HiCacheController:
self.io_backend = io_backend
self.enable_storage = False
self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
# todo: move backend initialization to storage backend module
if storage_backend is not None:
self.storage_backend_type = storage_backend
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
from sglang.srt.mem_cache.hicache_storage import get_hash_str
self.get_hash_str = get_hash_str
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
# In MLA backend, only one rank needs to backup the KV cache
self.backup_skip = (
is_mla_backend
# todo: for load balancing, decide which rank to backup the KV cache by hash value
and get_tensor_model_parallel_rank() != 0
# todo: support other storage backends
and self.storage_backend_type in ["file", "mooncake"]
)
if storage_backend == "file":
self.storage_backend = HiCacheFile(is_mla=self.is_mla)
self.get_hash_str = get_hash_str
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
self.storage_backend = HiCacheFile(is_mla_backend=is_mla_backend)
elif storage_backend == "nixl":
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
self.storage_backend = HiCacheNixl()
self.get_hash_str = get_hash_str
elif storage_backend == "mooncake":
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
MooncakeStore,
get_hash_str_mooncake,
)
self.storage_backend = MooncakeStore(is_mla=self.is_mla)
self.get_hash_str = get_hash_str_mooncake
self.storage_backend = MooncakeStore(is_mla_backend=is_mla_backend)
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs":
@@ -281,7 +291,6 @@ class HiCacheController:
self.storage_backend = HiCacheHF3FS.from_env_config(
bytes_per_page, dtype
)
self.get_hash_str = get_hash_str
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
@@ -400,15 +409,6 @@ class HiCacheController:
self.prefetch_thread.start()
self.backup_thread.start()
@property
def backup_skip(self):
return (
self.is_mla
and get_tensor_model_parallel_rank() != 0
# todo: only support file and mooncake
and self.storage_backend_type in ["file", "mooncake"]
)
def write(
self,
device_indices: torch.Tensor,
@@ -570,57 +570,91 @@ class HiCacheController:
operation.mark_done()
return operation.completed_tokens, operation.hash_value
def zerocopy_page_transfer(self, operation, batch_size=8):
# zero copy
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
operation.hash_value, operation.host_indices
hash_values, host_indices
)
for i in range(0, len(hashes), batch_size):
page_hashes = hashes[i : i + batch_size]
page_dsts = dsts[i : i + batch_size]
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
)
break
completed_tokens = operation.completed_tokens
if operation.increment(self.page_size * len(page_hashes)):
for i in range(len(page_hashes)):
completed_tokens += self.page_size
else:
break
page_data = self.storage_backend.batch_get(hashes, dsts)
if page_data:
operation.increment(self.page_size * len(hashes))
else:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
)
def generic_page_transfer(self, operation, batch_size=8):
for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size]
# todo: zero copy
dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page()
for _ in range(len(page_hashes))
]
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
)
break
completed_tokens = operation.completed_tokens
if operation.increment(self.page_size * len(page_hashes)):
for i in range(len(page_hashes)):
self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[completed_tokens],
page_data[i],
)
completed_tokens += self.page_size
else:
break
def mooncake_page_transfer(self, operation):
# zero copy
def _mooncake_page_get(self, operation, hash_values, host_indices):
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
operation.hash_value, operation.host_indices
hash_values,
host_indices,
)
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
operation.increment(len(operation.hash_value) * self.page_size)
get_result = self.storage_backend.batch_get(
key_strs,
target_location=buffer_ptrs,
target_sizes=buffer_sizes,
)
if get_result != len(hash_values):
logger.warning(
f"Prefetch operation {operation.request_id} failed or partially failed."
)
if get_result != 0:
operation.increment(get_result * self.page_size)
# non-zero copy
def _generic_page_get(self, operation, hash_values, host_indices):
# todo: zero copy
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
hash_values
)
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
if page_data is None:
return
for i in range(len(hash_values)):
if page_data[i] is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
)
break
self.mem_pool_host.set_from_flat_data_page(
host_indices[operation.completed_tokens],
page_data[i],
)
if not operation.increment(self.page_size):
break # Operation terminated by controller
def _page_transfer(self, operation):
# Select the get function and batch size
if self.is_mooncake_backend():
get_func = self._mooncake_page_get
batch_size = 128
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
get_func = self._3fs_zero_copy_page_get
elif self.mem_pool_host.layout == "layer_first":
get_func = self._generic_page_get
batch_size = 128
else:
get_func = self._generic_page_get
batch_size = 8
# Transfer batch by batch
for i in range(0, len(operation.hash_value), batch_size):
batch_hashes = operation.hash_value[i : i + batch_size]
batch_host_indices = operation.host_indices[
i * self.page_size : (i + len(batch_hashes)) * self.page_size
]
prev_completed_tokens = operation.completed_tokens
# Get one batch token, and update the completed_tokens if succeed
get_func(operation, batch_hashes, batch_host_indices)
# Check termination
if (
operation.completed_tokens
!= prev_completed_tokens + len(batch_hashes) * self.page_size
):
break # Some operations fail or operation terminated by controller
# release pre-allocated memory
self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :])
def is_mooncake_backend(self):
return self.storage_backend_type == "mooncake"
@@ -632,15 +666,7 @@ class HiCacheController:
while not self.stop_event.is_set():
try:
operation = self.prefetch_buffer.get(block=True, timeout=1)
if self.is_mooncake_backend():
self.mooncake_page_transfer(operation)
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_transfer(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_transfer(operation, batch_size=128)
else:
self.generic_page_transfer(operation)
self._page_transfer(operation)
if self.tp_world_size > 1:
# to ensure all TP workers release the host memory at the same time
@@ -662,6 +688,27 @@ class HiCacheController:
# todo: more sophisticated rate limiting based on storage backend performance
return True
def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]:
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
storage_query_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_query_count : storage_query_count + self.page_size
],
last_hash,
)
hash_value.append(last_hash)
storage_query_count += self.page_size
remaining_tokens -= self.page_size
# deferring to batch exists
hit_page_num = self.storage_backend.batch_exists(hash_value)
return hash_value[:hit_page_num], hit_page_num * self.page_size
def prefetch_thread_func(self):
"""
Manage prefetching operations from storage backend to host memory.
@@ -675,38 +722,12 @@ class HiCacheController:
if operation is None:
continue
storage_hit_count = 0
if (
operation.host_indices is not None
) and self.prefetch_rate_limit_check():
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
# todo, more unified interface
if not self.is_mooncake_backend():
if not self.storage_backend.exists(last_hash):
break
hash_value.append(last_hash)
storage_hit_count += self.page_size
remaining_tokens -= self.page_size
if self.is_mooncake_backend():
# deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0)
* self.page_size
)
hash_value, storage_hit_count = self._generic_storage_hit_query(
operation
)
if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor(
@@ -755,59 +776,64 @@ class HiCacheController:
self.backup_queue.put(operation)
return operation.id
def zerocopy_page_backup(self, operation, batch_size=8):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
operation.hash_value, operation.host_indices
# non-zero copy
def _generic_page_set(self, hash_values, host_indices) -> bool:
data = [
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values))
]
return self.storage_backend.batch_set(hash_values, data)
# zero copy
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
)
for i in range(0, len(hashes), batch_size):
page_hashes = hashes[i : i + batch_size]
page_data = dsts[i : i + batch_size]
success = self.storage_backend.batch_set(page_hashes, page_data)
if not success:
logger.warning(f"Failed to write page {page_hashes} to storage.")
break
operation.completed_tokens += self.page_size * len(page_hashes)
success = self.storage_backend.batch_set(
key_strs,
target_location=buffer_ptrs,
target_sizes=buffer_sizes,
)
return success
def generic_page_backup(self, operation, batch_size=8):
# zero copy
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
)
return self.storage_backend.batch_set(hashes, dsts)
# Backup batch by batch
def _page_backup(self, operation):
# Select the set function and batch size
if self.is_mooncake_backend():
backup_set_func = self._mooncake_page_set
batch_size = 128
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
backup_set_func = self._3fs_zero_copy_page_set
elif self.mem_pool_host.layout == "layer_first":
backup_set_func = self._generic_page_set
batch_size = 128
else:
backup_set_func = self._generic_page_set
batch_size = 8
# Backup batch by batch
for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size]
page_data = [
self.mem_pool_host.get_flat_data_page(
operation.host_indices[j * self.page_size]
)
for j in range(i, i + len(page_hashes))
batch_hashes = operation.hash_value[i : i + batch_size]
batch_host_indices = operation.host_indices[
i * self.page_size : (i + len(batch_hashes)) * self.page_size
]
success = self.storage_backend.batch_set(page_hashes, page_data)
# Set one batch token, and record if success.
# todo: allow partial success
success = backup_set_func(batch_hashes, batch_host_indices)
if not success:
logger.warning(f"Failed to write page {page_hashes} to storage.")
logger.warning(
f"Write page to storage: {len(batch_hashes)} pages failed."
)
break
operation.completed_tokens += self.page_size * len(page_hashes)
def mooncake_page_backup(self, operation):
if len(operation.hash_value):
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
indices = operation.host_indices.tolist()
non_exist_keys = []
non_exist_indices = []
for i in range(len(operation.hash_value)):
if not exist_hashvalues[operation.hash_value[i]]:
non_exist_keys.append(operation.hash_value[i])
non_exist_indices.extend(
indices[i * self.page_size : (i + 1) * self.page_size]
)
if len(non_exist_keys) > 0:
key_strs, buffer_ptrs, buffer_sizes = (
self.mem_pool_host.get_buffer_meta(
non_exist_keys, non_exist_indices
)
)
# TODO: check the return value of batch set to see how many tokens are set successfully
self.storage_backend.batch_set(
key_strs,
target_location=buffer_ptrs,
target_sizes=buffer_sizes,
)
operation.completed_tokens += len(operation.hash_value) * self.page_size
operation.completed_tokens += self.page_size * len(batch_hashes)
def backup_thread_func(self):
"""
@@ -820,15 +846,7 @@ class HiCacheController:
continue
if not self.backup_skip:
if self.is_mooncake_backend():
self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_backup(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_backup(operation, batch_size=128)
else:
self.generic_page_backup(operation)
self._page_backup(operation)
min_completed_tokens = operation.completed_tokens
else:
min_completed_tokens = len(operation.token_ids)