fix bug that gpu0 occupies more memory when hicache is turned on (#5778)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -268,7 +268,7 @@ class HiCacheController:
|
||||
"""
|
||||
Directly write through KV caches to host memory without buffering.
|
||||
"""
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
torch.cuda.set_stream(self.write_stream)
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
@@ -291,7 +291,7 @@ class HiCacheController:
|
||||
"""
|
||||
Directly load KV caches from host memory to device memory without buffering.
|
||||
"""
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
torch.cuda.set_stream(self.load_stream)
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
@@ -299,9 +299,7 @@ class HiCacheController:
|
||||
operation.data = self.mem_pool_host.get_flat_data(
|
||||
operation.host_indices
|
||||
)
|
||||
self.mem_pool_device.transfer(
|
||||
operation.device_indices, operation.data
|
||||
)
|
||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
@@ -315,7 +313,7 @@ class HiCacheController:
|
||||
"""
|
||||
Load KV caches from host memory to device memory layer by layer.
|
||||
"""
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
torch.cuda.set_stream(self.load_stream)
|
||||
while not self.stop_event.is_set():
|
||||
self.load_cache_event.wait(timeout=1)
|
||||
if not self.load_cache_event.is_set():
|
||||
@@ -360,6 +358,7 @@ class HiCacheController:
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for write operations.
|
||||
"""
|
||||
torch.cuda.set_stream(self.write_stream)
|
||||
|
||||
def _to_op(op_):
|
||||
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
||||
@@ -370,13 +369,11 @@ class HiCacheController:
|
||||
return op_
|
||||
|
||||
buffer = None
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
factor = (
|
||||
len(operation.device_indices)
|
||||
// self.write_buffer.max_buffer_size
|
||||
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
||||
)
|
||||
|
||||
if factor >= 1:
|
||||
@@ -484,10 +481,9 @@ class HiCacheController:
|
||||
aux_thread.join()
|
||||
|
||||
def load_thread_func_buffer(self):
|
||||
torch.cuda.set_stream(self.load_stream)
|
||||
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while not self.stop_event.is_set():
|
||||
operation = self.load_buffer.get()
|
||||
if operation is None:
|
||||
|
||||
Reference in New Issue
Block a user