fix bug that gpu0 occupies more memory when hicache is turned on (#5778)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -268,98 +268,97 @@ class HiCacheController:
|
|||||||
"""
|
"""
|
||||||
Directly write through KV caches to host memory without buffering.
|
Directly write through KV caches to host memory without buffering.
|
||||||
"""
|
"""
|
||||||
with torch.cuda.stream(self.write_stream):
|
torch.cuda.set_stream(self.write_stream)
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.write_queue.get(block=True, timeout=1)
|
operation = self.write_queue.get(block=True, timeout=1)
|
||||||
self.mem_pool_host.write_page_all_layers(
|
self.mem_pool_host.write_page_all_layers(
|
||||||
operation.host_indices,
|
operation.host_indices,
|
||||||
operation.device_indices,
|
operation.device_indices,
|
||||||
self.mem_pool_device,
|
self.mem_pool_device,
|
||||||
)
|
)
|
||||||
self.write_stream.synchronize()
|
self.write_stream.synchronize()
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
self.mem_pool_host.complete_io(operation.host_indices)
|
||||||
for node_id in operation.node_ids:
|
for node_id in operation.node_ids:
|
||||||
if node_id != 0:
|
if node_id != 0:
|
||||||
self.ack_write_queue.put(node_id)
|
self.ack_write_queue.put(node_id)
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|
||||||
def load_thread_func_direct(self):
|
def load_thread_func_direct(self):
|
||||||
"""
|
"""
|
||||||
Directly load KV caches from host memory to device memory without buffering.
|
Directly load KV caches from host memory to device memory without buffering.
|
||||||
"""
|
"""
|
||||||
with torch.cuda.stream(self.load_stream):
|
torch.cuda.set_stream(self.load_stream)
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.load_queue.get(block=True, timeout=1)
|
operation = self.load_queue.get(block=True, timeout=1)
|
||||||
# time.sleep(18e-6 * len(operation.host_indices))
|
# time.sleep(18e-6 * len(operation.host_indices))
|
||||||
operation.data = self.mem_pool_host.get_flat_data(
|
operation.data = self.mem_pool_host.get_flat_data(
|
||||||
operation.host_indices
|
operation.host_indices
|
||||||
)
|
)
|
||||||
self.mem_pool_device.transfer(
|
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||||
operation.device_indices, operation.data
|
self.mem_pool_host.complete_io(operation.host_indices)
|
||||||
)
|
for node_id in operation.node_ids:
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
if node_id != 0:
|
||||||
for node_id in operation.node_ids:
|
self.ack_load_queue.put(node_id)
|
||||||
if node_id != 0:
|
except Empty:
|
||||||
self.ack_load_queue.put(node_id)
|
continue
|
||||||
except Empty:
|
except Exception as e:
|
||||||
continue
|
logger.error(e)
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
|
|
||||||
def load_thread_func_layer_by_layer(self):
|
def load_thread_func_layer_by_layer(self):
|
||||||
"""
|
"""
|
||||||
Load KV caches from host memory to device memory layer by layer.
|
Load KV caches from host memory to device memory layer by layer.
|
||||||
"""
|
"""
|
||||||
with torch.cuda.stream(self.load_stream):
|
torch.cuda.set_stream(self.load_stream)
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
self.load_cache_event.wait(timeout=1)
|
self.load_cache_event.wait(timeout=1)
|
||||||
if not self.load_cache_event.is_set():
|
if not self.load_cache_event.is_set():
|
||||||
continue
|
continue
|
||||||
self.load_cache_event.clear()
|
self.load_cache_event.clear()
|
||||||
|
|
||||||
batch_operation = None
|
batch_operation = None
|
||||||
while self.load_queue.qsize() > 0:
|
while self.load_queue.qsize() > 0:
|
||||||
op = self.load_queue.get(block=True)
|
op = self.load_queue.get(block=True)
|
||||||
if batch_operation is None:
|
|
||||||
batch_operation = op
|
|
||||||
else:
|
|
||||||
batch_operation.merge(op)
|
|
||||||
if batch_operation is None:
|
if batch_operation is None:
|
||||||
continue
|
batch_operation = op
|
||||||
|
else:
|
||||||
|
batch_operation.merge(op)
|
||||||
|
if batch_operation is None:
|
||||||
|
continue
|
||||||
|
|
||||||
self.layer_done_counter.reset()
|
self.layer_done_counter.reset()
|
||||||
for i in range(self.mem_pool_host.layer_num):
|
for i in range(self.mem_pool_host.layer_num):
|
||||||
if self.page_size == 1:
|
if self.page_size == 1:
|
||||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
||||||
batch_operation.host_indices, i
|
batch_operation.host_indices, i
|
||||||
)
|
)
|
||||||
self.mem_pool_device.transfer_per_layer(
|
self.mem_pool_device.transfer_per_layer(
|
||||||
batch_operation.device_indices, flat_data, i
|
batch_operation.device_indices, flat_data, i
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.mem_pool_host.load_page_per_layer(
|
self.mem_pool_host.load_page_per_layer(
|
||||||
batch_operation.host_indices,
|
batch_operation.host_indices,
|
||||||
batch_operation.device_indices,
|
batch_operation.device_indices,
|
||||||
self.mem_pool_device,
|
self.mem_pool_device,
|
||||||
i,
|
i,
|
||||||
)
|
)
|
||||||
self.load_stream.synchronize()
|
self.load_stream.synchronize()
|
||||||
self.layer_done_counter.increment()
|
self.layer_done_counter.increment()
|
||||||
|
|
||||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
||||||
for node_id in batch_operation.node_ids:
|
for node_id in batch_operation.node_ids:
|
||||||
if node_id != 0:
|
if node_id != 0:
|
||||||
self.ack_load_queue.put(node_id)
|
self.ack_load_queue.put(node_id)
|
||||||
|
|
||||||
def write_aux_func(self, no_wait=False):
|
def write_aux_func(self, no_wait=False):
|
||||||
"""
|
"""
|
||||||
Auxiliary function to prepare the buffer for write operations.
|
Auxiliary function to prepare the buffer for write operations.
|
||||||
"""
|
"""
|
||||||
|
torch.cuda.set_stream(self.write_stream)
|
||||||
|
|
||||||
def _to_op(op_):
|
def _to_op(op_):
|
||||||
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
||||||
@@ -370,44 +369,42 @@ class HiCacheController:
|
|||||||
return op_
|
return op_
|
||||||
|
|
||||||
buffer = None
|
buffer = None
|
||||||
with torch.cuda.stream(self.write_stream):
|
while not self.stop_event.is_set():
|
||||||
while not self.stop_event.is_set():
|
try:
|
||||||
try:
|
operation = self.write_queue.get(block=True, timeout=1)
|
||||||
operation = self.write_queue.get(block=True, timeout=1)
|
factor = (
|
||||||
factor = (
|
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
||||||
len(operation.device_indices)
|
)
|
||||||
// self.write_buffer.max_buffer_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if factor >= 1:
|
if factor >= 1:
|
||||||
if buffer is not None:
|
if buffer is not None:
|
||||||
_to_op(buffer)
|
|
||||||
buffer = None
|
|
||||||
|
|
||||||
if factor < 2:
|
|
||||||
_to_op(operation)
|
|
||||||
else:
|
|
||||||
split_ops = operation.split(factor)
|
|
||||||
for op_ in split_ops:
|
|
||||||
_to_op(op_)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if buffer is None:
|
|
||||||
buffer = operation
|
|
||||||
else:
|
|
||||||
buffer.merge(operation)
|
|
||||||
if (
|
|
||||||
no_wait
|
|
||||||
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
|
||||||
or self.write_queue.empty()
|
|
||||||
or self.write_buffer.empty()
|
|
||||||
):
|
|
||||||
_to_op(buffer)
|
_to_op(buffer)
|
||||||
buffer = None
|
buffer = None
|
||||||
except Empty:
|
|
||||||
|
if factor < 2:
|
||||||
|
_to_op(operation)
|
||||||
|
else:
|
||||||
|
split_ops = operation.split(factor)
|
||||||
|
for op_ in split_ops:
|
||||||
|
_to_op(op_)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
if buffer is None:
|
||||||
|
buffer = operation
|
||||||
|
else:
|
||||||
|
buffer.merge(operation)
|
||||||
|
if (
|
||||||
|
no_wait
|
||||||
|
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
||||||
|
or self.write_queue.empty()
|
||||||
|
or self.write_buffer.empty()
|
||||||
|
):
|
||||||
|
_to_op(buffer)
|
||||||
|
buffer = None
|
||||||
|
except Empty:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
def load_aux_func(self):
|
def load_aux_func(self):
|
||||||
"""
|
"""
|
||||||
@@ -484,19 +481,18 @@ class HiCacheController:
|
|||||||
aux_thread.join()
|
aux_thread.join()
|
||||||
|
|
||||||
def load_thread_func_buffer(self):
|
def load_thread_func_buffer(self):
|
||||||
|
torch.cuda.set_stream(self.load_stream)
|
||||||
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
||||||
aux_thread.start()
|
aux_thread.start()
|
||||||
|
while not self.stop_event.is_set():
|
||||||
with torch.cuda.stream(self.load_stream):
|
operation = self.load_buffer.get()
|
||||||
while not self.stop_event.is_set():
|
if operation is None:
|
||||||
operation = self.load_buffer.get()
|
continue
|
||||||
if operation is None:
|
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||||
continue
|
self.mem_pool_host.complete_io(operation.host_indices)
|
||||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
for node_id in operation.node_ids:
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
if node_id != 0:
|
||||||
for node_id in operation.node_ids:
|
self.ack_load_queue.put(node_id)
|
||||||
if node_id != 0:
|
|
||||||
self.ack_load_queue.put(node_id)
|
|
||||||
aux_thread.join()
|
aux_thread.join()
|
||||||
|
|
||||||
def evict_device(
|
def evict_device(
|
||||||
|
|||||||
Reference in New Issue
Block a user