fix bug that gpu0 occupies more memory when hicache is turned on (#5778)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -268,7 +268,7 @@ class HiCacheController:
|
|||||||
"""
|
"""
|
||||||
Directly write through KV caches to host memory without buffering.
|
Directly write through KV caches to host memory without buffering.
|
||||||
"""
|
"""
|
||||||
with torch.cuda.stream(self.write_stream):
|
torch.cuda.set_stream(self.write_stream)
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.write_queue.get(block=True, timeout=1)
|
operation = self.write_queue.get(block=True, timeout=1)
|
||||||
@@ -291,7 +291,7 @@ class HiCacheController:
|
|||||||
"""
|
"""
|
||||||
Directly load KV caches from host memory to device memory without buffering.
|
Directly load KV caches from host memory to device memory without buffering.
|
||||||
"""
|
"""
|
||||||
with torch.cuda.stream(self.load_stream):
|
torch.cuda.set_stream(self.load_stream)
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.load_queue.get(block=True, timeout=1)
|
operation = self.load_queue.get(block=True, timeout=1)
|
||||||
@@ -299,9 +299,7 @@ class HiCacheController:
|
|||||||
operation.data = self.mem_pool_host.get_flat_data(
|
operation.data = self.mem_pool_host.get_flat_data(
|
||||||
operation.host_indices
|
operation.host_indices
|
||||||
)
|
)
|
||||||
self.mem_pool_device.transfer(
|
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||||
operation.device_indices, operation.data
|
|
||||||
)
|
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
self.mem_pool_host.complete_io(operation.host_indices)
|
||||||
for node_id in operation.node_ids:
|
for node_id in operation.node_ids:
|
||||||
if node_id != 0:
|
if node_id != 0:
|
||||||
@@ -315,7 +313,7 @@ class HiCacheController:
|
|||||||
"""
|
"""
|
||||||
Load KV caches from host memory to device memory layer by layer.
|
Load KV caches from host memory to device memory layer by layer.
|
||||||
"""
|
"""
|
||||||
with torch.cuda.stream(self.load_stream):
|
torch.cuda.set_stream(self.load_stream)
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
self.load_cache_event.wait(timeout=1)
|
self.load_cache_event.wait(timeout=1)
|
||||||
if not self.load_cache_event.is_set():
|
if not self.load_cache_event.is_set():
|
||||||
@@ -360,6 +358,7 @@ class HiCacheController:
|
|||||||
"""
|
"""
|
||||||
Auxiliary function to prepare the buffer for write operations.
|
Auxiliary function to prepare the buffer for write operations.
|
||||||
"""
|
"""
|
||||||
|
torch.cuda.set_stream(self.write_stream)
|
||||||
|
|
||||||
def _to_op(op_):
|
def _to_op(op_):
|
||||||
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
||||||
@@ -370,13 +369,11 @@ class HiCacheController:
|
|||||||
return op_
|
return op_
|
||||||
|
|
||||||
buffer = None
|
buffer = None
|
||||||
with torch.cuda.stream(self.write_stream):
|
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.write_queue.get(block=True, timeout=1)
|
operation = self.write_queue.get(block=True, timeout=1)
|
||||||
factor = (
|
factor = (
|
||||||
len(operation.device_indices)
|
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
||||||
// self.write_buffer.max_buffer_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if factor >= 1:
|
if factor >= 1:
|
||||||
@@ -484,10 +481,9 @@ class HiCacheController:
|
|||||||
aux_thread.join()
|
aux_thread.join()
|
||||||
|
|
||||||
def load_thread_func_buffer(self):
|
def load_thread_func_buffer(self):
|
||||||
|
torch.cuda.set_stream(self.load_stream)
|
||||||
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
||||||
aux_thread.start()
|
aux_thread.start()
|
||||||
|
|
||||||
with torch.cuda.stream(self.load_stream):
|
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
operation = self.load_buffer.get()
|
operation = self.load_buffer.get()
|
||||||
if operation is None:
|
if operation is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user