Hicache Storage Layer Prototype (#7704)
This commit is contained in:
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -159,6 +161,57 @@ class TransferBuffer:
|
||||
self.buffers.queue.clear()
|
||||
|
||||
|
||||
class StorageOperation:
|
||||
counter = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
):
|
||||
self.host_indices = host_indices
|
||||
self.token_ids = token_ids
|
||||
self.last_hash = last_hash
|
||||
self.completed_tokens = 0
|
||||
self.hash_value = []
|
||||
|
||||
self.id = StorageOperation.counter
|
||||
StorageOperation.counter += 1
|
||||
|
||||
def __lt__(self, other: "StorageOperation"):
|
||||
return self.id < other.id
|
||||
|
||||
|
||||
class PrefetchOperation(StorageOperation):
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
):
|
||||
self.request_id = request_id
|
||||
|
||||
self._done_flag = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
super().__init__(host_indices, token_ids, last_hash)
|
||||
|
||||
def increment(self, num_tokens: int):
|
||||
with self._lock:
|
||||
if self._done_flag:
|
||||
return
|
||||
self.completed_tokens += num_tokens
|
||||
|
||||
def mark_done(self):
|
||||
with self._lock:
|
||||
self._done_flag = True
|
||||
|
||||
def is_done(self) -> bool:
|
||||
return self._done_flag
|
||||
|
||||
|
||||
class HiCacheController:
|
||||
|
||||
def __init__(
|
||||
@@ -169,6 +222,8 @@ class HiCacheController:
|
||||
load_cache_event: threading.Event = None,
|
||||
write_policy: str = "write_through_selective",
|
||||
io_backend: str = "",
|
||||
storage_backend: Optional[str] = None,
|
||||
prefetch_threshold: int = 256,
|
||||
):
|
||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||
@@ -186,6 +241,19 @@ class HiCacheController:
|
||||
else:
|
||||
self.io_backend = io_backend
|
||||
|
||||
self.enable_storage = False
|
||||
# todo: move backend initialization to storage backend module
|
||||
if storage_backend is not None:
|
||||
if storage_backend == "file":
|
||||
self.storage_backend = HiCacheFile()
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = prefetch_threshold
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported storage backend: {storage_backend}"
|
||||
)
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
||||
@@ -218,9 +286,26 @@ class HiCacheController:
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
)
|
||||
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
|
||||
if self.enable_storage:
|
||||
self.prefetch_thread = threading.Thread(
|
||||
target=self.prefetch_thread_func, daemon=True
|
||||
)
|
||||
self.backup_thread = threading.Thread(
|
||||
target=self.backup_thread_func, daemon=True
|
||||
)
|
||||
self.prefetch_queue = Queue()
|
||||
self.backup_queue = Queue()
|
||||
|
||||
self.prefetch_revoke_queue = Queue()
|
||||
self.ack_backup_queue = Queue()
|
||||
|
||||
self.prefetch_thread.start()
|
||||
self.backup_thread.start()
|
||||
|
||||
def reset(self):
|
||||
self.stop_event.set()
|
||||
self.write_thread.join()
|
||||
@@ -232,6 +317,13 @@ class HiCacheController:
|
||||
self.load_buffer.clear()
|
||||
self.ack_write_queue.queue.clear()
|
||||
self.ack_load_queue.queue.clear()
|
||||
if self.enable_storage:
|
||||
self.prefetch_thread.join()
|
||||
self.backup_thread.join()
|
||||
self.prefetch_queue.queue.clear()
|
||||
self.backup_queue.queue.clear()
|
||||
self.prefetch_revoke_queue.queue.clear()
|
||||
self.ack_backup_queue.queue.clear()
|
||||
|
||||
self.write_thread = threading.Thread(
|
||||
target=self.write_thread_func_direct, daemon=True
|
||||
@@ -243,6 +335,16 @@ class HiCacheController:
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
|
||||
if self.enable_storage:
|
||||
self.prefetch_thread = threading.Thread(
|
||||
target=self.prefetch_thread_func, daemon=True
|
||||
)
|
||||
self.backup_thread = threading.Thread(
|
||||
target=self.backup_thread_func, daemon=True
|
||||
)
|
||||
self.prefetch_thread.start()
|
||||
self.backup_thread.start()
|
||||
|
||||
def write(
|
||||
self,
|
||||
device_indices: torch.Tensor,
|
||||
@@ -383,3 +485,142 @@ class HiCacheController:
|
||||
raise ValueError(
|
||||
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
||||
)
|
||||
|
||||
def prefetch(
|
||||
self,
|
||||
request_id: str,
|
||||
host_indices: torch.Tensor,
|
||||
new_input_tokens: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Prefetch KV caches from storage backend to host memory.
|
||||
"""
|
||||
operation = PrefetchOperation(
|
||||
request_id, host_indices, new_input_tokens, last_hash
|
||||
)
|
||||
self.prefetch_queue.put(operation)
|
||||
return operation
|
||||
|
||||
def terminate_prefetch(self, operation):
|
||||
operation.mark_done()
|
||||
return operation.completed_tokens, operation.hash_value
|
||||
|
||||
def prefetch_io_aux_func(self):
|
||||
"""
|
||||
Auxiliary function conducting IO operations for prefetching.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||
for h in operation.hash_value:
|
||||
page_data = self.storage_backend.get(h)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
||||
)
|
||||
break
|
||||
self.mem_pool_host.set_from_flat_data_page(
|
||||
operation.host_indices[operation.completed_tokens],
|
||||
page_data,
|
||||
)
|
||||
operation.increment(self.page_size)
|
||||
if operation.is_done():
|
||||
# operation terminated by controller, release pre-allocated memory
|
||||
self.mem_pool_host.free(
|
||||
operation.host_indices[operation.completed_tokens :]
|
||||
)
|
||||
break
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
def prefetch_thread_func(self):
|
||||
"""
|
||||
Manage prefetching operations from storage backend to host memory.
|
||||
"""
|
||||
self.prefetch_buffer = Queue()
|
||||
aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
|
||||
try:
|
||||
operation = self.prefetch_queue.get(block=True, timeout=1)
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_fetch = operation.token_ids
|
||||
|
||||
storage_hit_count = 0
|
||||
remaining_tokens = len(tokens_to_fetch)
|
||||
hash_value = []
|
||||
while remaining_tokens >= self.page_size:
|
||||
last_hash = get_hash_str(
|
||||
tokens_to_fetch[
|
||||
storage_hit_count : storage_hit_count + self.page_size
|
||||
],
|
||||
last_hash,
|
||||
)
|
||||
if self.storage_backend.exists(last_hash):
|
||||
storage_hit_count += self.page_size
|
||||
hash_value.append(last_hash)
|
||||
remaining_tokens -= self.page_size
|
||||
else:
|
||||
break
|
||||
|
||||
if storage_hit_count < self.prefetch_threshold:
|
||||
# not to prefetch if not enough benefits
|
||||
self.prefetch_revoke_queue.put(operation.request_id)
|
||||
else:
|
||||
operation.hash_value = hash_value
|
||||
logger.debug(
|
||||
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
|
||||
)
|
||||
self.prefetch_buffer.put(operation)
|
||||
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
def write_storage(
|
||||
self,
|
||||
host_indices: torch.Tensor,
|
||||
token_ids: List[int],
|
||||
last_hash: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Write KV caches from host memory to storage backend.
|
||||
"""
|
||||
operation = StorageOperation(host_indices, token_ids, last_hash)
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
def backup_thread_func(self):
|
||||
"""
|
||||
Manage backup operations from host memory to storage backend.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.backup_queue.get(block=True, timeout=1)
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_backup = operation.token_ids
|
||||
|
||||
for i in range(0, len(tokens_to_backup), self.page_size):
|
||||
last_hash = get_hash_str(
|
||||
tokens_to_backup[i : i + self.page_size], last_hash
|
||||
)
|
||||
# todo, handle failures in storage backend
|
||||
self.storage_backend.set(
|
||||
last_hash,
|
||||
self.mem_pool_host.get_flat_data_page(
|
||||
operation.host_indices[i]
|
||||
),
|
||||
)
|
||||
operation.completed_tokens += self.page_size
|
||||
operation.hash_value.append(last_hash)
|
||||
|
||||
self.ack_backup_queue.put((operation.id, operation.hash_value))
|
||||
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user