[Fix] HiCache Bugfix & Mooncake Error Handling Enhance (#8901)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -240,28 +240,38 @@ class HiCacheController:
|
|||||||
self.io_backend = io_backend
|
self.io_backend = io_backend
|
||||||
|
|
||||||
self.enable_storage = False
|
self.enable_storage = False
|
||||||
self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
|
|
||||||
# todo: move backend initialization to storage backend module
|
# todo: move backend initialization to storage backend module
|
||||||
if storage_backend is not None:
|
if storage_backend is not None:
|
||||||
self.storage_backend_type = storage_backend
|
self.storage_backend_type = storage_backend
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
from sglang.srt.mem_cache.hicache_storage import get_hash_str
|
||||||
|
|
||||||
|
self.get_hash_str = get_hash_str
|
||||||
|
|
||||||
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
||||||
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
||||||
|
# In MLA backend, only one rank needs to backup the KV cache
|
||||||
|
self.backup_skip = (
|
||||||
|
is_mla_backend
|
||||||
|
# todo: for load balancing, decide which rank to backup the KV cache by hash value
|
||||||
|
and get_tensor_model_parallel_rank() != 0
|
||||||
|
# todo: support other storage backends
|
||||||
|
and self.storage_backend_type in ["file", "mooncake"]
|
||||||
|
)
|
||||||
if storage_backend == "file":
|
if storage_backend == "file":
|
||||||
self.storage_backend = HiCacheFile(is_mla=self.is_mla)
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
|
||||||
self.get_hash_str = get_hash_str
|
|
||||||
|
self.storage_backend = HiCacheFile(is_mla_backend=is_mla_backend)
|
||||||
elif storage_backend == "nixl":
|
elif storage_backend == "nixl":
|
||||||
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||||
|
|
||||||
self.storage_backend = HiCacheNixl()
|
self.storage_backend = HiCacheNixl()
|
||||||
self.get_hash_str = get_hash_str
|
|
||||||
elif storage_backend == "mooncake":
|
elif storage_backend == "mooncake":
|
||||||
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
||||||
MooncakeStore,
|
MooncakeStore,
|
||||||
get_hash_str_mooncake,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.storage_backend = MooncakeStore(is_mla=self.is_mla)
|
self.storage_backend = MooncakeStore(is_mla_backend=is_mla_backend)
|
||||||
self.get_hash_str = get_hash_str_mooncake
|
|
||||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||||
assert self.mem_pool_host.layout == "page_first"
|
assert self.mem_pool_host.layout == "page_first"
|
||||||
elif storage_backend == "hf3fs":
|
elif storage_backend == "hf3fs":
|
||||||
@@ -281,7 +291,6 @@ class HiCacheController:
|
|||||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||||
bytes_per_page, dtype
|
bytes_per_page, dtype
|
||||||
)
|
)
|
||||||
self.get_hash_str = get_hash_str
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupported storage backend: {storage_backend}"
|
f"Unsupported storage backend: {storage_backend}"
|
||||||
@@ -400,15 +409,6 @@ class HiCacheController:
|
|||||||
self.prefetch_thread.start()
|
self.prefetch_thread.start()
|
||||||
self.backup_thread.start()
|
self.backup_thread.start()
|
||||||
|
|
||||||
@property
|
|
||||||
def backup_skip(self):
|
|
||||||
return (
|
|
||||||
self.is_mla
|
|
||||||
and get_tensor_model_parallel_rank() != 0
|
|
||||||
# todo: only support file and mooncake
|
|
||||||
and self.storage_backend_type in ["file", "mooncake"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
self,
|
self,
|
||||||
device_indices: torch.Tensor,
|
device_indices: torch.Tensor,
|
||||||
@@ -570,57 +570,91 @@ class HiCacheController:
|
|||||||
operation.mark_done()
|
operation.mark_done()
|
||||||
return operation.completed_tokens, operation.hash_value
|
return operation.completed_tokens, operation.hash_value
|
||||||
|
|
||||||
def zerocopy_page_transfer(self, operation, batch_size=8):
|
# zero copy
|
||||||
|
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
||||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||||
operation.hash_value, operation.host_indices
|
hash_values, host_indices
|
||||||
)
|
)
|
||||||
for i in range(0, len(hashes), batch_size):
|
page_data = self.storage_backend.batch_get(hashes, dsts)
|
||||||
page_hashes = hashes[i : i + batch_size]
|
if page_data:
|
||||||
page_dsts = dsts[i : i + batch_size]
|
operation.increment(self.page_size * len(hashes))
|
||||||
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
|
else:
|
||||||
if page_data is None:
|
logger.warning(
|
||||||
logger.warning(
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
||||||
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
)
|
||||||
)
|
|
||||||
break
|
|
||||||
completed_tokens = operation.completed_tokens
|
|
||||||
if operation.increment(self.page_size * len(page_hashes)):
|
|
||||||
for i in range(len(page_hashes)):
|
|
||||||
completed_tokens += self.page_size
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
def generic_page_transfer(self, operation, batch_size=8):
|
# zero copy
|
||||||
for i in range(0, len(operation.hash_value), batch_size):
|
def _mooncake_page_get(self, operation, hash_values, host_indices):
|
||||||
page_hashes = operation.hash_value[i : i + batch_size]
|
|
||||||
# todo: zero copy
|
|
||||||
dummy_page_dst = [
|
|
||||||
self.mem_pool_host.get_dummy_flat_data_page()
|
|
||||||
for _ in range(len(page_hashes))
|
|
||||||
]
|
|
||||||
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
|
||||||
if page_data is None:
|
|
||||||
logger.warning(
|
|
||||||
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
|
||||||
)
|
|
||||||
break
|
|
||||||
completed_tokens = operation.completed_tokens
|
|
||||||
if operation.increment(self.page_size * len(page_hashes)):
|
|
||||||
for i in range(len(page_hashes)):
|
|
||||||
self.mem_pool_host.set_from_flat_data_page(
|
|
||||||
operation.host_indices[completed_tokens],
|
|
||||||
page_data[i],
|
|
||||||
)
|
|
||||||
completed_tokens += self.page_size
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
def mooncake_page_transfer(self, operation):
|
|
||||||
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||||
operation.hash_value, operation.host_indices
|
hash_values,
|
||||||
|
host_indices,
|
||||||
)
|
)
|
||||||
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
|
get_result = self.storage_backend.batch_get(
|
||||||
operation.increment(len(operation.hash_value) * self.page_size)
|
key_strs,
|
||||||
|
target_location=buffer_ptrs,
|
||||||
|
target_sizes=buffer_sizes,
|
||||||
|
)
|
||||||
|
if get_result != len(hash_values):
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch operation {operation.request_id} failed or partially failed."
|
||||||
|
)
|
||||||
|
if get_result != 0:
|
||||||
|
operation.increment(get_result * self.page_size)
|
||||||
|
|
||||||
|
# non-zero copy
|
||||||
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
||||||
|
# todo: zero copy
|
||||||
|
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
||||||
|
hash_values
|
||||||
|
)
|
||||||
|
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
||||||
|
if page_data is None:
|
||||||
|
return
|
||||||
|
for i in range(len(hash_values)):
|
||||||
|
if page_data[i] is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
self.mem_pool_host.set_from_flat_data_page(
|
||||||
|
host_indices[operation.completed_tokens],
|
||||||
|
page_data[i],
|
||||||
|
)
|
||||||
|
if not operation.increment(self.page_size):
|
||||||
|
break # Operation terminated by controller
|
||||||
|
|
||||||
|
def _page_transfer(self, operation):
|
||||||
|
# Select the get function and batch size
|
||||||
|
if self.is_mooncake_backend():
|
||||||
|
get_func = self._mooncake_page_get
|
||||||
|
batch_size = 128
|
||||||
|
elif self.storage_backend_type == "hf3fs":
|
||||||
|
if self.mem_pool_host.layout == "page_first":
|
||||||
|
get_func = self._3fs_zero_copy_page_get
|
||||||
|
elif self.mem_pool_host.layout == "layer_first":
|
||||||
|
get_func = self._generic_page_get
|
||||||
|
batch_size = 128
|
||||||
|
else:
|
||||||
|
get_func = self._generic_page_get
|
||||||
|
batch_size = 8
|
||||||
|
|
||||||
|
# Transfer batch by batch
|
||||||
|
for i in range(0, len(operation.hash_value), batch_size):
|
||||||
|
batch_hashes = operation.hash_value[i : i + batch_size]
|
||||||
|
batch_host_indices = operation.host_indices[
|
||||||
|
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
||||||
|
]
|
||||||
|
prev_completed_tokens = operation.completed_tokens
|
||||||
|
# Get one batch token, and update the completed_tokens if succeed
|
||||||
|
get_func(operation, batch_hashes, batch_host_indices)
|
||||||
|
# Check termination
|
||||||
|
if (
|
||||||
|
operation.completed_tokens
|
||||||
|
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
||||||
|
):
|
||||||
|
break # Some operations fail or operation terminated by controller
|
||||||
|
# release pre-allocated memory
|
||||||
|
self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :])
|
||||||
|
|
||||||
def is_mooncake_backend(self):
|
def is_mooncake_backend(self):
|
||||||
return self.storage_backend_type == "mooncake"
|
return self.storage_backend_type == "mooncake"
|
||||||
@@ -632,15 +666,7 @@ class HiCacheController:
|
|||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||||
if self.is_mooncake_backend():
|
self._page_transfer(operation)
|
||||||
self.mooncake_page_transfer(operation)
|
|
||||||
elif self.storage_backend_type == "hf3fs":
|
|
||||||
if self.mem_pool_host.layout == "page_first":
|
|
||||||
self.zerocopy_page_transfer(operation, batch_size=128)
|
|
||||||
elif self.mem_pool_host.layout == "layer_first":
|
|
||||||
self.generic_page_transfer(operation, batch_size=128)
|
|
||||||
else:
|
|
||||||
self.generic_page_transfer(operation)
|
|
||||||
|
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
# to ensure all TP workers release the host memory at the same time
|
# to ensure all TP workers release the host memory at the same time
|
||||||
@@ -662,6 +688,27 @@ class HiCacheController:
|
|||||||
# todo: more sophisticated rate limiting based on storage backend performance
|
# todo: more sophisticated rate limiting based on storage backend performance
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]:
|
||||||
|
last_hash = operation.last_hash
|
||||||
|
tokens_to_fetch = operation.token_ids
|
||||||
|
|
||||||
|
storage_query_count = 0
|
||||||
|
remaining_tokens = len(tokens_to_fetch)
|
||||||
|
hash_value = []
|
||||||
|
while remaining_tokens >= self.page_size:
|
||||||
|
last_hash = self.get_hash_str(
|
||||||
|
tokens_to_fetch[
|
||||||
|
storage_query_count : storage_query_count + self.page_size
|
||||||
|
],
|
||||||
|
last_hash,
|
||||||
|
)
|
||||||
|
hash_value.append(last_hash)
|
||||||
|
storage_query_count += self.page_size
|
||||||
|
remaining_tokens -= self.page_size
|
||||||
|
# deferring to batch exists
|
||||||
|
hit_page_num = self.storage_backend.batch_exists(hash_value)
|
||||||
|
return hash_value[:hit_page_num], hit_page_num * self.page_size
|
||||||
|
|
||||||
def prefetch_thread_func(self):
|
def prefetch_thread_func(self):
|
||||||
"""
|
"""
|
||||||
Manage prefetching operations from storage backend to host memory.
|
Manage prefetching operations from storage backend to host memory.
|
||||||
@@ -675,38 +722,12 @@ class HiCacheController:
|
|||||||
if operation is None:
|
if operation is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
storage_hit_count = 0
|
|
||||||
if (
|
if (
|
||||||
operation.host_indices is not None
|
operation.host_indices is not None
|
||||||
) and self.prefetch_rate_limit_check():
|
) and self.prefetch_rate_limit_check():
|
||||||
last_hash = operation.last_hash
|
hash_value, storage_hit_count = self._generic_storage_hit_query(
|
||||||
tokens_to_fetch = operation.token_ids
|
operation
|
||||||
|
)
|
||||||
remaining_tokens = len(tokens_to_fetch)
|
|
||||||
hash_value = []
|
|
||||||
while remaining_tokens >= self.page_size:
|
|
||||||
last_hash = self.get_hash_str(
|
|
||||||
tokens_to_fetch[
|
|
||||||
storage_hit_count : storage_hit_count + self.page_size
|
|
||||||
],
|
|
||||||
last_hash,
|
|
||||||
)
|
|
||||||
|
|
||||||
# todo, more unified interface
|
|
||||||
if not self.is_mooncake_backend():
|
|
||||||
if not self.storage_backend.exists(last_hash):
|
|
||||||
break
|
|
||||||
hash_value.append(last_hash)
|
|
||||||
storage_hit_count += self.page_size
|
|
||||||
remaining_tokens -= self.page_size
|
|
||||||
|
|
||||||
if self.is_mooncake_backend():
|
|
||||||
# deferring to batch exists for mooncake store
|
|
||||||
exist_result = self.storage_backend.exists(hash_value)
|
|
||||||
storage_hit_count = (
|
|
||||||
sum(1 for v in exist_result.values() if v != 0)
|
|
||||||
* self.page_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
storage_hit_count_tensor = torch.tensor(
|
storage_hit_count_tensor = torch.tensor(
|
||||||
@@ -755,59 +776,64 @@ class HiCacheController:
|
|||||||
self.backup_queue.put(operation)
|
self.backup_queue.put(operation)
|
||||||
return operation.id
|
return operation.id
|
||||||
|
|
||||||
def zerocopy_page_backup(self, operation, batch_size=8):
|
# non-zero copy
|
||||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
||||||
operation.hash_value, operation.host_indices
|
data = [
|
||||||
|
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
|
||||||
|
for i in range(len(hash_values))
|
||||||
|
]
|
||||||
|
return self.storage_backend.batch_set(hash_values, data)
|
||||||
|
|
||||||
|
# zero copy
|
||||||
|
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
|
||||||
|
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||||
|
hash_values,
|
||||||
|
host_indices,
|
||||||
)
|
)
|
||||||
for i in range(0, len(hashes), batch_size):
|
success = self.storage_backend.batch_set(
|
||||||
page_hashes = hashes[i : i + batch_size]
|
key_strs,
|
||||||
page_data = dsts[i : i + batch_size]
|
target_location=buffer_ptrs,
|
||||||
success = self.storage_backend.batch_set(page_hashes, page_data)
|
target_sizes=buffer_sizes,
|
||||||
if not success:
|
)
|
||||||
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
return success
|
||||||
break
|
|
||||||
operation.completed_tokens += self.page_size * len(page_hashes)
|
|
||||||
|
|
||||||
def generic_page_backup(self, operation, batch_size=8):
|
# zero copy
|
||||||
|
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
||||||
|
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||||
|
hash_values, host_indices
|
||||||
|
)
|
||||||
|
return self.storage_backend.batch_set(hashes, dsts)
|
||||||
|
|
||||||
|
# Backup batch by batch
|
||||||
|
def _page_backup(self, operation):
|
||||||
|
# Select the set function and batch size
|
||||||
|
if self.is_mooncake_backend():
|
||||||
|
backup_set_func = self._mooncake_page_set
|
||||||
|
batch_size = 128
|
||||||
|
elif self.storage_backend_type == "hf3fs":
|
||||||
|
if self.mem_pool_host.layout == "page_first":
|
||||||
|
backup_set_func = self._3fs_zero_copy_page_set
|
||||||
|
elif self.mem_pool_host.layout == "layer_first":
|
||||||
|
backup_set_func = self._generic_page_set
|
||||||
|
batch_size = 128
|
||||||
|
else:
|
||||||
|
backup_set_func = self._generic_page_set
|
||||||
|
batch_size = 8
|
||||||
|
# Backup batch by batch
|
||||||
for i in range(0, len(operation.hash_value), batch_size):
|
for i in range(0, len(operation.hash_value), batch_size):
|
||||||
page_hashes = operation.hash_value[i : i + batch_size]
|
batch_hashes = operation.hash_value[i : i + batch_size]
|
||||||
page_data = [
|
batch_host_indices = operation.host_indices[
|
||||||
self.mem_pool_host.get_flat_data_page(
|
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
||||||
operation.host_indices[j * self.page_size]
|
|
||||||
)
|
|
||||||
for j in range(i, i + len(page_hashes))
|
|
||||||
]
|
]
|
||||||
success = self.storage_backend.batch_set(page_hashes, page_data)
|
# Set one batch token, and record if success.
|
||||||
|
# todo: allow partial success
|
||||||
|
success = backup_set_func(batch_hashes, batch_host_indices)
|
||||||
if not success:
|
if not success:
|
||||||
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
logger.warning(
|
||||||
|
f"Write page to storage: {len(batch_hashes)} pages failed."
|
||||||
|
)
|
||||||
break
|
break
|
||||||
operation.completed_tokens += self.page_size * len(page_hashes)
|
operation.completed_tokens += self.page_size * len(batch_hashes)
|
||||||
|
|
||||||
def mooncake_page_backup(self, operation):
|
|
||||||
if len(operation.hash_value):
|
|
||||||
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
|
|
||||||
indices = operation.host_indices.tolist()
|
|
||||||
non_exist_keys = []
|
|
||||||
non_exist_indices = []
|
|
||||||
for i in range(len(operation.hash_value)):
|
|
||||||
if not exist_hashvalues[operation.hash_value[i]]:
|
|
||||||
non_exist_keys.append(operation.hash_value[i])
|
|
||||||
non_exist_indices.extend(
|
|
||||||
indices[i * self.page_size : (i + 1) * self.page_size]
|
|
||||||
)
|
|
||||||
if len(non_exist_keys) > 0:
|
|
||||||
key_strs, buffer_ptrs, buffer_sizes = (
|
|
||||||
self.mem_pool_host.get_buffer_meta(
|
|
||||||
non_exist_keys, non_exist_indices
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# TODO: check the return value of batch set to see how many tokens are set successfully
|
|
||||||
self.storage_backend.batch_set(
|
|
||||||
key_strs,
|
|
||||||
target_location=buffer_ptrs,
|
|
||||||
target_sizes=buffer_sizes,
|
|
||||||
)
|
|
||||||
operation.completed_tokens += len(operation.hash_value) * self.page_size
|
|
||||||
|
|
||||||
def backup_thread_func(self):
|
def backup_thread_func(self):
|
||||||
"""
|
"""
|
||||||
@@ -820,15 +846,7 @@ class HiCacheController:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if not self.backup_skip:
|
if not self.backup_skip:
|
||||||
if self.is_mooncake_backend():
|
self._page_backup(operation)
|
||||||
self.mooncake_page_backup(operation)
|
|
||||||
elif self.storage_backend_type == "hf3fs":
|
|
||||||
if self.mem_pool_host.layout == "page_first":
|
|
||||||
self.zerocopy_page_backup(operation, batch_size=128)
|
|
||||||
elif self.mem_pool_host.layout == "layer_first":
|
|
||||||
self.generic_page_backup(operation, batch_size=128)
|
|
||||||
else:
|
|
||||||
self.generic_page_backup(operation)
|
|
||||||
min_completed_tokens = operation.completed_tokens
|
min_completed_tokens = operation.completed_tokens
|
||||||
else:
|
else:
|
||||||
min_completed_tokens = len(operation.token_ids)
|
min_completed_tokens = len(operation.token_ids)
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class HiCacheStorage(ABC):
|
|||||||
keys: List[str],
|
keys: List[str],
|
||||||
target_locations: Optional[Any] = None,
|
target_locations: Optional[Any] = None,
|
||||||
target_sizes: Optional[Any] = None,
|
target_sizes: Optional[Any] = None,
|
||||||
) -> List[torch.Tensor | None]:
|
) -> List[torch.Tensor | None] | int:
|
||||||
"""
|
"""
|
||||||
Retrieve values for multiple keys.
|
Retrieve values for multiple keys.
|
||||||
Returns a list of tensors or None for each key.
|
Returns a list of tensors or None for each key.
|
||||||
@@ -96,17 +96,28 @@ class HiCacheStorage(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def exists(self, key: str) -> bool | dict:
|
def exists(self, key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the key exists in the storage.
|
Check if the key exists in the storage.
|
||||||
Returns True if the key exists, False otherwise.
|
Returns True if the key exists, False otherwise.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def batch_exists(self, keys: List[str]) -> int:
|
||||||
|
"""
|
||||||
|
Check if the keys exist in the storage.
|
||||||
|
return the number of consecutive existing keys from the start.
|
||||||
|
Can be overridden by subclasses for more efficient implementation.
|
||||||
|
"""
|
||||||
|
for i in range(len(keys)):
|
||||||
|
if not self.exists(keys[i]):
|
||||||
|
return i
|
||||||
|
return len(keys)
|
||||||
|
|
||||||
|
|
||||||
class HiCacheFile(HiCacheStorage):
|
class HiCacheFile(HiCacheStorage):
|
||||||
|
|
||||||
def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
|
def __init__(self, file_path: str = "/tmp/hicache", is_mla_backend: bool = False):
|
||||||
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
||||||
if is_dp_attention_enabled():
|
if is_dp_attention_enabled():
|
||||||
tp_rank = get_attention_tp_rank()
|
tp_rank = get_attention_tp_rank()
|
||||||
@@ -115,7 +126,9 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
self.tp_suffix = (
|
||||||
|
f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla_backend else ""
|
||||||
|
)
|
||||||
if not os.path.exists(self.file_path) and tp_rank == 0:
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
||||||
os.makedirs(self.file_path)
|
os.makedirs(self.file_path)
|
||||||
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
||||||
|
|||||||
@@ -465,6 +465,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
|
||||||
def get_buffer_meta(self, keys, indices):
|
def get_buffer_meta(self, keys, indices):
|
||||||
|
local_rank = get_tensor_model_parallel_rank()
|
||||||
ptr_list = []
|
ptr_list = []
|
||||||
key_list = []
|
key_list = []
|
||||||
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
||||||
@@ -488,8 +489,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
ptr_list.append(k_ptr)
|
ptr_list.append(k_ptr)
|
||||||
ptr_list.append(v_ptr)
|
ptr_list.append(v_ptr)
|
||||||
key_ = keys[index // self.page_size]
|
key_ = keys[index // self.page_size]
|
||||||
key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k")
|
key_list.append(f"{key_}_{local_rank}_k")
|
||||||
key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v")
|
key_list.append(f"{key_}_{local_rank}_v")
|
||||||
element_size = (
|
element_size = (
|
||||||
self.layer_num
|
self.layer_num
|
||||||
* self.dtype.itemsize
|
* self.dtype.itemsize
|
||||||
@@ -704,6 +705,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
|
||||||
def get_buffer_meta(self, keys, indices):
|
def get_buffer_meta(self, keys, indices):
|
||||||
|
local_rank = get_tensor_model_parallel_rank()
|
||||||
ptr_list = []
|
ptr_list = []
|
||||||
key_list = []
|
key_list = []
|
||||||
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
||||||
@@ -717,7 +719,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
)
|
)
|
||||||
ptr_list.append(k_ptr)
|
ptr_list.append(k_ptr)
|
||||||
key_ = keys[index // self.page_size]
|
key_ = keys[index // self.page_size]
|
||||||
key_list.append(f"{key_}_k")
|
key_list.append(f"{key_}_{local_rank}_k")
|
||||||
element_size = (
|
element_size = (
|
||||||
self.layer_num
|
self.layer_num
|
||||||
* self.dtype.itemsize
|
* self.dtype.itemsize
|
||||||
|
|||||||
@@ -55,12 +55,11 @@ Launch Mooncake meta server:
|
|||||||
python -m mooncake.http_metadata_server
|
python -m mooncake.http_metadata_server
|
||||||
```
|
```
|
||||||
|
|
||||||
Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables:
|
Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables. Note that, for optimal performance, the Mooncake backend currently supports only the `page_first` layout.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \
|
MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \
|
||||||
MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \
|
MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \
|
||||||
MOONCAKE_LOCAL_BUFFER_SIZE=134217728 \
|
|
||||||
MOONCAKE_PROTOCOL="rdma" \
|
MOONCAKE_PROTOCOL="rdma" \
|
||||||
MOONCAKE_DEVICE="erdma_0,erdma_1" \
|
MOONCAKE_DEVICE="erdma_0,erdma_1" \
|
||||||
MOONCAKE_MASTER=127.0.0.1:50051 \
|
MOONCAKE_MASTER=127.0.0.1:50051 \
|
||||||
|
|||||||
@@ -13,21 +13,11 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
||||||
|
|
||||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
||||||
DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
|
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
|
||||||
prefix_str = ""
|
|
||||||
if prior_hash:
|
|
||||||
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
|
||||||
current_token_ids_bytes = np.array(token_ids).tobytes()
|
|
||||||
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
|
||||||
current_hash_hex = current_hash_object.hexdigest()
|
|
||||||
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MooncakeStoreConfig:
|
class MooncakeStoreConfig:
|
||||||
local_hostname: str
|
local_hostname: str
|
||||||
@@ -54,9 +44,8 @@ class MooncakeStoreConfig:
|
|||||||
global_segment_size=config.get(
|
global_segment_size=config.get(
|
||||||
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
||||||
),
|
),
|
||||||
local_buffer_size=config.get(
|
# Zero copy interface does not need local buffer
|
||||||
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
|
local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
|
||||||
),
|
|
||||||
protocol=config.get("protocol", "tcp"),
|
protocol=config.get("protocol", "tcp"),
|
||||||
device_name=config.get("device_name", "auto"),
|
device_name=config.get("device_name", "auto"),
|
||||||
master_server_address=config.get("master_server_address"),
|
master_server_address=config.get("master_server_address"),
|
||||||
@@ -79,9 +68,8 @@ class MooncakeStoreConfig:
|
|||||||
global_segment_size=int(
|
global_segment_size=int(
|
||||||
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
||||||
),
|
),
|
||||||
local_buffer_size=int(
|
# Zero copy interface does not need local buffer
|
||||||
os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
|
local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
|
||||||
),
|
|
||||||
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
|
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
|
||||||
device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
|
device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
|
||||||
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
||||||
@@ -96,7 +84,15 @@ class MooncakeStoreConfig:
|
|||||||
|
|
||||||
|
|
||||||
class MooncakeStore(HiCacheStorage):
|
class MooncakeStore(HiCacheStorage):
|
||||||
def __init__(self, is_mla: bool = False):
|
def __init__(self, is_mla_backend: bool = False):
|
||||||
|
"""
|
||||||
|
Initialize MooncakeStore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_mla_backend: If the backend is MLA
|
||||||
|
"""
|
||||||
|
self.is_mla_backend = is_mla_backend
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mooncake.store import MooncakeDistributedStore
|
from mooncake.store import MooncakeDistributedStore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -126,7 +122,6 @@ class MooncakeStore(HiCacheStorage):
|
|||||||
logger.info("Connect to Mooncake store successfully.")
|
logger.info("Connect to Mooncake store successfully.")
|
||||||
self.warmup()
|
self.warmup()
|
||||||
logger.info("Mooncake store warmup successfully.")
|
logger.info("Mooncake store warmup successfully.")
|
||||||
self.is_mla = is_mla
|
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error("Configuration loading failed: %s", e)
|
logger.error("Configuration loading failed: %s", e)
|
||||||
@@ -135,14 +130,14 @@ class MooncakeStore(HiCacheStorage):
|
|||||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
logger.error("An error occurred while loading the configuration: %s", exc)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
self.local_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
def warmup(self):
|
def warmup(self):
|
||||||
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
|
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
|
||||||
# 10 MB
|
warmup_value = bytes(4 * 1024) # 4 KB
|
||||||
warmup_value = bytes(10 * 1024 * 1024)
|
assert self.store.put(warmup_key, warmup_value) == 0
|
||||||
self.store.put(warmup_key, warmup_value)
|
|
||||||
assert self.store.is_exist(warmup_key) == 1
|
assert self.store.is_exist(warmup_key) == 1
|
||||||
self.store.get(warmup_key)
|
assert self.store.get(warmup_key) == warmup_value
|
||||||
self.store.remove(warmup_key)
|
|
||||||
|
|
||||||
def register_buffer(self, buffer: torch.Tensor) -> None:
|
def register_buffer(self, buffer: torch.Tensor) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -162,78 +157,95 @@ class MooncakeStore(HiCacheStorage):
|
|||||||
target_location: Optional[List[int]] = None,
|
target_location: Optional[List[int]] = None,
|
||||||
target_sizes: Optional[List[int]] = None,
|
target_sizes: Optional[List[int]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
assert len(key) == len(target_location) == len(target_sizes)
|
return self.batch_set([key], [value], [target_location], [target_sizes])
|
||||||
if len(key) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
for i in range(len(key)):
|
|
||||||
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._put_batch_zero_copy_impl(key, target_location, target_sizes)
|
|
||||||
|
|
||||||
def batch_set(
|
def batch_set(
|
||||||
self,
|
self,
|
||||||
keys: List[str],
|
keys: List[str],
|
||||||
value: Optional[Any] = None,
|
|
||||||
target_location: Optional[List[int]] = None,
|
target_location: Optional[List[int]] = None,
|
||||||
target_sizes: Optional[List[int]] = None,
|
target_sizes: Optional[List[int]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
assert len(keys) == len(target_location) == len(target_sizes)
|
assert len(keys) == len(target_location) == len(target_sizes)
|
||||||
if len(keys) == 0:
|
if len(keys) == 0:
|
||||||
return
|
return False
|
||||||
|
|
||||||
for i in range(len(keys)):
|
for i in range(len(keys)):
|
||||||
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
||||||
return
|
return False
|
||||||
|
|
||||||
self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
|
exist_result = self._batch_exist(keys)
|
||||||
|
set_keys = []
|
||||||
|
set_target_locations = []
|
||||||
|
set_target_sizes = []
|
||||||
|
set_indices = []
|
||||||
|
for i in range(len(keys)):
|
||||||
|
if exist_result[i] != 1:
|
||||||
|
set_keys.append(keys[i])
|
||||||
|
set_target_locations.append(target_location[i])
|
||||||
|
set_target_sizes.append(target_sizes[i])
|
||||||
|
set_indices.append(i)
|
||||||
|
# Only set non-existing keys to storage
|
||||||
|
put_result = self._put_batch_zero_copy_impl(
|
||||||
|
set_keys, set_target_locations, set_target_sizes
|
||||||
|
)
|
||||||
|
for i in range(len(set_indices)):
|
||||||
|
if put_result[i] == 0:
|
||||||
|
exist_result[set_indices[i]] = 1
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
for i in range(len(keys)):
|
||||||
|
if exist_result[i] == 0:
|
||||||
|
break
|
||||||
|
success_count += 1
|
||||||
|
# TODO: return the number of consecutive successful operations from the start.
|
||||||
|
return success_count == len(keys)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
key,
|
key,
|
||||||
target_location: Optional[Any] = None,
|
target_location: Optional[Any] = None,
|
||||||
target_sizes: Optional[Any] = None,
|
target_sizes: Optional[Any] = None,
|
||||||
) -> torch.Tensor | None:
|
) -> bool:
|
||||||
assert len(key) == len(target_location) == len(target_sizes)
|
return self.batch_get([key], [target_location], [target_sizes]) == 1
|
||||||
if len(key) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
for i in range(len(key)):
|
|
||||||
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
|
|
||||||
|
|
||||||
def batch_get(
|
def batch_get(
|
||||||
self,
|
self,
|
||||||
keys: List[str],
|
keys: List[str],
|
||||||
target_location: Optional[Any] = None,
|
target_location: Optional[Any] = None,
|
||||||
target_sizes: Optional[Any] = None,
|
target_sizes: Optional[Any] = None,
|
||||||
) -> torch.Tensor | None:
|
) -> int:
|
||||||
assert len(keys) == len(target_location) == len(target_sizes)
|
assert len(keys) == len(target_location) == len(target_sizes)
|
||||||
if len(keys) == 0:
|
if len(keys) == 0:
|
||||||
return
|
return 0
|
||||||
|
get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
|
||||||
|
if self.is_mla_backend:
|
||||||
|
key_multiplier = 1
|
||||||
|
else:
|
||||||
|
key_multiplier = 2
|
||||||
for i in range(len(keys)):
|
for i in range(len(keys)):
|
||||||
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
if get_result[i] < 0:
|
||||||
return
|
return i // key_multiplier
|
||||||
|
return len(keys) // key_multiplier
|
||||||
|
|
||||||
return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
|
def exists(self, key) -> bool:
|
||||||
|
return self.batch_exists([key]) > 0
|
||||||
|
|
||||||
def exists(self, keys) -> bool | dict:
|
def batch_exists(self, keys) -> int:
|
||||||
_keys = []
|
if self.is_mla_backend:
|
||||||
local_rank = get_tensor_model_parallel_rank()
|
query_keys = [f"{key}_k" for key in keys]
|
||||||
for key in keys:
|
key_multiplier = 1
|
||||||
if key is None:
|
else:
|
||||||
return None
|
query_keys = []
|
||||||
|
for key in keys:
|
||||||
|
query_keys.append(f"{key}_{self.local_rank}_k")
|
||||||
|
query_keys.append(f"{key}_{self.local_rank}_v")
|
||||||
|
key_multiplier = 2
|
||||||
|
|
||||||
if self.is_mla:
|
exist_result = self._batch_exist(query_keys)
|
||||||
_keys.append(f"{key}_k")
|
for i in range(len(query_keys)):
|
||||||
else:
|
if exist_result[i] != 1:
|
||||||
_keys.append(f"{key}_{local_rank}_k")
|
return i // key_multiplier
|
||||||
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
return len(query_keys) // key_multiplier
|
||||||
return result
|
|
||||||
|
|
||||||
def delete(self, key) -> None:
|
def delete(self, key) -> None:
|
||||||
raise (NotImplementedError)
|
raise (NotImplementedError)
|
||||||
@@ -248,18 +260,13 @@ class MooncakeStore(HiCacheStorage):
|
|||||||
|
|
||||||
def _put_batch_zero_copy_impl(
|
def _put_batch_zero_copy_impl(
|
||||||
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
||||||
) -> None:
|
) -> List[int]:
|
||||||
try:
|
return self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
||||||
self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
|
||||||
except TypeError as err:
|
|
||||||
logger.error("Failed to put value to Mooncake Store: %s", err)
|
|
||||||
raise TypeError("Mooncake Store Put Type Error.") from err
|
|
||||||
|
|
||||||
def _get_batch_zero_copy_impl(
|
def _get_batch_zero_copy_impl(
|
||||||
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
||||||
) -> None:
|
) -> List[int]:
|
||||||
try:
|
return self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
|
||||||
self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
|
|
||||||
except TypeError as err:
|
def _batch_exist(self, key_strs: List[str]) -> List[int]:
|
||||||
logger.error("Failed to get value from Mooncake Store: %s", err)
|
return self.store.batch_is_exist(key_strs)
|
||||||
raise TypeError("Mooncake Store Get Type Error.") from err
|
|
||||||
|
|||||||
Reference in New Issue
Block a user