[Fix] HiCache Bugfix & Mooncake Error Handling Enhance (#8901)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -240,28 +240,38 @@ class HiCacheController:
|
||||
self.io_backend = io_backend
|
||||
|
||||
self.enable_storage = False
|
||||
self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
|
||||
|
||||
# todo: move backend initialization to storage backend module
|
||||
if storage_backend is not None:
|
||||
self.storage_backend_type = storage_backend
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||
from sglang.srt.mem_cache.hicache_storage import get_hash_str
|
||||
|
||||
self.get_hash_str = get_hash_str
|
||||
|
||||
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
||||
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
||||
# In MLA backend, only one rank needs to backup the KV cache
|
||||
self.backup_skip = (
|
||||
is_mla_backend
|
||||
# todo: for load balancing, decide which rank to backup the KV cache by hash value
|
||||
and get_tensor_model_parallel_rank() != 0
|
||||
# todo: support other storage backends
|
||||
and self.storage_backend_type in ["file", "mooncake"]
|
||||
)
|
||||
if storage_backend == "file":
|
||||
self.storage_backend = HiCacheFile(is_mla=self.is_mla)
|
||||
self.get_hash_str = get_hash_str
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
|
||||
|
||||
self.storage_backend = HiCacheFile(is_mla_backend=is_mla_backend)
|
||||
elif storage_backend == "nixl":
|
||||
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||
|
||||
self.storage_backend = HiCacheNixl()
|
||||
self.get_hash_str = get_hash_str
|
||||
elif storage_backend == "mooncake":
|
||||
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
||||
MooncakeStore,
|
||||
get_hash_str_mooncake,
|
||||
)
|
||||
|
||||
self.storage_backend = MooncakeStore(is_mla=self.is_mla)
|
||||
self.get_hash_str = get_hash_str_mooncake
|
||||
self.storage_backend = MooncakeStore(is_mla_backend=is_mla_backend)
|
||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||
assert self.mem_pool_host.layout == "page_first"
|
||||
elif storage_backend == "hf3fs":
|
||||
@@ -281,7 +291,6 @@ class HiCacheController:
|
||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||
bytes_per_page, dtype
|
||||
)
|
||||
self.get_hash_str = get_hash_str
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported storage backend: {storage_backend}"
|
||||
@@ -400,15 +409,6 @@ class HiCacheController:
|
||||
self.prefetch_thread.start()
|
||||
self.backup_thread.start()
|
||||
|
||||
@property
|
||||
def backup_skip(self):
|
||||
return (
|
||||
self.is_mla
|
||||
and get_tensor_model_parallel_rank() != 0
|
||||
# todo: only support file and mooncake
|
||||
and self.storage_backend_type in ["file", "mooncake"]
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
device_indices: torch.Tensor,
|
||||
@@ -570,57 +570,91 @@ class HiCacheController:
|
||||
operation.mark_done()
|
||||
return operation.completed_tokens, operation.hash_value
|
||||
|
||||
def zerocopy_page_transfer(self, operation, batch_size=8):
|
||||
# zero copy
|
||||
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||
operation.hash_value, operation.host_indices
|
||||
hash_values, host_indices
|
||||
)
|
||||
for i in range(0, len(hashes), batch_size):
|
||||
page_hashes = hashes[i : i + batch_size]
|
||||
page_dsts = dsts[i : i + batch_size]
|
||||
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
||||
)
|
||||
break
|
||||
completed_tokens = operation.completed_tokens
|
||||
if operation.increment(self.page_size * len(page_hashes)):
|
||||
for i in range(len(page_hashes)):
|
||||
completed_tokens += self.page_size
|
||||
else:
|
||||
break
|
||||
page_data = self.storage_backend.batch_get(hashes, dsts)
|
||||
if page_data:
|
||||
operation.increment(self.page_size * len(hashes))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
||||
)
|
||||
|
||||
def generic_page_transfer(self, operation, batch_size=8):
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
# todo: zero copy
|
||||
dummy_page_dst = [
|
||||
self.mem_pool_host.get_dummy_flat_data_page()
|
||||
for _ in range(len(page_hashes))
|
||||
]
|
||||
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
||||
)
|
||||
break
|
||||
completed_tokens = operation.completed_tokens
|
||||
if operation.increment(self.page_size * len(page_hashes)):
|
||||
for i in range(len(page_hashes)):
|
||||
self.mem_pool_host.set_from_flat_data_page(
|
||||
operation.host_indices[completed_tokens],
|
||||
page_data[i],
|
||||
)
|
||||
completed_tokens += self.page_size
|
||||
else:
|
||||
break
|
||||
|
||||
def mooncake_page_transfer(self, operation):
|
||||
# zero copy
|
||||
def _mooncake_page_get(self, operation, hash_values, host_indices):
|
||||
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||
operation.hash_value, operation.host_indices
|
||||
hash_values,
|
||||
host_indices,
|
||||
)
|
||||
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
|
||||
operation.increment(len(operation.hash_value) * self.page_size)
|
||||
get_result = self.storage_backend.batch_get(
|
||||
key_strs,
|
||||
target_location=buffer_ptrs,
|
||||
target_sizes=buffer_sizes,
|
||||
)
|
||||
if get_result != len(hash_values):
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed or partially failed."
|
||||
)
|
||||
if get_result != 0:
|
||||
operation.increment(get_result * self.page_size)
|
||||
|
||||
# non-zero copy
|
||||
def _generic_page_get(self, operation, hash_values, host_indices):
|
||||
# todo: zero copy
|
||||
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
||||
hash_values
|
||||
)
|
||||
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
||||
if page_data is None:
|
||||
return
|
||||
for i in range(len(hash_values)):
|
||||
if page_data[i] is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
||||
)
|
||||
break
|
||||
self.mem_pool_host.set_from_flat_data_page(
|
||||
host_indices[operation.completed_tokens],
|
||||
page_data[i],
|
||||
)
|
||||
if not operation.increment(self.page_size):
|
||||
break # Operation terminated by controller
|
||||
|
||||
def _page_transfer(self, operation):
|
||||
# Select the get function and batch size
|
||||
if self.is_mooncake_backend():
|
||||
get_func = self._mooncake_page_get
|
||||
batch_size = 128
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
get_func = self._3fs_zero_copy_page_get
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
get_func = self._generic_page_get
|
||||
batch_size = 128
|
||||
else:
|
||||
get_func = self._generic_page_get
|
||||
batch_size = 8
|
||||
|
||||
# Transfer batch by batch
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
batch_hashes = operation.hash_value[i : i + batch_size]
|
||||
batch_host_indices = operation.host_indices[
|
||||
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
||||
]
|
||||
prev_completed_tokens = operation.completed_tokens
|
||||
# Get one batch token, and update the completed_tokens if succeed
|
||||
get_func(operation, batch_hashes, batch_host_indices)
|
||||
# Check termination
|
||||
if (
|
||||
operation.completed_tokens
|
||||
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
||||
):
|
||||
break # Some operations fail or operation terminated by controller
|
||||
# release pre-allocated memory
|
||||
self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :])
|
||||
|
||||
def is_mooncake_backend(self):
|
||||
return self.storage_backend_type == "mooncake"
|
||||
@@ -632,15 +666,7 @@ class HiCacheController:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_transfer(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
self.zerocopy_page_transfer(operation, batch_size=128)
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
self.generic_page_transfer(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_transfer(operation)
|
||||
self._page_transfer(operation)
|
||||
|
||||
if self.tp_world_size > 1:
|
||||
# to ensure all TP workers release the host memory at the same time
|
||||
@@ -662,6 +688,27 @@ class HiCacheController:
|
||||
# todo: more sophisticated rate limiting based on storage backend performance
|
||||
return True
|
||||
|
||||
def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]:
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
|
||||
storage_query_count = 0
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_query_count : storage_query_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
hash_value.append(last_hash)
|
||||
storage_query_count += self.page_size
|
||||
remaining_tokens -= self.page_size
|
||||
# deferring to batch exists
|
||||
hit_page_num = self.storage_backend.batch_exists(hash_value)
|
||||
return hash_value[:hit_page_num], hit_page_num * self.page_size
|
||||
|
||||
def prefetch_thread_func(self):
|
||||
"""
|
||||
Manage prefetching operations from storage backend to host memory.
|
||||
@@ -675,38 +722,12 @@ class HiCacheController:
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
storage_hit_count = 0
|
||||
if (
|
||||
operation.host_indices is not None
|
||||
) and self.prefetch_rate_limit_check():
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = self.get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_hit_count : storage_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
|
||||
# todo, more unified interface
|
||||
if not self.is_mooncake_backend():
|
||||
if not self.storage_backend.exists(last_hash):
|
||||
break
|
||||
hash_value.append(last_hash)
|
||||
storage_hit_count += self.page_size
|
||||
remaining_tokens -= self.page_size
|
||||
|
||||
if self.is_mooncake_backend():
|
||||
# deferring to batch exists for mooncake store
|
||||
exist_result = self.storage_backend.exists(hash_value)
|
||||
storage_hit_count = (
|
||||
sum(1 for v in exist_result.values() if v != 0)
|
||||
* self.page_size
|
||||
)
|
||||
hash_value, storage_hit_count = self._generic_storage_hit_query(
|
||||
operation
|
||||
)
|
||||
|
||||
if self.tp_world_size > 1:
|
||||
storage_hit_count_tensor = torch.tensor(
|
||||
@@ -755,59 +776,64 @@ class HiCacheController:
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
def zerocopy_page_backup(self, operation, batch_size=8):
|
||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||
operation.hash_value, operation.host_indices
|
||||
# non-zero copy
|
||||
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
||||
data = [
|
||||
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
|
||||
for i in range(len(hash_values))
|
||||
]
|
||||
return self.storage_backend.batch_set(hash_values, data)
|
||||
|
||||
# zero copy
|
||||
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
|
||||
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||
hash_values,
|
||||
host_indices,
|
||||
)
|
||||
for i in range(0, len(hashes), batch_size):
|
||||
page_hashes = hashes[i : i + batch_size]
|
||||
page_data = dsts[i : i + batch_size]
|
||||
success = self.storage_backend.batch_set(page_hashes, page_data)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
||||
break
|
||||
operation.completed_tokens += self.page_size * len(page_hashes)
|
||||
success = self.storage_backend.batch_set(
|
||||
key_strs,
|
||||
target_location=buffer_ptrs,
|
||||
target_sizes=buffer_sizes,
|
||||
)
|
||||
return success
|
||||
|
||||
def generic_page_backup(self, operation, batch_size=8):
|
||||
# zero copy
|
||||
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||
hash_values, host_indices
|
||||
)
|
||||
return self.storage_backend.batch_set(hashes, dsts)
|
||||
|
||||
# Backup batch by batch
|
||||
def _page_backup(self, operation):
|
||||
# Select the set function and batch size
|
||||
if self.is_mooncake_backend():
|
||||
backup_set_func = self._mooncake_page_set
|
||||
batch_size = 128
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
backup_set_func = self._3fs_zero_copy_page_set
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
backup_set_func = self._generic_page_set
|
||||
batch_size = 128
|
||||
else:
|
||||
backup_set_func = self._generic_page_set
|
||||
batch_size = 8
|
||||
# Backup batch by batch
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
page_data = [
|
||||
self.mem_pool_host.get_flat_data_page(
|
||||
operation.host_indices[j * self.page_size]
|
||||
)
|
||||
for j in range(i, i + len(page_hashes))
|
||||
batch_hashes = operation.hash_value[i : i + batch_size]
|
||||
batch_host_indices = operation.host_indices[
|
||||
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
||||
]
|
||||
success = self.storage_backend.batch_set(page_hashes, page_data)
|
||||
# Set one batch token, and record if success.
|
||||
# todo: allow partial success
|
||||
success = backup_set_func(batch_hashes, batch_host_indices)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
||||
logger.warning(
|
||||
f"Write page to storage: {len(batch_hashes)} pages failed."
|
||||
)
|
||||
break
|
||||
operation.completed_tokens += self.page_size * len(page_hashes)
|
||||
|
||||
def mooncake_page_backup(self, operation):
|
||||
if len(operation.hash_value):
|
||||
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
|
||||
indices = operation.host_indices.tolist()
|
||||
non_exist_keys = []
|
||||
non_exist_indices = []
|
||||
for i in range(len(operation.hash_value)):
|
||||
if not exist_hashvalues[operation.hash_value[i]]:
|
||||
non_exist_keys.append(operation.hash_value[i])
|
||||
non_exist_indices.extend(
|
||||
indices[i * self.page_size : (i + 1) * self.page_size]
|
||||
)
|
||||
if len(non_exist_keys) > 0:
|
||||
key_strs, buffer_ptrs, buffer_sizes = (
|
||||
self.mem_pool_host.get_buffer_meta(
|
||||
non_exist_keys, non_exist_indices
|
||||
)
|
||||
)
|
||||
# TODO: check the return value of batch set to see how many tokens are set successfully
|
||||
self.storage_backend.batch_set(
|
||||
key_strs,
|
||||
target_location=buffer_ptrs,
|
||||
target_sizes=buffer_sizes,
|
||||
)
|
||||
operation.completed_tokens += len(operation.hash_value) * self.page_size
|
||||
operation.completed_tokens += self.page_size * len(batch_hashes)
|
||||
|
||||
def backup_thread_func(self):
|
||||
"""
|
||||
@@ -820,15 +846,7 @@ class HiCacheController:
|
||||
continue
|
||||
|
||||
if not self.backup_skip:
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_backup(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
self.zerocopy_page_backup(operation, batch_size=128)
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
self.generic_page_backup(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_backup(operation)
|
||||
self._page_backup(operation)
|
||||
min_completed_tokens = operation.completed_tokens
|
||||
else:
|
||||
min_completed_tokens = len(operation.token_ids)
|
||||
|
||||
Reference in New Issue
Block a user