2025-01-10 20:22:01 -08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Copyright 2023-2025 SGLang Team
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import logging
|
2025-02-23 21:56:30 -08:00
|
|
|
import math
|
2025-01-10 20:22:01 -08:00
|
|
|
import threading
|
2025-08-08 17:09:28 +08:00
|
|
|
import time
|
2025-02-23 21:56:30 -08:00
|
|
|
from queue import Empty, Full, PriorityQueue, Queue
|
2025-09-08 22:18:50 -07:00
|
|
|
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
2025-01-10 20:22:01 -08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
2025-10-10 15:22:05 +08:00
|
|
|
from sglang.srt.mem_cache.hicache_storage import (
|
|
|
|
|
HiCacheStorageConfig,
|
|
|
|
|
HiCacheStorageExtraInfo,
|
|
|
|
|
)
|
2025-08-27 08:55:20 +08:00
|
|
|
|
2025-06-22 12:37:18 +08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
|
|
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
2025-01-10 20:22:01 -08:00
|
|
|
|
2025-08-27 08:55:20 +08:00
|
|
|
from sglang.srt.distributed import (
|
|
|
|
|
get_tensor_model_parallel_rank,
|
|
|
|
|
get_tensor_model_parallel_world_size,
|
|
|
|
|
)
|
|
|
|
|
from sglang.srt.layers.dp_attention import (
|
2025-09-06 07:52:55 +08:00
|
|
|
get_attention_dp_rank,
|
2025-08-27 08:55:20 +08:00
|
|
|
get_attention_tp_rank,
|
|
|
|
|
get_attention_tp_size,
|
|
|
|
|
is_dp_attention_enabled,
|
|
|
|
|
)
|
2025-08-26 10:05:10 +08:00
|
|
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
2025-07-18 00:20:19 -07:00
|
|
|
|
2025-01-10 20:22:01 -08:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
class LayerLoadingEvent:
|
|
|
|
|
def __init__(self, num_layers: int):
|
|
|
|
|
self._num_layers = num_layers
|
|
|
|
|
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
|
|
|
|
|
self.start_event = torch.cuda.Event() # start event on controller stream
|
|
|
|
|
|
|
|
|
|
def complete(self, layer_index: int):
|
|
|
|
|
assert 0 <= layer_index < self._num_layers
|
|
|
|
|
self.load_events[layer_index].record()
|
|
|
|
|
|
|
|
|
|
def wait(self, layer_index: int):
|
|
|
|
|
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def finish_event(self):
|
|
|
|
|
return self.load_events[-1]
|
|
|
|
|
|
|
|
|
|
|
2025-03-12 11:22:35 -07:00
|
|
|
class LayerDoneCounter:
|
2025-09-08 22:18:50 -07:00
|
|
|
def __init__(self, num_layers: int):
|
2025-06-17 17:44:57 -07:00
|
|
|
self.num_layers = num_layers
|
|
|
|
|
# extra producer and consumer counters for overlap mode
|
|
|
|
|
self.num_counters = 3
|
2025-09-08 22:18:50 -07:00
|
|
|
self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
|
|
|
|
|
self.producer_index = -1
|
|
|
|
|
self.consumer_index = -1
|
2025-06-17 17:44:57 -07:00
|
|
|
|
|
|
|
|
def update_producer(self):
|
2025-09-08 22:18:50 -07:00
|
|
|
self.producer_index = (self.producer_index + 1) % self.num_counters
|
|
|
|
|
assert self.events[
|
|
|
|
|
self.producer_index
|
|
|
|
|
].finish_event.query(), (
|
|
|
|
|
"Producer finish event should be ready before being reused."
|
|
|
|
|
)
|
2025-06-17 17:44:57 -07:00
|
|
|
return self.producer_index
|
|
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
def set_consumer(self, index: int):
|
2025-06-17 17:44:57 -07:00
|
|
|
self.consumer_index = index
|
2025-03-12 11:22:35 -07:00
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
def wait_until(self, threshold: int):
|
|
|
|
|
if self.consumer_index < 0:
|
|
|
|
|
return
|
|
|
|
|
self.events[self.consumer_index].wait(threshold)
|
2025-03-12 11:22:35 -07:00
|
|
|
|
|
|
|
|
def reset(self):
|
2025-09-08 22:18:50 -07:00
|
|
|
self.producer_index = -1
|
|
|
|
|
self.consumer_index = -1
|
2025-03-12 11:22:35 -07:00
|
|
|
|
|
|
|
|
|
2025-01-10 20:22:01 -08:00
|
|
|
class CacheOperation:
|
|
|
|
|
|
|
|
|
|
counter = 0
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
host_indices: torch.Tensor,
|
|
|
|
|
device_indices: torch.Tensor,
|
|
|
|
|
node_id: int,
|
|
|
|
|
priority: Optional[int] = None,
|
|
|
|
|
):
|
|
|
|
|
self.host_indices = host_indices
|
|
|
|
|
self.device_indices = device_indices
|
|
|
|
|
self.node_ids = [node_id]
|
|
|
|
|
self.data = None
|
|
|
|
|
|
|
|
|
|
self.id = CacheOperation.counter
|
|
|
|
|
CacheOperation.counter += 1
|
|
|
|
|
# default priority is the order of creation
|
|
|
|
|
self.priority = priority if priority is not None else self.id
|
|
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
@staticmethod
|
|
|
|
|
def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
|
|
|
|
|
assert len(ops) > 0
|
|
|
|
|
if len(ops) == 1:
|
|
|
|
|
return ops[0]
|
|
|
|
|
|
|
|
|
|
host_indices = torch.cat([op.host_indices for op in ops])
|
|
|
|
|
device_indices = torch.cat([op.device_indices for op in ops])
|
|
|
|
|
node_ids = []
|
|
|
|
|
priority = min(op.priority for op in ops)
|
|
|
|
|
for op in ops:
|
|
|
|
|
node_ids.extend(op.node_ids)
|
|
|
|
|
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
|
|
|
|
|
merged_op.node_ids = node_ids
|
|
|
|
|
return merged_op
|
|
|
|
|
|
|
|
|
|
def __lt__(self, other: CacheOperation):
|
|
|
|
|
return self.priority < other.priority
|
2025-02-23 21:56:30 -08:00
|
|
|
|
|
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
class HiCacheAck(NamedTuple):
|
|
|
|
|
start_event: torch.cuda.Event
|
|
|
|
|
finish_event: torch.cuda.Event
|
|
|
|
|
node_ids: List[int]
|
2025-01-10 20:22:01 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransferBuffer:
|
|
|
|
|
"""
|
|
|
|
|
Overlapping buffer preparation and transfer operations to improve throughput.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-02-23 21:56:30 -08:00
|
|
|
def __init__(
|
2025-07-31 14:15:51 +08:00
|
|
|
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
|
2025-02-23 21:56:30 -08:00
|
|
|
) -> None:
|
|
|
|
|
self.stop_event = stop_event
|
2025-01-10 20:22:01 -08:00
|
|
|
self.buffers = Queue(maxsize=buffer_count)
|
|
|
|
|
# todo: adjust the buffer size based on throughput profile of the system
|
|
|
|
|
self.max_buffer_size = max_buffer_size
|
|
|
|
|
|
|
|
|
|
def full(self) -> bool:
|
|
|
|
|
return self.buffers.full()
|
|
|
|
|
|
|
|
|
|
def empty(self) -> bool:
|
|
|
|
|
return self.buffers.empty()
|
|
|
|
|
|
2025-02-23 21:56:30 -08:00
|
|
|
def put(self, item, block=True, timeout=1) -> None:
|
|
|
|
|
while not self.stop_event.is_set():
|
|
|
|
|
try:
|
|
|
|
|
self.buffers.put(item, block=block, timeout=timeout)
|
|
|
|
|
break
|
|
|
|
|
except Full:
|
|
|
|
|
if not block:
|
|
|
|
|
break
|
|
|
|
|
continue
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
2025-01-10 20:22:01 -08:00
|
|
|
|
2025-02-23 21:56:30 -08:00
|
|
|
def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
|
2025-01-10 20:22:01 -08:00
|
|
|
try:
|
2025-02-23 21:56:30 -08:00
|
|
|
return self.buffers.get(block=block, timeout=timeout)
|
|
|
|
|
except Empty:
|
|
|
|
|
return None
|
2025-01-10 20:22:01 -08:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
|
2025-02-23 21:56:30 -08:00
|
|
|
def clear(self):
|
|
|
|
|
self.buffers.queue.clear()
|
|
|
|
|
|
2025-01-10 20:22:01 -08:00
|
|
|
|
2025-07-18 00:20:19 -07:00
|
|
|
class StorageOperation:
|
|
|
|
|
counter = 0
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
host_indices: torch.Tensor,
|
|
|
|
|
token_ids: List[int],
|
|
|
|
|
last_hash: Optional[str] = None,
|
2025-08-11 14:18:59 -07:00
|
|
|
hash_value: Optional[List[str]] = None,
|
2025-10-10 15:22:05 +08:00
|
|
|
prefix_keys: Optional[List[str]] = None,
|
2025-07-18 00:20:19 -07:00
|
|
|
):
|
|
|
|
|
self.host_indices = host_indices
|
|
|
|
|
self.token_ids = token_ids
|
|
|
|
|
self.last_hash = last_hash
|
|
|
|
|
self.completed_tokens = 0
|
2025-08-11 14:18:59 -07:00
|
|
|
self.hash_value = hash_value if hash_value is not None else []
|
2025-10-10 15:22:05 +08:00
|
|
|
self.prefix_keys = prefix_keys
|
2025-07-18 00:20:19 -07:00
|
|
|
|
|
|
|
|
self.id = StorageOperation.counter
|
|
|
|
|
StorageOperation.counter += 1
|
|
|
|
|
|
|
|
|
|
def __lt__(self, other: "StorageOperation"):
|
|
|
|
|
return self.id < other.id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PrefetchOperation(StorageOperation):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
request_id: str,
|
|
|
|
|
host_indices: torch.Tensor,
|
|
|
|
|
token_ids: List[int],
|
|
|
|
|
last_hash: Optional[str] = None,
|
2025-10-10 15:22:05 +08:00
|
|
|
prefix_keys: Optional[List[str]] = None,
|
2025-07-18 00:20:19 -07:00
|
|
|
):
|
|
|
|
|
self.request_id = request_id
|
|
|
|
|
|
|
|
|
|
self._lock = threading.Lock()
|
2025-09-08 09:34:04 +08:00
|
|
|
self._terminated_flag = False
|
2025-08-08 17:09:28 +08:00
|
|
|
self.start_time = time.monotonic()
|
|
|
|
|
|
2025-10-10 15:22:05 +08:00
|
|
|
super().__init__(host_indices, token_ids, last_hash, prefix_keys=prefix_keys)
|
2025-07-18 00:20:19 -07:00
|
|
|
|
|
|
|
|
def increment(self, num_tokens: int):
|
|
|
|
|
with self._lock:
|
2025-09-08 09:34:04 +08:00
|
|
|
if self._terminated_flag:
|
2025-07-26 23:13:16 -07:00
|
|
|
return False
|
2025-07-18 00:20:19 -07:00
|
|
|
self.completed_tokens += num_tokens
|
2025-07-26 23:13:16 -07:00
|
|
|
return True
|
2025-07-18 00:20:19 -07:00
|
|
|
|
2025-09-08 09:34:04 +08:00
|
|
|
def mark_terminate(self):
|
2025-07-18 00:20:19 -07:00
|
|
|
with self._lock:
|
2025-09-08 09:34:04 +08:00
|
|
|
self._terminated_flag = True
|
2025-07-18 00:20:19 -07:00
|
|
|
|
2025-09-08 09:34:04 +08:00
|
|
|
def is_terminated(self) -> bool:
|
|
|
|
|
return self._terminated_flag
|
2025-07-18 00:20:19 -07:00
|
|
|
|
|
|
|
|
|
2025-01-10 20:22:01 -08:00
|
|
|
class HiCacheController:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
2025-06-22 12:37:18 +08:00
|
|
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
2025-03-14 11:42:14 +08:00
|
|
|
mem_pool_host: HostKVCache,
|
2025-04-01 22:38:15 -07:00
|
|
|
page_size: int,
|
2025-07-24 17:31:47 -07:00
|
|
|
tp_group: torch.distributed.ProcessGroup,
|
2025-09-08 22:18:50 -07:00
|
|
|
load_cache_event: threading.Event,
|
2025-01-10 20:22:01 -08:00
|
|
|
write_policy: str = "write_through_selective",
|
2025-07-06 22:53:36 -07:00
|
|
|
io_backend: str = "",
|
2025-07-18 00:20:19 -07:00
|
|
|
storage_backend: Optional[str] = None,
|
|
|
|
|
prefetch_threshold: int = 256,
|
2025-08-27 08:55:20 +08:00
|
|
|
model_name: Optional[str] = None,
|
2025-10-01 21:44:10 +08:00
|
|
|
storage_backend_extra_config: Optional[dict] = None,
|
2025-01-10 20:22:01 -08:00
|
|
|
):
|
2025-03-07 00:58:20 -08:00
|
|
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
|
|
|
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
2025-01-10 20:22:01 -08:00
|
|
|
self.mem_pool_host = mem_pool_host
|
|
|
|
|
self.write_policy = write_policy
|
2025-04-01 22:38:15 -07:00
|
|
|
self.page_size = page_size
|
2025-07-31 20:37:49 -07:00
|
|
|
self.io_backend = io_backend
|
2025-07-18 00:20:19 -07:00
|
|
|
self.enable_storage = False
|
2025-08-26 10:05:10 +08:00
|
|
|
|
2025-07-18 00:20:19 -07:00
|
|
|
if storage_backend is not None:
|
2025-08-02 06:58:17 +08:00
|
|
|
self.storage_backend_type = storage_backend
|
2025-08-26 10:05:10 +08:00
|
|
|
from sglang.srt.mem_cache.hicache_storage import get_hash_str
|
|
|
|
|
|
|
|
|
|
self.get_hash_str = get_hash_str
|
2025-08-27 08:55:20 +08:00
|
|
|
self.storage_config = self._generate_storage_config(
|
|
|
|
|
model_name, storage_backend_extra_config
|
|
|
|
|
)
|
2025-08-31 07:58:21 -07:00
|
|
|
# for MLA models, only one rank needs to backup the KV cache
|
2025-08-26 10:05:10 +08:00
|
|
|
self.backup_skip = (
|
2025-08-27 08:55:20 +08:00
|
|
|
self.storage_config.is_mla_model
|
2025-08-31 07:58:21 -07:00
|
|
|
# todo: load balancing
|
2025-08-27 08:55:20 +08:00
|
|
|
and self.storage_config.tp_rank != 0
|
2025-08-26 10:05:10 +08:00
|
|
|
)
|
2025-08-27 08:55:20 +08:00
|
|
|
|
2025-09-27 09:34:11 +08:00
|
|
|
# Use storage backend factory for dynamic backend creation
|
|
|
|
|
from sglang.srt.mem_cache.storage import StorageBackendFactory
|
2025-08-26 10:05:10 +08:00
|
|
|
|
2025-09-27 09:34:11 +08:00
|
|
|
try:
|
|
|
|
|
self.storage_backend = StorageBackendFactory.create_backend(
|
|
|
|
|
storage_backend, self.storage_config, self.mem_pool_host
|
2025-07-18 00:20:19 -07:00
|
|
|
)
|
2025-09-27 09:34:11 +08:00
|
|
|
except ValueError as e:
|
|
|
|
|
raise ValueError(f"Failed to create storage backend: {e}") from e
|
2025-08-31 07:58:21 -07:00
|
|
|
|
2025-09-23 06:17:31 +08:00
|
|
|
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
|
|
|
|
|
|
2025-07-31 08:42:41 +08:00
|
|
|
self.enable_storage = True
|
|
|
|
|
# todo: threshold policy for prefetching
|
|
|
|
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
2025-08-08 17:09:28 +08:00
|
|
|
self.prefetch_capacity_limit = int(
|
|
|
|
|
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
|
|
|
|
)
|
2025-08-31 07:58:21 -07:00
|
|
|
# granularity of batch storage IO operations, in number of pages
|
|
|
|
|
self.storage_batch_size = 128
|
2025-08-08 17:09:28 +08:00
|
|
|
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
|
|
|
|
self.prefetch_tokens_occupied = 0
|
|
|
|
|
|
2025-08-01 05:59:29 +08:00
|
|
|
# create a new communication group for synchronizing storage operations across TP workers
|
|
|
|
|
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
|
|
|
|
if self.tp_world_size > 1:
|
|
|
|
|
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
|
|
|
|
self.prefetch_tp_group = torch.distributed.new_group(
|
|
|
|
|
group_ranks, backend="gloo"
|
|
|
|
|
)
|
2025-07-18 00:20:19 -07:00
|
|
|
|
2025-09-05 04:24:12 +08:00
|
|
|
# Select the get and set functions
|
|
|
|
|
self.page_get_func = self._generic_page_get
|
|
|
|
|
self.page_set_func = self._generic_page_set
|
2025-09-23 06:17:31 +08:00
|
|
|
|
2025-10-01 21:43:34 +08:00
|
|
|
if self.storage_backend_type in ["hf3fs", "mooncake", "eic"]:
|
2025-09-23 06:17:31 +08:00
|
|
|
self.page_get_func = self._page_get_zero_copy
|
|
|
|
|
self.page_set_func = self._page_set_zero_copy
|
2025-09-05 04:24:12 +08:00
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
self.device = self.mem_pool_device.device
|
|
|
|
|
self.layer_num = self.mem_pool_device.layer_num
|
|
|
|
|
self.layer_done_counter = LayerDoneCounter(self.layer_num)
|
2025-03-12 11:22:35 -07:00
|
|
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
|
|
|
|
|
2025-01-10 20:22:01 -08:00
|
|
|
if write_policy not in [
|
|
|
|
|
"write_through",
|
|
|
|
|
"write_through_selective",
|
|
|
|
|
"write_back",
|
|
|
|
|
]:
|
|
|
|
|
raise ValueError(f"Invalid write policy: {write_policy}")
|
|
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
# self.write_queue = PriorityQueue[CacheOperation]()
|
|
|
|
|
self.load_queue: List[CacheOperation] = []
|
|
|
|
|
self.write_queue: List[CacheOperation] = []
|
|
|
|
|
self.ack_load_queue: List[HiCacheAck] = []
|
|
|
|
|
self.ack_write_queue: List[HiCacheAck] = []
|
2025-01-10 20:22:01 -08:00
|
|
|
|
2025-02-23 21:56:30 -08:00
|
|
|
self.stop_event = threading.Event()
|
|
|
|
|
self.write_buffer = TransferBuffer(self.stop_event)
|
|
|
|
|
self.load_buffer = TransferBuffer(
|
|
|
|
|
self.stop_event, buffer_count=10, max_buffer_size=100
|
|
|
|
|
)
|
2025-01-10 20:22:01 -08:00
|
|
|
|
|
|
|
|
self.write_stream = torch.cuda.Stream()
|
|
|
|
|
self.load_stream = torch.cuda.Stream()
|
|
|
|
|
|
2025-07-18 00:20:19 -07:00
|
|
|
if self.enable_storage:
|
|
|
|
|
self.prefetch_thread = threading.Thread(
|
|
|
|
|
target=self.prefetch_thread_func, daemon=True
|
|
|
|
|
)
|
|
|
|
|
self.backup_thread = threading.Thread(
|
|
|
|
|
target=self.backup_thread_func, daemon=True
|
|
|
|
|
)
|
|
|
|
|
self.prefetch_queue = Queue()
|
|
|
|
|
self.backup_queue = Queue()
|
|
|
|
|
|
|
|
|
|
self.prefetch_revoke_queue = Queue()
|
|
|
|
|
self.ack_backup_queue = Queue()
|
2025-08-31 07:58:21 -07:00
|
|
|
self.host_mem_release_queue = Queue()
|
2025-07-18 00:20:19 -07:00
|
|
|
|
|
|
|
|
self.prefetch_thread.start()
|
|
|
|
|
self.backup_thread.start()
|
|
|
|
|
|
2025-08-27 08:55:20 +08:00
|
|
|
def _generate_storage_config(
|
|
|
|
|
self,
|
|
|
|
|
model_name: Optional[str] = None,
|
2025-10-01 21:44:10 +08:00
|
|
|
storage_backend_extra_config: Optional[dict] = None,
|
2025-08-27 08:55:20 +08:00
|
|
|
):
|
|
|
|
|
|
|
|
|
|
if is_dp_attention_enabled():
|
|
|
|
|
self.tp_rank = get_attention_tp_rank()
|
|
|
|
|
self.tp_size = get_attention_tp_size()
|
2025-09-06 07:52:55 +08:00
|
|
|
self.dp_rank = get_attention_dp_rank()
|
2025-08-27 08:55:20 +08:00
|
|
|
else:
|
|
|
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
2025-09-06 07:52:55 +08:00
|
|
|
self.dp_rank = 0
|
2025-08-27 08:55:20 +08:00
|
|
|
|
|
|
|
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
|
|
|
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
|
|
|
|
|
|
|
|
|
return HiCacheStorageConfig(
|
|
|
|
|
tp_rank=self.tp_rank,
|
|
|
|
|
tp_size=self.tp_size,
|
|
|
|
|
is_mla_model=is_mla_backend,
|
2025-09-03 02:30:11 +08:00
|
|
|
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
2025-08-27 08:55:20 +08:00
|
|
|
model_name=model_name,
|
2025-10-01 21:44:10 +08:00
|
|
|
extra_config=storage_backend_extra_config,
|
2025-08-27 08:55:20 +08:00
|
|
|
)
|
|
|
|
|
|
2025-02-23 21:56:30 -08:00
|
|
|
def reset(self):
|
|
|
|
|
self.stop_event.set()
|
|
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
self.write_queue.clear()
|
|
|
|
|
self.load_queue.clear()
|
2025-02-23 21:56:30 -08:00
|
|
|
self.write_buffer.clear()
|
|
|
|
|
self.load_buffer.clear()
|
2025-09-08 22:18:50 -07:00
|
|
|
self.ack_write_queue.clear()
|
|
|
|
|
self.ack_load_queue.clear()
|
2025-07-18 00:20:19 -07:00
|
|
|
if self.enable_storage:
|
|
|
|
|
self.prefetch_thread.join()
|
|
|
|
|
self.backup_thread.join()
|
|
|
|
|
self.prefetch_queue.queue.clear()
|
|
|
|
|
self.backup_queue.queue.clear()
|
|
|
|
|
self.prefetch_revoke_queue.queue.clear()
|
|
|
|
|
self.ack_backup_queue.queue.clear()
|
2025-02-23 21:56:30 -08:00
|
|
|
|
|
|
|
|
self.stop_event.clear()
|
|
|
|
|
|
2025-07-18 00:20:19 -07:00
|
|
|
if self.enable_storage:
|
|
|
|
|
self.prefetch_thread = threading.Thread(
|
|
|
|
|
target=self.prefetch_thread_func, daemon=True
|
|
|
|
|
)
|
|
|
|
|
self.backup_thread = threading.Thread(
|
|
|
|
|
target=self.backup_thread_func, daemon=True
|
|
|
|
|
)
|
|
|
|
|
self.prefetch_thread.start()
|
|
|
|
|
self.backup_thread.start()
|
|
|
|
|
|
2025-01-10 20:22:01 -08:00
|
|
|
def write(
|
|
|
|
|
self,
|
|
|
|
|
device_indices: torch.Tensor,
|
|
|
|
|
priority: Optional[int] = None,
|
2025-09-08 22:18:50 -07:00
|
|
|
node_id: int = -1,
|
2025-01-10 20:22:01 -08:00
|
|
|
) -> Optional[torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Back up KV caches from device memory to host memory.
|
|
|
|
|
"""
|
|
|
|
|
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
|
|
|
|
if host_indices is None:
|
|
|
|
|
return None
|
2025-09-08 22:18:50 -07:00
|
|
|
self.write_queue.append(
|
2025-01-10 20:22:01 -08:00
|
|
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
|
|
|
|
)
|
2025-09-08 22:18:50 -07:00
|
|
|
self.start_writing()
|
2025-01-10 20:22:01 -08:00
|
|
|
return host_indices
|
|
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
def start_writing(self) -> None:
|
|
|
|
|
if len(self.write_queue) == 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
op = CacheOperation.merge_ops(self.write_queue)
|
|
|
|
|
host_indices, device_indices = self.move_indices(op)
|
|
|
|
|
self.write_queue.clear()
|
|
|
|
|
|
|
|
|
|
start_event = torch.cuda.Event()
|
|
|
|
|
finish_event = torch.cuda.Event()
|
|
|
|
|
|
|
|
|
|
start_event.record()
|
|
|
|
|
with torch.cuda.stream(self.write_stream):
|
|
|
|
|
start_event.wait(self.write_stream)
|
|
|
|
|
self.mem_pool_host.backup_from_device_all_layer(
|
|
|
|
|
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
|
|
|
|
)
|
|
|
|
|
finish_event.record()
|
|
|
|
|
# NOTE: We must save the host indices and device indices here,
|
|
|
|
|
# this is because we need to guarantee that these tensors are
|
|
|
|
|
# still alive when the write stream is executing.
|
|
|
|
|
if host_indices.is_cuda:
|
|
|
|
|
host_indices.record_stream(self.write_stream)
|
|
|
|
|
if device_indices.is_cuda:
|
|
|
|
|
device_indices.record_stream(self.write_stream)
|
|
|
|
|
|
|
|
|
|
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
|
|
|
|
|
|
2025-01-10 20:22:01 -08:00
|
|
|
def load(
|
|
|
|
|
self,
|
|
|
|
|
host_indices: torch.Tensor,
|
|
|
|
|
priority: Optional[int] = None,
|
2025-09-08 22:18:50 -07:00
|
|
|
node_id: int = -1,
|
2025-01-10 20:22:01 -08:00
|
|
|
) -> Optional[torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Load KV caches from host memory to device memory.
|
|
|
|
|
"""
|
2025-03-07 00:58:20 -08:00
|
|
|
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
2025-01-10 20:22:01 -08:00
|
|
|
if device_indices is None:
|
|
|
|
|
return None
|
2025-09-08 22:18:50 -07:00
|
|
|
self.load_queue.append(
|
2025-01-10 20:22:01 -08:00
|
|
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
|
|
|
|
)
|
|
|
|
|
return device_indices
|
|
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
def move_indices(self, op: CacheOperation):
|
|
|
|
|
host_indices, device_indices = op.host_indices, op.device_indices
|
2025-07-06 22:53:36 -07:00
|
|
|
# move indices to GPU if using kernels, to host if using direct indexing
|
|
|
|
|
if self.io_backend == "kernel":
|
2025-09-08 22:18:50 -07:00
|
|
|
if not host_indices.is_cuda:
|
|
|
|
|
host_indices = host_indices.to(self.device, non_blocking=True)
|
|
|
|
|
return host_indices, device_indices
|
2025-07-06 22:53:36 -07:00
|
|
|
elif self.io_backend == "direct":
|
2025-09-12 14:19:44 +08:00
|
|
|
if self.mem_pool_host.layout == "layer_first":
|
|
|
|
|
device_indices = device_indices.cpu()
|
|
|
|
|
host_indices, idx = host_indices.sort()
|
|
|
|
|
return host_indices, device_indices.index_select(0, idx)
|
|
|
|
|
elif self.mem_pool_host.layout == "page_first_direct":
|
|
|
|
|
return host_indices, device_indices.cpu()
|
2025-07-06 22:53:36 -07:00
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported io backend")
|
|
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
def start_loading(self) -> int:
|
|
|
|
|
if len(self.load_queue) == 0:
|
|
|
|
|
return -1
|
2025-01-10 20:22:01 -08:00
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
producer_id = self.layer_done_counter.update_producer()
|
|
|
|
|
op = CacheOperation.merge_ops(self.load_queue)
|
|
|
|
|
host_indices, device_indices = self.move_indices(op)
|
|
|
|
|
self.load_queue.clear()
|
|
|
|
|
producer_event = self.layer_done_counter.events[producer_id]
|
|
|
|
|
producer_event.start_event.record()
|
2025-03-12 11:22:35 -07:00
|
|
|
|
2025-09-08 22:18:50 -07:00
|
|
|
with torch.cuda.stream(self.load_stream):
|
|
|
|
|
producer_event.start_event.wait(self.load_stream)
|
|
|
|
|
for i in range(self.layer_num):
|
2025-07-31 20:37:49 -07:00
|
|
|
self.mem_pool_host.load_to_device_per_layer(
|
|
|
|
|
self.mem_pool_device,
|
2025-07-06 22:53:36 -07:00
|
|
|
host_indices,
|
|
|
|
|
device_indices,
|
|
|
|
|
i,
|
|
|
|
|
self.io_backend,
|
|
|
|
|
)
|
2025-09-08 22:18:50 -07:00
|
|
|
producer_event.complete(i)
|
|
|
|
|
# NOTE: We must save the host indices and device indices here,
|
|
|
|
|
# this is because we need to guarantee that these tensors are
|
|
|
|
|
# still alive when the load stream is executing.
|
|
|
|
|
if host_indices.is_cuda:
|
|
|
|
|
host_indices.record_stream(self.load_stream)
|
|
|
|
|
if device_indices.is_cuda:
|
|
|
|
|
device_indices.record_stream(self.load_stream)
|
|
|
|
|
|
|
|
|
|
self.ack_load_queue.append(
|
|
|
|
|
HiCacheAck(
|
|
|
|
|
start_event=producer_event.start_event,
|
|
|
|
|
finish_event=producer_event.finish_event,
|
|
|
|
|
node_ids=op.node_ids,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return producer_id
|
2025-03-12 11:22:35 -07:00
|
|
|
|
2025-09-24 23:43:53 -07:00
|
|
|
def evict_device(self, device_indices: torch.Tensor) -> int:
|
|
|
|
|
self.mem_pool_device_allocator.free(device_indices)
|
|
|
|
|
return len(device_indices)
|
2025-01-10 20:22:01 -08:00
|
|
|
|
|
|
|
|
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
|
|
|
|
if not backup_only:
|
|
|
|
|
raise ValueError("Other eviction policies are not supported yet.")
|
|
|
|
|
|
2025-09-24 23:43:53 -07:00
|
|
|
self.mem_pool_host.free(host_indices)
|
|
|
|
|
return len(host_indices)
|
2025-07-18 00:20:19 -07:00
|
|
|
|
|
|
|
|
def prefetch(
|
|
|
|
|
self,
|
|
|
|
|
request_id: str,
|
|
|
|
|
host_indices: torch.Tensor,
|
|
|
|
|
new_input_tokens: List[int],
|
|
|
|
|
last_hash: Optional[str] = None,
|
2025-10-10 15:22:05 +08:00
|
|
|
prefix_keys: Optional[List[str]] = None,
|
2025-08-08 17:09:28 +08:00
|
|
|
) -> PrefetchOperation:
|
2025-07-18 00:20:19 -07:00
|
|
|
"""
|
|
|
|
|
Prefetch KV caches from storage backend to host memory.
|
|
|
|
|
"""
|
|
|
|
|
operation = PrefetchOperation(
|
2025-10-10 15:22:05 +08:00
|
|
|
request_id, host_indices, new_input_tokens, last_hash, prefix_keys
|
2025-07-18 00:20:19 -07:00
|
|
|
)
|
|
|
|
|
self.prefetch_queue.put(operation)
|
|
|
|
|
return operation
|
|
|
|
|
|
|
|
|
|
def terminate_prefetch(self, operation):
|
2025-09-08 09:34:04 +08:00
|
|
|
operation.mark_terminate()
|
2025-07-18 00:20:19 -07:00
|
|
|
return operation.completed_tokens, operation.hash_value
|
|
|
|
|
|
2025-08-31 07:58:21 -07:00
|
|
|
def append_host_mem_release(self, host_indices: torch.Tensor):
|
2025-10-01 21:42:37 +08:00
|
|
|
if host_indices.numel() == 0:
|
|
|
|
|
return
|
|
|
|
|
pages = host_indices.split(self.mem_pool_host.page_size)
|
|
|
|
|
for page in pages:
|
|
|
|
|
self.host_mem_release_queue.put(page)
|
2025-08-31 07:58:21 -07:00
|
|
|
|
2025-10-10 15:22:05 +08:00
|
|
|
def _page_get_zero_copy(
|
|
|
|
|
self, operation, hash_values, host_indices, extra_info=None
|
|
|
|
|
):
|
|
|
|
|
results = self.storage_backend.batch_get_v1(
|
|
|
|
|
hash_values, host_indices, extra_info
|
|
|
|
|
)
|
2025-09-23 06:17:31 +08:00
|
|
|
inc = 0
|
|
|
|
|
for i in range(len(hash_values)):
|
|
|
|
|
if not results[i]:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
|
|
|
|
)
|
|
|
|
|
break
|
|
|
|
|
inc += self.page_size
|
|
|
|
|
operation.increment(inc)
|
2025-08-26 10:05:10 +08:00
|
|
|
|
2025-09-23 06:17:31 +08:00
|
|
|
# todo: deprecate
|
2025-10-10 15:22:05 +08:00
|
|
|
def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None):
|
2025-09-02 15:52:37 -07:00
|
|
|
dummy_page_dst = [
|
|
|
|
|
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
|
|
|
|
]
|
2025-08-26 10:05:10 +08:00
|
|
|
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
|
|
|
|
if page_data is None:
|
|
|
|
|
return
|
|
|
|
|
for i in range(len(hash_values)):
|
|
|
|
|
if page_data[i] is None:
|
2025-08-22 17:56:38 +08:00
|
|
|
logger.warning(
|
2025-08-26 10:05:10 +08:00
|
|
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
2025-08-22 17:56:38 +08:00
|
|
|
)
|
|
|
|
|
break
|
2025-09-02 20:22:06 +08:00
|
|
|
# Must set the data before increasing the completed tokens.
|
|
|
|
|
# Otherwise this page may be read before being set.
|
|
|
|
|
self.mem_pool_host.set_from_flat_data_page(
|
|
|
|
|
host_indices[i * self.page_size],
|
|
|
|
|
page_data[i],
|
|
|
|
|
)
|
|
|
|
|
if not operation.increment(self.page_size):
|
|
|
|
|
break # Operation terminated by controller
|
2025-08-26 10:05:10 +08:00
|
|
|
|
|
|
|
|
def _page_transfer(self, operation):
|
|
|
|
|
# Transfer batch by batch
|
2025-10-10 15:22:05 +08:00
|
|
|
prefix_keys = operation.prefix_keys
|
2025-08-31 07:58:21 -07:00
|
|
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
|
|
|
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
2025-08-26 10:05:10 +08:00
|
|
|
batch_host_indices = operation.host_indices[
|
|
|
|
|
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
2025-08-22 17:56:38 +08:00
|
|
|
]
|
2025-08-26 10:05:10 +08:00
|
|
|
prev_completed_tokens = operation.completed_tokens
|
|
|
|
|
# Get one batch token, and update the completed_tokens if succeed
|
2025-10-10 15:22:05 +08:00
|
|
|
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
|
|
|
|
|
self.page_get_func(operation, batch_hashes, batch_host_indices, extra_info)
|
2025-08-26 10:05:10 +08:00
|
|
|
# Check termination
|
|
|
|
|
if (
|
|
|
|
|
operation.completed_tokens
|
|
|
|
|
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
|
|
|
|
):
|
2025-09-08 09:34:04 +08:00
|
|
|
operation.mark_terminate()
|
2025-08-26 10:05:10 +08:00
|
|
|
break # Some operations fail or operation terminated by controller
|
2025-10-10 15:22:05 +08:00
|
|
|
|
|
|
|
|
if prefix_keys and len(prefix_keys) > 0:
|
|
|
|
|
prefix_keys += batch_hashes
|
|
|
|
|
|
2025-08-26 10:05:10 +08:00
|
|
|
# release pre-allocated memory
|
2025-08-31 07:58:21 -07:00
|
|
|
self.append_host_mem_release(
|
|
|
|
|
operation.host_indices[operation.completed_tokens :]
|
|
|
|
|
)
|
2025-08-02 06:58:17 +08:00
|
|
|
|
2025-07-18 00:20:19 -07:00
|
|
|
def prefetch_io_aux_func(self):
|
|
|
|
|
"""
|
|
|
|
|
Auxiliary function conducting IO operations for prefetching.
|
|
|
|
|
"""
|
|
|
|
|
while not self.stop_event.is_set():
|
|
|
|
|
try:
|
|
|
|
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
2025-08-26 10:05:10 +08:00
|
|
|
self._page_transfer(operation)
|
2025-08-09 01:16:51 -07:00
|
|
|
# operation terminated by controller, release pre-allocated memory
|
2025-08-31 07:58:21 -07:00
|
|
|
self.append_host_mem_release(
|
2025-08-09 01:16:51 -07:00
|
|
|
operation.host_indices[operation.completed_tokens :]
|
|
|
|
|
)
|
2025-07-18 00:20:19 -07:00
|
|
|
except Empty:
|
|
|
|
|
continue
|
|
|
|
|
|
2025-08-31 07:58:21 -07:00
|
|
|
def prefetch_rate_limited(self) -> bool:
|
2025-08-08 17:09:28 +08:00
|
|
|
"""
|
|
|
|
|
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
|
|
|
|
"""
|
|
|
|
|
# cancel prefetch if too much memory is occupied
|
|
|
|
|
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
2025-08-31 07:58:21 -07:00
|
|
|
return True
|
2025-08-08 17:09:28 +08:00
|
|
|
# todo: more sophisticated rate limiting based on storage backend performance
|
2025-08-31 07:58:21 -07:00
|
|
|
return False
|
2025-08-08 17:09:28 +08:00
|
|
|
|
2025-08-31 07:58:21 -07:00
|
|
|
def _storage_hit_query(self, operation) -> tuple[list[str], int]:
|
2025-08-26 10:05:10 +08:00
|
|
|
last_hash = operation.last_hash
|
|
|
|
|
tokens_to_fetch = operation.token_ids
|
2025-10-10 15:22:05 +08:00
|
|
|
prefix_keys = operation.prefix_keys.copy() if operation.prefix_keys else None
|
2025-08-26 10:05:10 +08:00
|
|
|
|
|
|
|
|
storage_query_count = 0
|
|
|
|
|
hash_value = []
|
2025-08-31 07:58:21 -07:00
|
|
|
|
|
|
|
|
for start in range(
|
|
|
|
|
0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
|
|
|
|
|
):
|
|
|
|
|
end = min(
|
|
|
|
|
start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
|
2025-08-26 10:05:10 +08:00
|
|
|
)
|
2025-08-31 07:58:21 -07:00
|
|
|
batch_tokens = tokens_to_fetch[start:end]
|
|
|
|
|
batch_hashes = []
|
|
|
|
|
for i in range(0, len(batch_tokens), self.page_size):
|
|
|
|
|
last_hash = self.get_hash_str(
|
|
|
|
|
batch_tokens[i : i + self.page_size], last_hash
|
|
|
|
|
)
|
|
|
|
|
batch_hashes.append(last_hash)
|
2025-10-10 15:22:05 +08:00
|
|
|
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
|
|
|
|
|
hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info)
|
2025-08-31 07:58:21 -07:00
|
|
|
hash_value.extend(batch_hashes[:hit_page_num])
|
|
|
|
|
storage_query_count += hit_page_num * self.page_size
|
|
|
|
|
if hit_page_num < len(batch_hashes):
|
|
|
|
|
break
|
2025-10-10 15:22:05 +08:00
|
|
|
if prefix_keys and len(prefix_keys) > 0:
|
|
|
|
|
prefix_keys += batch_hashes
|
|
|
|
|
|
2025-08-31 07:58:21 -07:00
|
|
|
return hash_value, storage_query_count
|
2025-08-26 10:05:10 +08:00
|
|
|
|
2025-07-18 00:20:19 -07:00
|
|
|
def prefetch_thread_func(self):
|
|
|
|
|
"""
|
|
|
|
|
Manage prefetching operations from storage backend to host memory.
|
|
|
|
|
"""
|
|
|
|
|
self.prefetch_buffer = Queue()
|
|
|
|
|
aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
|
|
|
|
|
aux_thread.start()
|
|
|
|
|
while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
|
|
|
|
|
try:
|
|
|
|
|
operation = self.prefetch_queue.get(block=True, timeout=1)
|
|
|
|
|
if operation is None:
|
|
|
|
|
continue
|
|
|
|
|
|
2025-08-31 07:58:21 -07:00
|
|
|
hash_value, storage_hit_count = self._storage_hit_query(operation)
|
2025-07-24 17:31:47 -07:00
|
|
|
if self.tp_world_size > 1:
|
|
|
|
|
storage_hit_count_tensor = torch.tensor(
|
|
|
|
|
storage_hit_count, dtype=torch.int
|
|
|
|
|
)
|
|
|
|
|
torch.distributed.all_reduce(
|
|
|
|
|
storage_hit_count_tensor,
|
|
|
|
|
op=torch.distributed.ReduceOp.MIN,
|
2025-07-31 08:42:41 +08:00
|
|
|
group=self.prefetch_tp_group,
|
2025-07-24 17:31:47 -07:00
|
|
|
)
|
|
|
|
|
storage_hit_count = storage_hit_count_tensor.item()
|
|
|
|
|
|
2025-07-18 00:20:19 -07:00
|
|
|
if storage_hit_count < self.prefetch_threshold:
|
|
|
|
|
# not to prefetch if not enough benefits
|
|
|
|
|
self.prefetch_revoke_queue.put(operation.request_id)
|
2025-08-31 07:58:21 -07:00
|
|
|
self.append_host_mem_release(operation.host_indices)
|
2025-07-24 17:31:47 -07:00
|
|
|
logger.debug(
|
|
|
|
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
|
|
|
|
)
|
2025-07-18 00:20:19 -07:00
|
|
|
else:
|
2025-07-24 17:31:47 -07:00
|
|
|
operation.hash_value = hash_value[
|
|
|
|
|
: (storage_hit_count // self.page_size)
|
|
|
|
|
]
|
|
|
|
|
# free the pre-allocated memory for pages that are not hit
|
2025-08-31 07:58:21 -07:00
|
|
|
self.append_host_mem_release(
|
|
|
|
|
operation.host_indices[storage_hit_count:]
|
|
|
|
|
)
|
2025-07-24 17:31:47 -07:00
|
|
|
operation.host_indices = operation.host_indices[:storage_hit_count]
|
2025-07-18 00:20:19 -07:00
|
|
|
logger.debug(
|
2025-07-24 17:31:47 -07:00
|
|
|
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
|
2025-07-18 00:20:19 -07:00
|
|
|
)
|
|
|
|
|
self.prefetch_buffer.put(operation)
|
|
|
|
|
|
|
|
|
|
except Empty:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
def write_storage(
|
|
|
|
|
self,
|
|
|
|
|
host_indices: torch.Tensor,
|
|
|
|
|
token_ids: List[int],
|
2025-08-11 14:18:59 -07:00
|
|
|
hash_value: Optional[List[str]] = None,
|
2025-10-10 15:22:05 +08:00
|
|
|
prefix_keys: Optional[List[str]] = None,
|
2025-07-18 00:20:19 -07:00
|
|
|
) -> int:
|
|
|
|
|
"""
|
|
|
|
|
Write KV caches from host memory to storage backend.
|
|
|
|
|
"""
|
2025-10-10 15:22:05 +08:00
|
|
|
operation = StorageOperation(
|
|
|
|
|
host_indices, token_ids, hash_value=hash_value, prefix_keys=prefix_keys
|
|
|
|
|
)
|
2025-07-18 00:20:19 -07:00
|
|
|
self.backup_queue.put(operation)
|
|
|
|
|
return operation.id
|
|
|
|
|
|
2025-09-23 06:17:31 +08:00
|
|
|
# todo: deprecate
|
2025-10-10 15:22:05 +08:00
|
|
|
def _generic_page_set(self, hash_values, host_indices, extra_info=None) -> bool:
|
2025-08-26 10:05:10 +08:00
|
|
|
data = [
|
2025-09-23 06:17:31 +08:00
|
|
|
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
|
2025-08-26 10:05:10 +08:00
|
|
|
for i in range(len(hash_values))
|
|
|
|
|
]
|
|
|
|
|
return self.storage_backend.batch_set(hash_values, data)
|
|
|
|
|
|
2025-10-10 15:22:05 +08:00
|
|
|
def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> bool:
|
|
|
|
|
return all(
|
|
|
|
|
self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info)
|
|
|
|
|
)
|
2025-08-26 10:05:10 +08:00
|
|
|
|
|
|
|
|
# Backup batch by batch
|
|
|
|
|
def _page_backup(self, operation):
|
|
|
|
|
# Backup batch by batch
|
2025-10-10 15:22:05 +08:00
|
|
|
prefix_keys = operation.prefix_keys
|
2025-08-31 07:58:21 -07:00
|
|
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
|
|
|
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
2025-08-26 10:05:10 +08:00
|
|
|
batch_host_indices = operation.host_indices[
|
|
|
|
|
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
2025-07-31 14:15:51 +08:00
|
|
|
]
|
2025-08-26 10:05:10 +08:00
|
|
|
# Set one batch token, and record if success.
|
|
|
|
|
# todo: allow partial success
|
2025-10-10 15:22:05 +08:00
|
|
|
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
|
|
|
|
|
success = self.page_set_func(batch_hashes, batch_host_indices, extra_info)
|
2025-07-31 14:15:51 +08:00
|
|
|
if not success:
|
2025-08-26 10:05:10 +08:00
|
|
|
logger.warning(
|
|
|
|
|
f"Write page to storage: {len(batch_hashes)} pages failed."
|
2025-07-31 14:15:51 +08:00
|
|
|
)
|
2025-08-26 10:05:10 +08:00
|
|
|
break
|
2025-10-10 15:22:05 +08:00
|
|
|
|
|
|
|
|
if prefix_keys and len(prefix_keys) > 0:
|
|
|
|
|
prefix_keys += batch_hashes
|
2025-08-26 10:05:10 +08:00
|
|
|
operation.completed_tokens += self.page_size * len(batch_hashes)
|
2025-07-31 14:15:51 +08:00
|
|
|
|
2025-07-18 00:20:19 -07:00
|
|
|
def backup_thread_func(self):
|
|
|
|
|
"""
|
|
|
|
|
Manage backup operations from host memory to storage backend.
|
|
|
|
|
"""
|
|
|
|
|
while not self.stop_event.is_set():
|
|
|
|
|
try:
|
|
|
|
|
operation = self.backup_queue.get(block=True, timeout=1)
|
|
|
|
|
if operation is None:
|
|
|
|
|
continue
|
|
|
|
|
|
2025-08-22 18:03:51 +08:00
|
|
|
if not self.backup_skip:
|
2025-08-26 10:05:10 +08:00
|
|
|
self._page_backup(operation)
|
2025-09-06 07:52:55 +08:00
|
|
|
self.ack_backup_queue.put(operation)
|
2025-07-18 00:20:19 -07:00
|
|
|
|
|
|
|
|
except Empty:
|
|
|
|
|
continue
|