[HiCache] Cleaning the deprecated host memory state (#10778)
This commit is contained in:
@@ -462,7 +462,6 @@ class HiCacheController:
|
||||
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
||||
if host_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_write(host_indices)
|
||||
self.write_queue.append(
|
||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||
)
|
||||
@@ -486,7 +485,6 @@ class HiCacheController:
|
||||
self.mem_pool_host.backup_from_device_all_layer(
|
||||
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
||||
)
|
||||
self.mem_pool_host.complete_io(op.host_indices)
|
||||
finish_event.record()
|
||||
# NOTE: We must save the host indices and device indices here,
|
||||
# this is because we need to guarantee that these tensors are
|
||||
@@ -510,7 +508,6 @@ class HiCacheController:
|
||||
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
||||
if device_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_load(host_indices)
|
||||
self.load_queue.append(
|
||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||
)
|
||||
@@ -555,7 +552,6 @@ class HiCacheController:
|
||||
self.io_backend,
|
||||
)
|
||||
producer_event.complete(i)
|
||||
self.mem_pool_host.complete_io(op.host_indices)
|
||||
# NOTE: We must save the host indices and device indices here,
|
||||
# this is because we need to guarantee that these tensors are
|
||||
# still alive when the load stream is executing.
|
||||
@@ -573,29 +569,16 @@ class HiCacheController:
|
||||
)
|
||||
return producer_id
|
||||
|
||||
def evict_device(
|
||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||
) -> int:
|
||||
if self.mem_pool_host.is_synced(host_indices):
|
||||
self.mem_pool_device_allocator.free(device_indices)
|
||||
self.mem_pool_host.update_backup(host_indices)
|
||||
return len(device_indices)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
||||
)
|
||||
def evict_device(self, device_indices: torch.Tensor) -> int:
|
||||
self.mem_pool_device_allocator.free(device_indices)
|
||||
return len(device_indices)
|
||||
|
||||
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
||||
if not backup_only:
|
||||
raise ValueError("Other eviction policies are not supported yet.")
|
||||
|
||||
if self.mem_pool_host.is_backup(host_indices):
|
||||
self.mem_pool_host.free(host_indices)
|
||||
return len(host_indices)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
||||
)
|
||||
self.mem_pool_host.free(host_indices)
|
||||
return len(host_indices)
|
||||
|
||||
def prefetch(
|
||||
self,
|
||||
|
||||
@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
def _evict_backuped(self, node: TreeNode):
|
||||
# evict a node already written to host
|
||||
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
|
||||
num_evicted = self.cache_controller.evict_device(node.value)
|
||||
assert num_evicted > 0
|
||||
self.evictable_size_ -= num_evicted
|
||||
node.value = None
|
||||
@@ -576,8 +576,6 @@ class HiRadixCache(RadixCache):
|
||||
written_indices,
|
||||
hash_value[: min_completed_tokens // self.page_size],
|
||||
)
|
||||
if len(written_indices):
|
||||
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
||||
|
||||
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||
self.cache_controller.append_host_mem_release(
|
||||
@@ -775,7 +773,6 @@ class HiRadixCache(RadixCache):
|
||||
# change the reference if the node is evicted
|
||||
# this often happens in the case of KV cache recomputation
|
||||
node.value = value[:prefix_len]
|
||||
self.token_to_kv_pool_host.update_synced(node.host_value)
|
||||
self.evictable_size_ += len(node.value)
|
||||
else:
|
||||
self._inc_hit_count(node, chunked)
|
||||
@@ -785,7 +782,6 @@ class HiRadixCache(RadixCache):
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
if new_node.evicted:
|
||||
new_node.value = value[:prefix_len]
|
||||
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
||||
self.evictable_size_ += len(new_node.value)
|
||||
else:
|
||||
self._inc_hit_count(new_node, chunked)
|
||||
|
||||
@@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryStateInt(IntEnum):
|
||||
IDLE = 0
|
||||
RESERVED = 1
|
||||
PROTECTED = 2
|
||||
SYNCED = 3
|
||||
BACKUP = 4
|
||||
def synchronized(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
|
||||
def synchronized(debug_only=False):
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if (not debug_only) or self.debug:
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
else:
|
||||
return True
|
||||
|
||||
return wrapper
|
||||
|
||||
return _decorator
|
||||
return wrapper
|
||||
|
||||
|
||||
class HostKVCache(abc.ABC):
|
||||
@@ -110,7 +96,6 @@ class HostKVCache(abc.ABC):
|
||||
|
||||
# A lock for synchronized operations on memory allocation and state transitions.
|
||||
self.lock = threading.RLock()
|
||||
self.debug = logger.isEnabledFor(logging.DEBUG)
|
||||
self.clear()
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -161,7 +146,7 @@ class HostKVCache(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@synchronized()
|
||||
@synchronized
|
||||
def clear(self):
|
||||
# Initialize memory states and tracking structures.
|
||||
self.mem_state = torch.zeros(
|
||||
@@ -172,7 +157,7 @@ class HostKVCache(abc.ABC):
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
@synchronized()
|
||||
@synchronized
|
||||
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
||||
assert (
|
||||
need_size % self.page_size == 0
|
||||
@@ -183,92 +168,13 @@ class HostKVCache(abc.ABC):
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
if self.debug:
|
||||
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
||||
|
||||
return select_index
|
||||
|
||||
@synchronized()
|
||||
@synchronized
|
||||
def free(self, indices: torch.Tensor) -> int:
|
||||
self.free_slots = torch.cat([self.free_slots, indices])
|
||||
if self.debug:
|
||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
||||
return len(indices)
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
||||
assert len(indices) > 0, "The indices should not be empty"
|
||||
states = self.mem_state[indices]
|
||||
assert (
|
||||
states == states[0]
|
||||
).all(), "The memory slots should have the same state {}".format(states)
|
||||
return MemoryStateInt(states[0].item())
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_reserved(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.RESERVED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_protected(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_synced(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_backup(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def update_backup(self, indices: torch.Tensor):
|
||||
if not self.is_synced(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def update_prefetch(self, indices: torch.Tensor):
|
||||
if not self.is_reserved(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def update_synced(self, indices: torch.Tensor):
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def protect_write(self, indices: torch.Tensor):
|
||||
if not self.is_reserved(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be RESERVED before write operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def protect_load(self, indices: torch.Tensor):
|
||||
if not self.is_backup(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in BACKUP state before load operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def complete_io(self, indices: torch.Tensor):
|
||||
if not self.is_protected(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be PROTECTED during I/O operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
|
||||
class MHATokenToKVPoolHost(HostKVCache):
|
||||
device_pool: MHATokenToKVPool
|
||||
|
||||
Reference in New Issue
Block a user