[HiCache] Cleaning the deprecated host memory state (#10778)
This commit is contained in:
@@ -462,7 +462,6 @@ class HiCacheController:
|
|||||||
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
||||||
if host_indices is None:
|
if host_indices is None:
|
||||||
return None
|
return None
|
||||||
self.mem_pool_host.protect_write(host_indices)
|
|
||||||
self.write_queue.append(
|
self.write_queue.append(
|
||||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||||
)
|
)
|
||||||
@@ -486,7 +485,6 @@ class HiCacheController:
|
|||||||
self.mem_pool_host.backup_from_device_all_layer(
|
self.mem_pool_host.backup_from_device_all_layer(
|
||||||
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
||||||
)
|
)
|
||||||
self.mem_pool_host.complete_io(op.host_indices)
|
|
||||||
finish_event.record()
|
finish_event.record()
|
||||||
# NOTE: We must save the host indices and device indices here,
|
# NOTE: We must save the host indices and device indices here,
|
||||||
# this is because we need to guarantee that these tensors are
|
# this is because we need to guarantee that these tensors are
|
||||||
@@ -510,7 +508,6 @@ class HiCacheController:
|
|||||||
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
||||||
if device_indices is None:
|
if device_indices is None:
|
||||||
return None
|
return None
|
||||||
self.mem_pool_host.protect_load(host_indices)
|
|
||||||
self.load_queue.append(
|
self.load_queue.append(
|
||||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||||
)
|
)
|
||||||
@@ -555,7 +552,6 @@ class HiCacheController:
|
|||||||
self.io_backend,
|
self.io_backend,
|
||||||
)
|
)
|
||||||
producer_event.complete(i)
|
producer_event.complete(i)
|
||||||
self.mem_pool_host.complete_io(op.host_indices)
|
|
||||||
# NOTE: We must save the host indices and device indices here,
|
# NOTE: We must save the host indices and device indices here,
|
||||||
# this is because we need to guarantee that these tensors are
|
# this is because we need to guarantee that these tensors are
|
||||||
# still alive when the load stream is executing.
|
# still alive when the load stream is executing.
|
||||||
@@ -573,29 +569,16 @@ class HiCacheController:
|
|||||||
)
|
)
|
||||||
return producer_id
|
return producer_id
|
||||||
|
|
||||||
def evict_device(
|
def evict_device(self, device_indices: torch.Tensor) -> int:
|
||||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
self.mem_pool_device_allocator.free(device_indices)
|
||||||
) -> int:
|
return len(device_indices)
|
||||||
if self.mem_pool_host.is_synced(host_indices):
|
|
||||||
self.mem_pool_device_allocator.free(device_indices)
|
|
||||||
self.mem_pool_host.update_backup(host_indices)
|
|
||||||
return len(device_indices)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
||||||
if not backup_only:
|
if not backup_only:
|
||||||
raise ValueError("Other eviction policies are not supported yet.")
|
raise ValueError("Other eviction policies are not supported yet.")
|
||||||
|
|
||||||
if self.mem_pool_host.is_backup(host_indices):
|
self.mem_pool_host.free(host_indices)
|
||||||
self.mem_pool_host.free(host_indices)
|
return len(host_indices)
|
||||||
return len(host_indices)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def prefetch(
|
def prefetch(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
def _evict_backuped(self, node: TreeNode):
|
def _evict_backuped(self, node: TreeNode):
|
||||||
# evict a node already written to host
|
# evict a node already written to host
|
||||||
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
|
num_evicted = self.cache_controller.evict_device(node.value)
|
||||||
assert num_evicted > 0
|
assert num_evicted > 0
|
||||||
self.evictable_size_ -= num_evicted
|
self.evictable_size_ -= num_evicted
|
||||||
node.value = None
|
node.value = None
|
||||||
@@ -576,8 +576,6 @@ class HiRadixCache(RadixCache):
|
|||||||
written_indices,
|
written_indices,
|
||||||
hash_value[: min_completed_tokens // self.page_size],
|
hash_value[: min_completed_tokens // self.page_size],
|
||||||
)
|
)
|
||||||
if len(written_indices):
|
|
||||||
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
|
||||||
|
|
||||||
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||||
self.cache_controller.append_host_mem_release(
|
self.cache_controller.append_host_mem_release(
|
||||||
@@ -775,7 +773,6 @@ class HiRadixCache(RadixCache):
|
|||||||
# change the reference if the node is evicted
|
# change the reference if the node is evicted
|
||||||
# this often happens in the case of KV cache recomputation
|
# this often happens in the case of KV cache recomputation
|
||||||
node.value = value[:prefix_len]
|
node.value = value[:prefix_len]
|
||||||
self.token_to_kv_pool_host.update_synced(node.host_value)
|
|
||||||
self.evictable_size_ += len(node.value)
|
self.evictable_size_ += len(node.value)
|
||||||
else:
|
else:
|
||||||
self._inc_hit_count(node, chunked)
|
self._inc_hit_count(node, chunked)
|
||||||
@@ -785,7 +782,6 @@ class HiRadixCache(RadixCache):
|
|||||||
new_node = self._split_node(node.key, node, prefix_len)
|
new_node = self._split_node(node.key, node, prefix_len)
|
||||||
if new_node.evicted:
|
if new_node.evicted:
|
||||||
new_node.value = value[:prefix_len]
|
new_node.value = value[:prefix_len]
|
||||||
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
|
||||||
self.evictable_size_ += len(new_node.value)
|
self.evictable_size_ += len(new_node.value)
|
||||||
else:
|
else:
|
||||||
self._inc_hit_count(new_node, chunked)
|
self._inc_hit_count(new_node, chunked)
|
||||||
|
|||||||
@@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu):
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MemoryStateInt(IntEnum):
|
def synchronized(func):
|
||||||
IDLE = 0
|
@wraps(func)
|
||||||
RESERVED = 1
|
def wrapper(self, *args, **kwargs):
|
||||||
PROTECTED = 2
|
with self.lock:
|
||||||
SYNCED = 3
|
return func(self, *args, **kwargs)
|
||||||
BACKUP = 4
|
|
||||||
|
|
||||||
|
return wrapper
|
||||||
def synchronized(debug_only=False):
|
|
||||||
def _decorator(func):
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(self, *args, **kwargs):
|
|
||||||
if (not debug_only) or self.debug:
|
|
||||||
with self.lock:
|
|
||||||
return func(self, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return _decorator
|
|
||||||
|
|
||||||
|
|
||||||
class HostKVCache(abc.ABC):
|
class HostKVCache(abc.ABC):
|
||||||
@@ -110,7 +96,6 @@ class HostKVCache(abc.ABC):
|
|||||||
|
|
||||||
# A lock for synchronized operations on memory allocation and state transitions.
|
# A lock for synchronized operations on memory allocation and state transitions.
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self.debug = logger.isEnabledFor(logging.DEBUG)
|
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -161,7 +146,7 @@ class HostKVCache(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@synchronized()
|
@synchronized
|
||||||
def clear(self):
|
def clear(self):
|
||||||
# Initialize memory states and tracking structures.
|
# Initialize memory states and tracking structures.
|
||||||
self.mem_state = torch.zeros(
|
self.mem_state = torch.zeros(
|
||||||
@@ -172,7 +157,7 @@ class HostKVCache(abc.ABC):
|
|||||||
def available_size(self):
|
def available_size(self):
|
||||||
return len(self.free_slots)
|
return len(self.free_slots)
|
||||||
|
|
||||||
@synchronized()
|
@synchronized
|
||||||
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
||||||
assert (
|
assert (
|
||||||
need_size % self.page_size == 0
|
need_size % self.page_size == 0
|
||||||
@@ -183,92 +168,13 @@ class HostKVCache(abc.ABC):
|
|||||||
select_index = self.free_slots[:need_size]
|
select_index = self.free_slots[:need_size]
|
||||||
self.free_slots = self.free_slots[need_size:]
|
self.free_slots = self.free_slots[need_size:]
|
||||||
|
|
||||||
if self.debug:
|
|
||||||
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
|
||||||
|
|
||||||
return select_index
|
return select_index
|
||||||
|
|
||||||
@synchronized()
|
@synchronized
|
||||||
def free(self, indices: torch.Tensor) -> int:
|
def free(self, indices: torch.Tensor) -> int:
|
||||||
self.free_slots = torch.cat([self.free_slots, indices])
|
self.free_slots = torch.cat([self.free_slots, indices])
|
||||||
if self.debug:
|
|
||||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
|
||||||
return len(indices)
|
return len(indices)
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
|
||||||
assert len(indices) > 0, "The indices should not be empty"
|
|
||||||
states = self.mem_state[indices]
|
|
||||||
assert (
|
|
||||||
states == states[0]
|
|
||||||
).all(), "The memory slots should have the same state {}".format(states)
|
|
||||||
return MemoryStateInt(states[0].item())
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def is_reserved(self, indices: torch.Tensor) -> bool:
|
|
||||||
return self.get_state(indices) == MemoryStateInt.RESERVED
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def is_protected(self, indices: torch.Tensor) -> bool:
|
|
||||||
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def is_synced(self, indices: torch.Tensor) -> bool:
|
|
||||||
return self.get_state(indices) == MemoryStateInt.SYNCED
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def is_backup(self, indices: torch.Tensor) -> bool:
|
|
||||||
return self.get_state(indices) == MemoryStateInt.BACKUP
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def update_backup(self, indices: torch.Tensor):
|
|
||||||
if not self.is_synced(indices):
|
|
||||||
raise ValueError(
|
|
||||||
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
|
||||||
f"Current state: {self.get_state(indices)}"
|
|
||||||
)
|
|
||||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def update_prefetch(self, indices: torch.Tensor):
|
|
||||||
if not self.is_reserved(indices):
|
|
||||||
raise ValueError(
|
|
||||||
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
|
|
||||||
f"Current state: {self.get_state(indices)}"
|
|
||||||
)
|
|
||||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def update_synced(self, indices: torch.Tensor):
|
|
||||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def protect_write(self, indices: torch.Tensor):
|
|
||||||
if not self.is_reserved(indices):
|
|
||||||
raise ValueError(
|
|
||||||
f"The host memory slots should be RESERVED before write operations. "
|
|
||||||
f"Current state: {self.get_state(indices)}"
|
|
||||||
)
|
|
||||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def protect_load(self, indices: torch.Tensor):
|
|
||||||
if not self.is_backup(indices):
|
|
||||||
raise ValueError(
|
|
||||||
f"The host memory slots should be in BACKUP state before load operations. "
|
|
||||||
f"Current state: {self.get_state(indices)}"
|
|
||||||
)
|
|
||||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
|
||||||
|
|
||||||
@synchronized(debug_only=True)
|
|
||||||
def complete_io(self, indices: torch.Tensor):
|
|
||||||
if not self.is_protected(indices):
|
|
||||||
raise ValueError(
|
|
||||||
f"The host memory slots should be PROTECTED during I/O operations. "
|
|
||||||
f"Current state: {self.get_state(indices)}"
|
|
||||||
)
|
|
||||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
|
||||||
|
|
||||||
|
|
||||||
class MHATokenToKVPoolHost(HostKVCache):
|
class MHATokenToKVPoolHost(HostKVCache):
|
||||||
device_pool: MHATokenToKVPool
|
device_pool: MHATokenToKVPool
|
||||||
|
|||||||
Reference in New Issue
Block a user