[Refactor] Remove Hicache Load & Write threads (#10127)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
|
||||
if write_back:
|
||||
# blocking till all write back complete
|
||||
while len(self.ongoing_write_through) > 0:
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
del self.ongoing_write_through[ack_id]
|
||||
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
||||
finish_event.synchronize()
|
||||
for ack_id in ack_list:
|
||||
del self.ongoing_write_through[ack_id]
|
||||
self.cache_controller.ack_write_queue.clear()
|
||||
assert len(self.ongoing_write_through) == 0
|
||||
return
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
|
||||
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
|
||||
if len(self.ongoing_write_through) == 0:
|
||||
return
|
||||
|
||||
finish_count = 0
|
||||
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
||||
if not finish_event.query():
|
||||
break
|
||||
finish_count += 1
|
||||
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
|
||||
if self.tp_world_size > 1:
|
||||
# synchrnoize TP workers to make the same update to radix cache
|
||||
# synchronize TP workers to make the same update to radix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
backuped_node = self.ongoing_write_through[ack_id]
|
||||
self.dec_lock_ref(backuped_node)
|
||||
del self.ongoing_write_through[ack_id]
|
||||
if self.enable_storage:
|
||||
self.write_backup_storage(backuped_node)
|
||||
|
||||
finish_count = int(queue_size.item())
|
||||
while finish_count > 0:
|
||||
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
||||
finish_event.synchronize()
|
||||
for ack_id in ack_list:
|
||||
backuped_node = self.ongoing_write_through.pop(ack_id)
|
||||
self.dec_lock_ref(backuped_node)
|
||||
if self.enable_storage:
|
||||
self.write_backup_storage(backuped_node)
|
||||
finish_count -= 1
|
||||
|
||||
def loading_check(self):
|
||||
while not self.cache_controller.ack_load_queue.empty():
|
||||
try:
|
||||
ack_id = self.cache_controller.ack_load_queue.get_nowait()
|
||||
start_node, end_node = self.ongoing_load_back[ack_id]
|
||||
self.dec_lock_ref(end_node)
|
||||
while end_node != start_node:
|
||||
assert end_node.loading
|
||||
end_node.loading = False
|
||||
end_node = end_node.parent
|
||||
# clear the reference
|
||||
del self.ongoing_load_back[ack_id]
|
||||
except Exception:
|
||||
finish_count = 0
|
||||
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
|
||||
if not finish_event.query():
|
||||
# the KV cache loading is still ongoing
|
||||
break
|
||||
finish_count += 1
|
||||
# no need to sync across TP workers as batch forwarding is synced
|
||||
for ack_id in ack_list:
|
||||
end_node = self.ongoing_load_back.pop(ack_id)
|
||||
self.dec_lock_ref(end_node)
|
||||
|
||||
# ACK until all events are processed
|
||||
del self.cache_controller.ack_load_queue[:finish_count]
|
||||
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
|
||||
# no sufficient GPU memory to load back KV caches
|
||||
return None
|
||||
|
||||
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
|
||||
self.ongoing_load_back[last_hit_node.id] = last_hit_node
|
||||
offset = 0
|
||||
for node in nodes_to_load:
|
||||
node.value = device_indices[offset : offset + len(node.host_value)]
|
||||
offset += len(node.host_value)
|
||||
node.loading = True
|
||||
self.evictable_size_ += len(device_indices)
|
||||
self.inc_lock_ref(last_hit_node)
|
||||
|
||||
@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
|
||||
last_node,
|
||||
)
|
||||
|
||||
def ready_to_load_host_cache(self):
|
||||
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
||||
self.load_cache_event.set()
|
||||
return producer_index
|
||||
def ready_to_load_host_cache(self) -> int:
|
||||
"""
|
||||
Notify the cache controller to start the KV cache loading.
|
||||
Return the consumer index for the schedule batch manager to track.
|
||||
"""
|
||||
return self.cache_controller.start_loading()
|
||||
|
||||
def check_hicache_events(self):
|
||||
self.writing_check()
|
||||
@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
|
||||
new_node.parent = child.parent
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = child.key[:split_len]
|
||||
new_node.loading = child.loading
|
||||
new_node.hit_count = child.hit_count
|
||||
|
||||
# split value and host value if exists
|
||||
|
||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
"""
|
||||
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
|
||||
import abc
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GB = 1024 * 1024 * 1024
|
||||
@@ -175,7 +180,7 @@ class KVCache(abc.ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
|
||||
self.layer_transfer_counter = layer_transfer_counter
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
|
||||
return len(self.free_slots)
|
||||
|
||||
@synchronized()
|
||||
def alloc(self, need_size: int) -> torch.Tensor:
|
||||
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
||||
assert (
|
||||
need_size % self.page_size == 0
|
||||
), "The requested size should be a multiple of the page size."
|
||||
|
||||
@@ -53,8 +53,6 @@ class TreeNode:
|
||||
self.last_access_time = time.monotonic()
|
||||
|
||||
self.hit_count = 0
|
||||
# indicating the node is loading KV cache from host
|
||||
self.loading = False
|
||||
# indicating the node is locked to protect from eviction
|
||||
# incremented when the node is referenced by a storage operation
|
||||
self.host_ref_counter = 0
|
||||
|
||||
@@ -60,8 +60,6 @@ class TreeNode:
|
||||
self.last_access_time = time.monotonic()
|
||||
|
||||
self.hit_count = 0
|
||||
# indicating the node is loading KV cache from host
|
||||
self.loading = False
|
||||
# store the host indices of KV cache
|
||||
self.host_value = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user