Hierarchical Caching Refactoring and Fixing TP issue (#4082)
This commit is contained in:
@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LayerDoneCounter:
|
||||
def __init__(self, num_layers):
|
||||
self.counter = num_layers
|
||||
self.condition = threading.Condition()
|
||||
|
||||
def increment(self):
|
||||
with self.condition:
|
||||
self.counter += 1
|
||||
self.condition.notify_all()
|
||||
|
||||
def wait_until(self, threshold):
|
||||
with self.condition:
|
||||
while self.counter <= threshold:
|
||||
self.condition.wait()
|
||||
|
||||
def reset(self):
|
||||
with self.condition:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
class CacheOperation:
|
||||
|
||||
counter = 0
|
||||
@@ -132,6 +152,7 @@ class HiCacheController:
|
||||
self,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
mem_pool_host: MHATokenToKVPoolHost,
|
||||
load_cache_event: threading.Event = None,
|
||||
write_policy: str = "write_through_selective",
|
||||
):
|
||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||
@@ -139,6 +160,10 @@ class HiCacheController:
|
||||
self.mem_pool_host = mem_pool_host
|
||||
self.write_policy = write_policy
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
||||
|
||||
if write_policy not in [
|
||||
"write_through",
|
||||
"write_through_selective",
|
||||
@@ -165,7 +190,7 @@ class HiCacheController:
|
||||
target=self.write_thread_func_buffer, daemon=True
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_buffer, daemon=True
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
)
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
@@ -186,7 +211,7 @@ class HiCacheController:
|
||||
target=self.write_thread_func_buffer, daemon=True
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_buffer, daemon=True
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
)
|
||||
self.stop_event.clear()
|
||||
self.write_thread.start()
|
||||
@@ -273,6 +298,42 @@ class HiCacheController:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_thread_func_layer_by_layer(self):
|
||||
"""
|
||||
Load KV caches from host memory to device memory layer by layer.
|
||||
"""
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while not self.stop_event.is_set():
|
||||
self.load_cache_event.wait(timeout=1)
|
||||
if not self.load_cache_event.is_set():
|
||||
continue
|
||||
self.load_cache_event.clear()
|
||||
|
||||
batch_operation = None
|
||||
while self.load_queue.qsize() > 0:
|
||||
op = self.load_queue.get(block=True)
|
||||
if batch_operation is None:
|
||||
batch_operation = op
|
||||
else:
|
||||
batch_operation.merge(op)
|
||||
if batch_operation is None:
|
||||
continue
|
||||
|
||||
self.layer_done_counter.reset()
|
||||
for i in range(self.mem_pool_host.layer_num):
|
||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
||||
batch_operation.host_indices, i
|
||||
)
|
||||
self.mem_pool_device.transfer_per_layer(
|
||||
batch_operation.device_indices, flat_data, i
|
||||
)
|
||||
self.layer_done_counter.increment()
|
||||
|
||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
||||
for node_id in batch_operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
|
||||
def write_aux_func(self, no_wait=False):
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for write operations.
|
||||
|
||||
Reference in New Issue
Block a user