fix bug that gpu0 occupies more memory when hicache is turned on (#5778)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -268,98 +268,97 @@ class HiCacheController:
|
||||
"""
|
||||
Directly write through KV caches to host memory without buffering.
|
||||
"""
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
self.mem_pool_host.write_page_all_layers(
|
||||
operation.host_indices,
|
||||
operation.device_indices,
|
||||
self.mem_pool_device,
|
||||
)
|
||||
self.write_stream.synchronize()
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_write_queue.put(node_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
torch.cuda.set_stream(self.write_stream)
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
self.mem_pool_host.write_page_all_layers(
|
||||
operation.host_indices,
|
||||
operation.device_indices,
|
||||
self.mem_pool_device,
|
||||
)
|
||||
self.write_stream.synchronize()
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_write_queue.put(node_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_thread_func_direct(self):
|
||||
"""
|
||||
Directly load KV caches from host memory to device memory without buffering.
|
||||
"""
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
# time.sleep(18e-6 * len(operation.host_indices))
|
||||
operation.data = self.mem_pool_host.get_flat_data(
|
||||
operation.host_indices
|
||||
)
|
||||
self.mem_pool_device.transfer(
|
||||
operation.device_indices, operation.data
|
||||
)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
torch.cuda.set_stream(self.load_stream)
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
# time.sleep(18e-6 * len(operation.host_indices))
|
||||
operation.data = self.mem_pool_host.get_flat_data(
|
||||
operation.host_indices
|
||||
)
|
||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_thread_func_layer_by_layer(self):
|
||||
"""
|
||||
Load KV caches from host memory to device memory layer by layer.
|
||||
"""
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while not self.stop_event.is_set():
|
||||
self.load_cache_event.wait(timeout=1)
|
||||
if not self.load_cache_event.is_set():
|
||||
continue
|
||||
self.load_cache_event.clear()
|
||||
torch.cuda.set_stream(self.load_stream)
|
||||
while not self.stop_event.is_set():
|
||||
self.load_cache_event.wait(timeout=1)
|
||||
if not self.load_cache_event.is_set():
|
||||
continue
|
||||
self.load_cache_event.clear()
|
||||
|
||||
batch_operation = None
|
||||
while self.load_queue.qsize() > 0:
|
||||
op = self.load_queue.get(block=True)
|
||||
if batch_operation is None:
|
||||
batch_operation = op
|
||||
else:
|
||||
batch_operation.merge(op)
|
||||
batch_operation = None
|
||||
while self.load_queue.qsize() > 0:
|
||||
op = self.load_queue.get(block=True)
|
||||
if batch_operation is None:
|
||||
continue
|
||||
batch_operation = op
|
||||
else:
|
||||
batch_operation.merge(op)
|
||||
if batch_operation is None:
|
||||
continue
|
||||
|
||||
self.layer_done_counter.reset()
|
||||
for i in range(self.mem_pool_host.layer_num):
|
||||
if self.page_size == 1:
|
||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
||||
batch_operation.host_indices, i
|
||||
)
|
||||
self.mem_pool_device.transfer_per_layer(
|
||||
batch_operation.device_indices, flat_data, i
|
||||
)
|
||||
else:
|
||||
self.mem_pool_host.load_page_per_layer(
|
||||
batch_operation.host_indices,
|
||||
batch_operation.device_indices,
|
||||
self.mem_pool_device,
|
||||
i,
|
||||
)
|
||||
self.load_stream.synchronize()
|
||||
self.layer_done_counter.increment()
|
||||
self.layer_done_counter.reset()
|
||||
for i in range(self.mem_pool_host.layer_num):
|
||||
if self.page_size == 1:
|
||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
||||
batch_operation.host_indices, i
|
||||
)
|
||||
self.mem_pool_device.transfer_per_layer(
|
||||
batch_operation.device_indices, flat_data, i
|
||||
)
|
||||
else:
|
||||
self.mem_pool_host.load_page_per_layer(
|
||||
batch_operation.host_indices,
|
||||
batch_operation.device_indices,
|
||||
self.mem_pool_device,
|
||||
i,
|
||||
)
|
||||
self.load_stream.synchronize()
|
||||
self.layer_done_counter.increment()
|
||||
|
||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
||||
for node_id in batch_operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
||||
for node_id in batch_operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
|
||||
def write_aux_func(self, no_wait=False):
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for write operations.
|
||||
"""
|
||||
torch.cuda.set_stream(self.write_stream)
|
||||
|
||||
def _to_op(op_):
|
||||
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
||||
@@ -370,44 +369,42 @@ class HiCacheController:
|
||||
return op_
|
||||
|
||||
buffer = None
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
factor = (
|
||||
len(operation.device_indices)
|
||||
// self.write_buffer.max_buffer_size
|
||||
)
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
factor = (
|
||||
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
||||
)
|
||||
|
||||
if factor >= 1:
|
||||
if buffer is not None:
|
||||
_to_op(buffer)
|
||||
buffer = None
|
||||
|
||||
if factor < 2:
|
||||
_to_op(operation)
|
||||
else:
|
||||
split_ops = operation.split(factor)
|
||||
for op_ in split_ops:
|
||||
_to_op(op_)
|
||||
continue
|
||||
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
buffer.merge(operation)
|
||||
if (
|
||||
no_wait
|
||||
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
||||
or self.write_queue.empty()
|
||||
or self.write_buffer.empty()
|
||||
):
|
||||
if factor >= 1:
|
||||
if buffer is not None:
|
||||
_to_op(buffer)
|
||||
buffer = None
|
||||
except Empty:
|
||||
|
||||
if factor < 2:
|
||||
_to_op(operation)
|
||||
else:
|
||||
split_ops = operation.split(factor)
|
||||
for op_ in split_ops:
|
||||
_to_op(op_)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
buffer.merge(operation)
|
||||
if (
|
||||
no_wait
|
||||
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
||||
or self.write_queue.empty()
|
||||
or self.write_buffer.empty()
|
||||
):
|
||||
_to_op(buffer)
|
||||
buffer = None
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_aux_func(self):
|
||||
"""
|
||||
@@ -484,19 +481,18 @@ class HiCacheController:
|
||||
aux_thread.join()
|
||||
|
||||
def load_thread_func_buffer(self):
|
||||
torch.cuda.set_stream(self.load_stream)
|
||||
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while not self.stop_event.is_set():
|
||||
operation = self.load_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
while not self.stop_event.is_set():
|
||||
operation = self.load_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
aux_thread.join()
|
||||
|
||||
def evict_device(
|
||||
|
||||
Reference in New Issue
Block a user