[Refactor] Remove Hicache Load & Write threads (#10127)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
DarkSharpness
2025-09-08 22:18:50 -07:00
committed by GitHub
parent cdc56ef6c1
commit 948b01a04c
10 changed files with 215 additions and 204 deletions

View File

@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
if write_back:
# blocking till all write back complete
while len(self.ongoing_write_through) > 0:
ack_id = self.cache_controller.ack_write_queue.get()
del self.ongoing_write_through[ack_id]
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
finish_event.synchronize()
for ack_id in ack_list:
del self.ongoing_write_through[ack_id]
self.cache_controller.ack_write_queue.clear()
assert len(self.ongoing_write_through) == 0
return
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
)
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
if len(self.ongoing_write_through) == 0:
return
finish_count = 0
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
if not finish_event.query():
break
finish_count += 1
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to radix cache
# synchronize TP workers to make the same update to radix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get()
backuped_node = self.ongoing_write_through[ack_id]
self.dec_lock_ref(backuped_node)
del self.ongoing_write_through[ack_id]
if self.enable_storage:
self.write_backup_storage(backuped_node)
finish_count = int(queue_size.item())
while finish_count > 0:
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
finish_event.synchronize()
for ack_id in ack_list:
backuped_node = self.ongoing_write_through.pop(ack_id)
self.dec_lock_ref(backuped_node)
if self.enable_storage:
self.write_backup_storage(backuped_node)
finish_count -= 1
def loading_check(self):
while not self.cache_controller.ack_load_queue.empty():
try:
ack_id = self.cache_controller.ack_load_queue.get_nowait()
start_node, end_node = self.ongoing_load_back[ack_id]
self.dec_lock_ref(end_node)
while end_node != start_node:
assert end_node.loading
end_node.loading = False
end_node = end_node.parent
# clear the reference
del self.ongoing_load_back[ack_id]
except Exception:
finish_count = 0
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
if not finish_event.query():
# the KV cache loading is still ongoing
break
finish_count += 1
# no need to sync across TP workers as batch forwarding is synced
for ack_id in ack_list:
end_node = self.ongoing_load_back.pop(ack_id)
self.dec_lock_ref(end_node)
# ACK until all events are processed
del self.cache_controller.ack_load_queue[:finish_count]
def evictable_size(self):
return self.evictable_size_
@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
# no sufficient GPU memory to load back KV caches
return None
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
self.ongoing_load_back[last_hit_node.id] = last_hit_node
offset = 0
for node in nodes_to_load:
node.value = device_indices[offset : offset + len(node.host_value)]
offset += len(node.host_value)
node.loading = True
self.evictable_size_ += len(device_indices)
self.inc_lock_ref(last_hit_node)
@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
last_node,
)
def ready_to_load_host_cache(self):
producer_index = self.cache_controller.layer_done_counter.next_producer()
self.load_cache_event.set()
return producer_index
def ready_to_load_host_cache(self) -> int:
"""
Notify the cache controller to start the KV cache loading.
Return the consumer index for the schedule batch manager to track.
"""
return self.cache_controller.start_loading()
def check_hicache_events(self):
self.writing_check()
@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.loading = child.loading
new_node.hit_count = child.hit_count
# split value and host value if exists

View File

@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import annotations
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
import abc
import logging
from contextlib import nullcontext
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
@@ -175,7 +180,7 @@ class KVCache(abc.ABC):
) -> None:
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter):
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
self.layer_transfer_counter = layer_transfer_counter
def get_cpu_copy(self, indices):

View File

@@ -3,6 +3,7 @@ import logging
import threading
from enum import IntEnum
from functools import wraps
from typing import Optional
import psutil
import torch
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
return len(self.free_slots)
@synchronized()
def alloc(self, need_size: int) -> torch.Tensor:
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
assert (
need_size % self.page_size == 0
), "The requested size should be a multiple of the page size."

View File

@@ -53,8 +53,6 @@ class TreeNode:
self.last_access_time = time.monotonic()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# indicating the node is locked to protect from eviction
# incremented when the node is referenced by a storage operation
self.host_ref_counter = 0

View File

@@ -60,8 +60,6 @@ class TreeNode:
self.last_access_time = time.monotonic()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache
self.host_value = None