Kernels for efficient KV cache IO (#7313)
This commit is contained in:
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
@@ -169,12 +168,23 @@ class HiCacheController:
|
||||
page_size: int,
|
||||
load_cache_event: threading.Event = None,
|
||||
write_policy: str = "write_through_selective",
|
||||
io_backend: str = "",
|
||||
):
|
||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||
self.mem_pool_host = mem_pool_host
|
||||
self.write_policy = write_policy
|
||||
self.page_size = page_size
|
||||
# using kernel for small page KV cache transfer and DMA for large pages
|
||||
if not io_backend:
|
||||
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
|
||||
self.io_backend = (
|
||||
"direct"
|
||||
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
|
||||
else "kernel"
|
||||
)
|
||||
else:
|
||||
self.io_backend = io_backend
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
@@ -203,12 +213,7 @@ class HiCacheController:
|
||||
self.load_stream = torch.cuda.Stream()
|
||||
|
||||
self.write_thread = threading.Thread(
|
||||
target=(
|
||||
self.write_thread_func_buffer
|
||||
if self.page_size == 1
|
||||
else self.write_thread_func_direct
|
||||
),
|
||||
daemon=True,
|
||||
target=self.write_thread_func_direct, daemon=True
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
@@ -229,12 +234,7 @@ class HiCacheController:
|
||||
self.ack_load_queue.queue.clear()
|
||||
|
||||
self.write_thread = threading.Thread(
|
||||
target=(
|
||||
self.write_thread_func_buffer
|
||||
if self.page_size == 1
|
||||
else self.write_thread_func_direct
|
||||
),
|
||||
daemon=True,
|
||||
target=self.write_thread_func_direct, daemon=True
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
@@ -281,6 +281,15 @@ class HiCacheController:
|
||||
)
|
||||
return device_indices
|
||||
|
||||
def move_indices(self, host_indices, device_indices):
|
||||
# move indices to GPU if using kernels, to host if using direct indexing
|
||||
if self.io_backend == "kernel":
|
||||
return host_indices.to(self.mem_pool_device.device), device_indices
|
||||
elif self.io_backend == "direct":
|
||||
return host_indices, device_indices.cpu()
|
||||
else:
|
||||
raise ValueError(f"Unsupported io backend")
|
||||
|
||||
def write_thread_func_direct(self):
|
||||
"""
|
||||
Directly write through KV caches to host memory without buffering.
|
||||
@@ -289,10 +298,14 @@ class HiCacheController:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
self.mem_pool_host.write_page_all_layers(
|
||||
operation.host_indices,
|
||||
operation.device_indices,
|
||||
self.mem_pool_device,
|
||||
host_indices, device_indices = self.move_indices(
|
||||
operation.host_indices, operation.device_indices
|
||||
)
|
||||
self.mem_pool_device.backup_to_host_all_layer(
|
||||
self.mem_pool_host,
|
||||
host_indices,
|
||||
device_indices,
|
||||
self.io_backend,
|
||||
)
|
||||
self.write_stream.synchronize()
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
@@ -304,27 +317,6 @@ class HiCacheController:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_thread_func_direct(self):
|
||||
"""
|
||||
Directly load KV caches from host memory to device memory without buffering.
|
||||
"""
|
||||
torch.cuda.set_stream(self.load_stream)
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
operation.data = self.mem_pool_host.get_flat_data(
|
||||
operation.host_indices
|
||||
)
|
||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_thread_func_layer_by_layer(self):
|
||||
"""
|
||||
Load KV caches from host memory to device memory layer by layer.
|
||||
@@ -349,22 +341,18 @@ class HiCacheController:
|
||||
|
||||
# start layer-wise KV cache transfer from CPU to GPU
|
||||
self.layer_done_counter.reset()
|
||||
host_indices, device_indices = self.move_indices(
|
||||
batch_operation.host_indices, batch_operation.device_indices
|
||||
)
|
||||
for i in range(self.mem_pool_host.layer_num):
|
||||
if self.page_size == 1:
|
||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
||||
batch_operation.host_indices, i
|
||||
)
|
||||
self.mem_pool_device.transfer_per_layer(
|
||||
batch_operation.device_indices, flat_data, i
|
||||
)
|
||||
else:
|
||||
self.mem_pool_host.load_page_per_layer(
|
||||
batch_operation.host_indices,
|
||||
batch_operation.device_indices,
|
||||
self.mem_pool_device,
|
||||
i,
|
||||
)
|
||||
self.load_stream.synchronize()
|
||||
self.mem_pool_device.load_from_host_per_layer(
|
||||
self.mem_pool_host,
|
||||
host_indices,
|
||||
device_indices,
|
||||
i,
|
||||
self.io_backend,
|
||||
)
|
||||
self.load_stream.synchronize()
|
||||
self.layer_done_counter.increment()
|
||||
|
||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
||||
@@ -372,148 +360,6 @@ class HiCacheController:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
|
||||
def write_aux_func(self, no_wait=False):
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for write operations.
|
||||
"""
|
||||
torch.cuda.set_stream(self.write_stream)
|
||||
|
||||
def _to_op(op_):
|
||||
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
||||
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
|
||||
self.mem_pool_host.device
|
||||
)
|
||||
self.write_buffer.put(op_)
|
||||
return op_
|
||||
|
||||
buffer = None
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
factor = (
|
||||
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
||||
)
|
||||
|
||||
if factor >= 1:
|
||||
if buffer is not None:
|
||||
_to_op(buffer)
|
||||
buffer = None
|
||||
|
||||
if factor < 2:
|
||||
_to_op(operation)
|
||||
else:
|
||||
split_ops = operation.split(factor)
|
||||
for op_ in split_ops:
|
||||
_to_op(op_)
|
||||
continue
|
||||
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
buffer.merge(operation)
|
||||
if (
|
||||
no_wait
|
||||
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
||||
or self.write_queue.empty()
|
||||
or self.write_buffer.empty()
|
||||
):
|
||||
_to_op(buffer)
|
||||
buffer = None
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_aux_func(self):
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for load operations.
|
||||
"""
|
||||
|
||||
def _pin_op(op_, put=True):
|
||||
op_.data = (
|
||||
self.mem_pool_host.get_flat_data(op_.host_indices)
|
||||
.contiguous()
|
||||
.pin_memory()
|
||||
)
|
||||
if put:
|
||||
self.load_buffer.put(op_)
|
||||
return op_
|
||||
|
||||
buffer = None
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
|
||||
|
||||
if factor >= 1:
|
||||
if buffer is not None:
|
||||
_pin_op(buffer)
|
||||
buffer = None
|
||||
|
||||
if factor < 2:
|
||||
_pin_op(operation)
|
||||
else:
|
||||
split_ops = operation.split(factor)
|
||||
split_args = [(op_, True) for op_ in split_ops[:-1]]
|
||||
split_args.append((split_ops[-1], False))
|
||||
# Spawn threads to pin each op concurrently
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
pinned_ops = list(
|
||||
executor.map(
|
||||
lambda x: _pin_op(x[0], put=x[1]), split_args
|
||||
)
|
||||
)
|
||||
# preserve the order of last op to ensure correct ack
|
||||
self.load_buffer.put(pinned_ops[-1])
|
||||
continue
|
||||
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
buffer.merge(operation)
|
||||
if (
|
||||
len(buffer.host_indices) >= self.load_buffer.max_buffer_size
|
||||
or self.load_queue.empty()
|
||||
or self.load_buffer.empty()
|
||||
):
|
||||
_pin_op(buffer)
|
||||
buffer = None
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
# todo (zhiqiang): double buffering to be deprecated
|
||||
def write_thread_func_buffer(self):
|
||||
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
operation = self.write_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_write_queue.put(node_id)
|
||||
aux_thread.join()
|
||||
|
||||
def load_thread_func_buffer(self):
|
||||
torch.cuda.set_stream(self.load_stream)
|
||||
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
while not self.stop_event.is_set():
|
||||
operation = self.load_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
aux_thread.join()
|
||||
|
||||
def evict_device(
|
||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||
) -> int:
|
||||
|
||||
@@ -591,6 +591,12 @@ class Scheduler(
|
||||
hicache_ratio=server_args.hicache_ratio,
|
||||
hicache_size=server_args.hicache_size,
|
||||
hicache_write_policy=server_args.hicache_write_policy,
|
||||
hicache_io_backend=(
|
||||
"direct"
|
||||
if server_args.attention_backend
|
||||
== "fa3" # hot fix for incompatibility
|
||||
else server_args.hicache_io_backend
|
||||
),
|
||||
)
|
||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||
self.tree_cache.cache_controller.layer_done_counter
|
||||
|
||||
Reference in New Issue
Block a user