From 2fc824b84c08c86429b494ae3cfa86383da4647d Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Sun, 6 Jul 2025 22:53:36 -0700 Subject: [PATCH] Kernels for efficient KV cache IO (#7313) --- .../sglang/srt/managers/cache_controller.py | 236 +++--------------- python/sglang/srt/managers/scheduler.py | 6 + python/sglang/srt/mem_cache/hiradix_cache.py | 2 + python/sglang/srt/mem_cache/memory_pool.py | 176 ++++++++----- .../sglang/srt/mem_cache/memory_pool_host.py | 115 +-------- python/sglang/srt/mem_cache/radix_cache.py | 12 +- python/sglang/srt/server_args.py | 8 + 7 files changed, 184 insertions(+), 371 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index bd2ddcc5c..cad1d74b7 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -import concurrent.futures import logging import math import threading @@ -169,12 +168,23 @@ class HiCacheController: page_size: int, load_cache_event: threading.Event = None, write_policy: str = "write_through_selective", + io_backend: str = "", ): self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_host = mem_pool_host self.write_policy = write_policy self.page_size = page_size + # using kernel for small page KV cache transfer and DMA for large pages + if not io_backend: + IO_BACKEND_PAGE_SIZE_THRESHOLD = 64 + self.io_backend = ( + "direct" + if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD + else "kernel" + ) + else: + self.io_backend = io_backend self.load_cache_event = load_cache_event self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) @@ -203,12 +213,7 @@ class HiCacheController: self.load_stream = torch.cuda.Stream() self.write_thread = threading.Thread( - target=( - self.write_thread_func_buffer - if self.page_size == 1 - else self.write_thread_func_direct - ), - daemon=True, + target=self.write_thread_func_direct, daemon=True ) self.load_thread = threading.Thread( target=self.load_thread_func_layer_by_layer, daemon=True @@ -229,12 +234,7 @@ class HiCacheController: self.ack_load_queue.queue.clear() self.write_thread = threading.Thread( - target=( - self.write_thread_func_buffer - if self.page_size == 1 - else self.write_thread_func_direct - ), - daemon=True, + target=self.write_thread_func_direct, daemon=True ) self.load_thread = threading.Thread( target=self.load_thread_func_layer_by_layer, daemon=True @@ -281,6 +281,15 @@ class HiCacheController: ) return device_indices + def move_indices(self, host_indices, device_indices): + # move indices to GPU if using kernels, to host if using direct indexing + if self.io_backend == "kernel": + return host_indices.to(self.mem_pool_device.device), device_indices + elif self.io_backend == "direct": + return host_indices, device_indices.cpu() + else: + raise ValueError(f"Unsupported io backend") + def write_thread_func_direct(self): """ Directly write through KV caches to host memory without buffering. @@ -289,10 +298,14 @@ class HiCacheController: while not self.stop_event.is_set(): try: operation = self.write_queue.get(block=True, timeout=1) - self.mem_pool_host.write_page_all_layers( - operation.host_indices, - operation.device_indices, - self.mem_pool_device, + host_indices, device_indices = self.move_indices( + operation.host_indices, operation.device_indices + ) + self.mem_pool_device.backup_to_host_all_layer( + self.mem_pool_host, + host_indices, + device_indices, + self.io_backend, ) self.write_stream.synchronize() self.mem_pool_host.complete_io(operation.host_indices) @@ -304,27 +317,6 @@ class HiCacheController: except Exception as e: logger.error(e) - def load_thread_func_direct(self): - """ - Directly load KV caches from host memory to device memory without buffering. - """ - torch.cuda.set_stream(self.load_stream) - while not self.stop_event.is_set(): - try: - operation = self.load_queue.get(block=True, timeout=1) - operation.data = self.mem_pool_host.get_flat_data( - operation.host_indices - ) - self.mem_pool_device.transfer(operation.device_indices, operation.data) - self.mem_pool_host.complete_io(operation.host_indices) - for node_id in operation.node_ids: - if node_id != 0: - self.ack_load_queue.put(node_id) - except Empty: - continue - except Exception as e: - logger.error(e) - def load_thread_func_layer_by_layer(self): """ Load KV caches from host memory to device memory layer by layer. @@ -349,22 +341,18 @@ class HiCacheController: # start layer-wise KV cache transfer from CPU to GPU self.layer_done_counter.reset() + host_indices, device_indices = self.move_indices( + batch_operation.host_indices, batch_operation.device_indices + ) for i in range(self.mem_pool_host.layer_num): - if self.page_size == 1: - flat_data = self.mem_pool_host.get_flat_data_by_layer( - batch_operation.host_indices, i - ) - self.mem_pool_device.transfer_per_layer( - batch_operation.device_indices, flat_data, i - ) - else: - self.mem_pool_host.load_page_per_layer( - batch_operation.host_indices, - batch_operation.device_indices, - self.mem_pool_device, - i, - ) - self.load_stream.synchronize() + self.mem_pool_device.load_from_host_per_layer( + self.mem_pool_host, + host_indices, + device_indices, + i, + self.io_backend, + ) + self.load_stream.synchronize() self.layer_done_counter.increment() self.mem_pool_host.complete_io(batch_operation.host_indices) @@ -372,148 +360,6 @@ class HiCacheController: if node_id != 0: self.ack_load_queue.put(node_id) - def write_aux_func(self, no_wait=False): - """ - Auxiliary function to prepare the buffer for write operations. - """ - torch.cuda.set_stream(self.write_stream) - - def _to_op(op_): - assert op_.device_indices.is_cuda, "Device indices should be on GPU" - op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to( - self.mem_pool_host.device - ) - self.write_buffer.put(op_) - return op_ - - buffer = None - while not self.stop_event.is_set(): - try: - operation = self.write_queue.get(block=True, timeout=1) - factor = ( - len(operation.device_indices) // self.write_buffer.max_buffer_size - ) - - if factor >= 1: - if buffer is not None: - _to_op(buffer) - buffer = None - - if factor < 2: - _to_op(operation) - else: - split_ops = operation.split(factor) - for op_ in split_ops: - _to_op(op_) - continue - - if buffer is None: - buffer = operation - else: - buffer.merge(operation) - if ( - no_wait - or len(buffer.host_indices) >= self.write_buffer.max_buffer_size - or self.write_queue.empty() - or self.write_buffer.empty() - ): - _to_op(buffer) - buffer = None - except Empty: - continue - except Exception as e: - logger.error(e) - - def load_aux_func(self): - """ - Auxiliary function to prepare the buffer for load operations. - """ - - def _pin_op(op_, put=True): - op_.data = ( - self.mem_pool_host.get_flat_data(op_.host_indices) - .contiguous() - .pin_memory() - ) - if put: - self.load_buffer.put(op_) - return op_ - - buffer = None - while not self.stop_event.is_set(): - try: - operation = self.load_queue.get(block=True, timeout=1) - factor = len(operation.host_indices) // self.load_buffer.max_buffer_size - - if factor >= 1: - if buffer is not None: - _pin_op(buffer) - buffer = None - - if factor < 2: - _pin_op(operation) - else: - split_ops = operation.split(factor) - split_args = [(op_, True) for op_ in split_ops[:-1]] - split_args.append((split_ops[-1], False)) - # Spawn threads to pin each op concurrently - with concurrent.futures.ThreadPoolExecutor() as executor: - pinned_ops = list( - executor.map( - lambda x: _pin_op(x[0], put=x[1]), split_args - ) - ) - # preserve the order of last op to ensure correct ack - self.load_buffer.put(pinned_ops[-1]) - continue - - if buffer is None: - buffer = operation - else: - buffer.merge(operation) - if ( - len(buffer.host_indices) >= self.load_buffer.max_buffer_size - or self.load_queue.empty() - or self.load_buffer.empty() - ): - _pin_op(buffer) - buffer = None - except Empty: - continue - except Exception as e: - logger.error(e) - - # todo (zhiqiang): double buffering to be deprecated - def write_thread_func_buffer(self): - aux_thread = threading.Thread(target=self.write_aux_func, daemon=True) - aux_thread.start() - - while not self.stop_event.is_set(): - operation = self.write_buffer.get() - if operation is None: - continue - self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data) - self.mem_pool_host.complete_io(operation.host_indices) - for node_id in operation.node_ids: - if node_id != 0: - self.ack_write_queue.put(node_id) - aux_thread.join() - - def load_thread_func_buffer(self): - torch.cuda.set_stream(self.load_stream) - aux_thread = threading.Thread(target=self.load_aux_func, daemon=True) - aux_thread.start() - while not self.stop_event.is_set(): - operation = self.load_buffer.get() - if operation is None: - continue - self.mem_pool_device.transfer(operation.device_indices, operation.data) - self.mem_pool_host.complete_io(operation.host_indices) - for node_id in operation.node_ids: - if node_id != 0: - self.ack_load_queue.put(node_id) - aux_thread.join() - def evict_device( self, device_indices: torch.Tensor, host_indices: torch.Tensor ) -> int: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 73789dc12..d8f164a10 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -591,6 +591,12 @@ class Scheduler( hicache_ratio=server_args.hicache_ratio, hicache_size=server_args.hicache_size, hicache_write_policy=server_args.hicache_write_policy, + hicache_io_backend=( + "direct" + if server_args.attention_backend + == "fa3" # hot fix for incompatibility + else server_args.hicache_io_backend + ), ) self.tp_worker.register_hicache_layer_transfer_counter( self.tree_cache.cache_controller.layer_done_counter diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 31918b150..cb7d95558 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -34,6 +34,7 @@ class HiRadixCache(RadixCache): hicache_ratio: float, hicache_size: int, hicache_write_policy: str, + hicache_io_backend: str, ): self.kv_cache = token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): @@ -56,6 +57,7 @@ class HiRadixCache(RadixCache): page_size, load_cache_event=self.load_cache_event, write_policy=hicache_write_policy, + io_backend=hicache_io_backend, ) # record the nodes with ongoing write through diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 00ad66552..b8d280dfa 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -34,10 +34,11 @@ import torch import torch.distributed as dist import triton import triton.language as tl +from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2 +from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2 logger = logging.getLogger(__name__) @@ -150,13 +151,16 @@ class KVCache(abc.ABC): ) -> None: raise NotImplementedError() - def get_flat_data(self, indices): + @abc.abstractmethod + def load_from_host_per_layer( + self, host_pool, host_indices, device_indices, layer_id, io_backend + ): raise NotImplementedError() - def transfer(self, indices, flat_data): - raise NotImplementedError() - - def transfer_per_layer(self, indices, flat_data, layer_id): + @abc.abstractmethod + def backup_to_host_all_layer( + self, host_pool, host_indices, device_indices, io_backend + ): raise NotImplementedError() def register_layer_transfer_counter(self, layer_transfer_counter): @@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache): ) for _ in range(self.layer_num) ] - + self.token_stride = self.head_num * self.head_dim self.data_ptrs = torch.tensor( [x.data_ptr() for x in self.k_buffer + self.v_buffer], dtype=torch.uint64, @@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache): # layer_num x [seq_len, head_num, head_dim] # layer_num x [page_num, page_size, head_num, head_dim] kv_data_ptrs = [ - self.get_key_buffer(i).data_ptr() + self._get_key_buffer(i).data_ptr() for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ - self.get_value_buffer(i).data_ptr() + self._get_value_buffer(i).data_ptr() for i in range(self.start_layer, self.start_layer + self.layer_num) ] kv_data_lens = [ - self.get_key_buffer(i).nbytes + self._get_key_buffer(i).nbytes for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ - self.get_value_buffer(i).nbytes + self._get_value_buffer(i).nbytes for i in range(self.start_layer, self.start_layer + self.layer_num) ] kv_item_lens = [ - self.get_key_buffer(i)[0].nbytes * self.page_size + self._get_key_buffer(i)[0].nbytes * self.page_size for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ - self.get_value_buffer(i)[0].nbytes * self.page_size + self._get_value_buffer(i)[0].nbytes * self.page_size for i in range(self.start_layer, self.start_layer + self.layer_num) ] return kv_data_ptrs, kv_data_lens, kv_item_lens @@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache): self.v_buffer[layer_id][chunk_indices] = v_chunk torch.cuda.synchronize() - # Todo: different memory layout - def get_flat_data(self, indices): - # prepare a large chunk of contiguous data for efficient transfer - flatten = torch.stack( - [ - torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]), - torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]), - ] + def load_from_host_per_layer( + self, + host_pool, + host_indices, + device_indices, + layer_id, + io_backend, + ): + transfer_kv_per_layer( + src_k=host_pool.k_buffer[layer_id], + dst_k=self.k_buffer[layer_id], + src_v=host_pool.v_buffer[layer_id], + dst_v=self.v_buffer[layer_id], + src_indices=host_indices, + dst_indices=device_indices, + io_backend=io_backend, + page_size=self.page_size, + item_size=self.token_stride, ) - return flatten - @debug_timing - def transfer(self, indices, flat_data): - # transfer prepared data from host to device - flat_data = flat_data.to(device=self.device, non_blocking=False) - k_data, v_data = flat_data[0], flat_data[1] - for i in range(self.layer_num): - self.k_buffer[i][indices] = k_data[i] - self.v_buffer[i][indices] = v_data[i] - - def transfer_per_layer(self, indices, flat_data, layer_id): - # transfer prepared data from host to device - flat_data = flat_data.to(device=self.device, non_blocking=False) - k_data, v_data = flat_data[0], flat_data[1] - self.k_buffer[layer_id - self.start_layer][indices] = k_data - self.v_buffer[layer_id - self.start_layer][indices] = v_data - - def get_key_buffer(self, layer_id: int): - if self.layer_transfer_counter is not None: - self.layer_transfer_counter.wait_until(layer_id - self.start_layer) + def backup_to_host_all_layer( + self, host_pool, host_indices, device_indices, io_backend + ): + # todo: specialized all layer kernels for the layer-non-contiguous memory pool + for layer_id in range(self.start_layer, self.start_layer + self.layer_num): + if layer_id - self.start_layer >= len(host_pool.k_buffer): + raise ValueError( + f"Layer ID {layer_id} exceeds the number of layers in host pool." + ) + transfer_kv_per_layer( + src_k=self.k_buffer[layer_id], + dst_k=host_pool.k_buffer[layer_id], + src_v=self.v_buffer[layer_id], + dst_v=host_pool.v_buffer[layer_id], + src_indices=device_indices, + dst_indices=host_indices, + io_backend=io_backend, + page_size=self.page_size, + item_size=self.token_stride, + ) + def _get_key_buffer(self, layer_id: int): + # for internal use of referencing if self.store_dtype != self.dtype: return self.k_buffer[layer_id - self.start_layer].view(self.dtype) return self.k_buffer[layer_id - self.start_layer] - def get_value_buffer(self, layer_id: int): + def get_key_buffer(self, layer_id: int): + # note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading + # it is supposed to be used only by attention backend not for information purpose + # same applies to get_value_buffer and get_kv_buffer if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) + return self._get_key_buffer(layer_id) + + def _get_value_buffer(self, layer_id: int): + # for internal use of referencing if self.store_dtype != self.dtype: return self.v_buffer[layer_id - self.start_layer].view(self.dtype) return self.v_buffer[layer_id - self.start_layer] + def get_value_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) + return self._get_value_buffer(layer_id) + def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) @@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache): for _ in range(layer_num) ] + self.token_stride = kv_lora_rank + qk_rope_head_dim self.layer_transfer_counter = None kv_size = self.get_kv_size_bytes() @@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache): self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope ) - def get_flat_data(self, indices): - # prepare a large chunk of contiguous data for efficient transfer - return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)]) + def load_from_host_per_layer( + self, host_pool, host_indices, device_indices, layer_id, io_backend + ): + transfer_kv_per_layer_mla( + src=host_pool.kv_buffer[layer_id], + dst=self.kv_buffer[layer_id], + src_indices=host_indices, + dst_indices=device_indices, + io_backend=io_backend, + page_size=self.page_size, + item_size=self.token_stride, + ) - @debug_timing - def transfer(self, indices, flat_data): - # transfer prepared data from host to device - flat_data = flat_data.to(device=self.device, non_blocking=False) - for i in range(self.layer_num): - self.kv_buffer[i][indices] = flat_data[i] - - def transfer_per_layer(self, indices, flat_data, layer_id): - # transfer prepared data from host to device - flat_data = flat_data.to(device=self.device, non_blocking=False) - self.kv_buffer[layer_id - self.start_layer][indices] = flat_data + def backup_to_host_all_layer( + self, host_pool, host_indices, device_indices, io_backend + ): + # todo: specialized all layer kernels for the layer-non-contiguous memory pool + for layer_id in range(self.start_layer, self.start_layer + self.layer_num): + if layer_id - self.start_layer >= len(host_pool.kv_buffer): + raise ValueError( + f"Layer ID {layer_id} exceeds the number of layers in host pool." + ) + transfer_kv_per_layer_mla( + src=self.kv_buffer[layer_id], + dst=host_pool.kv_buffer[layer_id], + src_indices=device_indices, + dst_indices=host_indices, + io_backend=io_backend, + page_size=self.page_size, + item_size=self.token_stride, + ) def get_cpu_copy(self, indices): torch.cuda.synchronize() @@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache): self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.label_buffer[layer_id - self.start_layer][loc] = cache_label - def get_flat_data(self, indices): - pass + def load_from_host_per_layer( + self, host_pool, host_indices, device_indices, layer_id, io_backend + ): + raise NotImplementedError( + "HiCache not supported for DoubleSparseTokenToKVPool." + ) - def transfer(self, indices, flat_data): - pass - - def transfer_per_layer(self, indices, flat_data, layer_id): - pass + def backup_to_host_all_layer( + self, host_pool, host_indices, device_indices, io_backend + ): + raise NotImplementedError( + "HiCache not supported for DoubleSparseTokenToKVPool." + ) @triton.jit diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 2a9d4c6e5..a5977fd1d 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -8,7 +8,6 @@ import psutil import torch from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool -from sglang.srt.utils import debug_timing logger = logging.getLogger(__name__) @@ -99,22 +98,6 @@ class HostKVCache(abc.ABC): def init_kv_buffer(self): raise NotImplementedError() - @abc.abstractmethod - def transfer(self, indices, flat_data): - raise NotImplementedError() - - @abc.abstractmethod - def get_flat_data(self, indices): - raise NotImplementedError() - - @abc.abstractmethod - def get_flat_data_by_layer(self, indices, layer_id): - raise NotImplementedError() - - @abc.abstractmethod - def assign_flat_data(self, indices, flat_data): - raise NotImplementedError() - @synchronized() def clear(self): # Initialize memory states and tracking structures. @@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache): pin_memory=self.pin_memory, ) - @debug_timing - def transfer(self, indices, flat_data): - # backup prepared data from device to host - self.kv_buffer[:, :, indices] = flat_data.to( - device=self.device, non_blocking=False - ) + @property + def k_buffer(self): + return self.kv_buffer[0] - def get_flat_data(self, indices): - return self.kv_buffer[:, :, indices] - - def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[:, layer_id - self.start_layer, indices] - - def assign_flat_data(self, indices, flat_data): - self.kv_buffer[:, :, indices] = flat_data - - def write_page_all_layers(self, host_indices, device_indices, device_pool): - device_indices_cpu = device_indices[:: self.page_size].cpu() - for i in range(len(device_indices_cpu)): - h_index = host_indices[i * self.page_size] - d_index = device_indices_cpu[i] - for j in range(self.layer_num): - self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_( - device_pool.k_buffer[j][d_index : d_index + self.page_size], - non_blocking=True, - ) - self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_( - device_pool.v_buffer[j][d_index : d_index + self.page_size], - non_blocking=True, - ) - - def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id): - device_indices_cpu = device_indices[:: self.page_size].cpu() - for i in range(len(device_indices_cpu)): - h_index = host_indices[i * self.page_size] - d_index = device_indices_cpu[i] - device_pool.k_buffer[layer_id - self.start_layer][ - d_index : d_index + self.page_size - ].copy_( - self.kv_buffer[ - 0, layer_id - self.start_layer, h_index : h_index + self.page_size - ], - non_blocking=True, - ) - device_pool.v_buffer[layer_id - self.start_layer][ - d_index : d_index + self.page_size - ].copy_( - self.kv_buffer[ - 1, layer_id - self.start_layer, h_index : h_index + self.page_size - ], - non_blocking=True, - ) + @property + def v_buffer(self): + return self.kv_buffer[1] class MLATokenToKVPoolHost(HostKVCache): @@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache): device=self.device, pin_memory=self.pin_memory, ) - - @debug_timing - def transfer(self, indices, flat_data): - # backup prepared data from device to host - self.kv_buffer[:, indices] = flat_data.to( - device=self.device, non_blocking=False - ) - - def get_flat_data(self, indices): - return self.kv_buffer[:, indices] - - def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[layer_id - self.start_layer, indices] - - def assign_flat_data(self, indices, flat_data): - self.kv_buffer[:, indices] = flat_data - - def write_page_all_layers(self, host_indices, device_indices, device_pool): - device_indices_cpu = device_indices[:: self.page_size].cpu() - for i in range(len(device_indices_cpu)): - h_index = host_indices[i * self.page_size] - d_index = device_indices_cpu[i] - for j in range(self.layer_num): - self.kv_buffer[j, h_index : h_index + self.page_size].copy_( - device_pool.kv_buffer[j][d_index : d_index + self.page_size], - non_blocking=True, - ) - - def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id): - device_indices_cpu = device_indices[:: self.page_size].cpu() - for i in range(len(device_indices_cpu)): - h_index = host_indices[i * self.page_size] - d_index = device_indices_cpu[i] - device_pool.kv_buffer[layer_id - self.start_layer][ - d_index : d_index + self.page_size - ].copy_( - self.kv_buffer[ - layer_id - self.start_layer, h_index : h_index + self.page_size - ], - non_blocking=True, - ) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 72241b829..706432209 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache): if self.page_size != 1: page_aligned_len = len(kv_indices) // self.page_size * self.page_size - page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() + page_aligned_kv_indices = kv_indices[:page_aligned_len].to( + dtype=torch.int64, copy=True + ) self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) else: page_aligned_len = len(kv_indices) - page_aligned_kv_indices = kv_indices.clone() + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) # Radix Cache takes one ref in memory pool new_prefix_len = self.insert( @@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache): if self.page_size != 1: page_aligned_len = len(kv_indices) // self.page_size * self.page_size - page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() + page_aligned_kv_indices = kv_indices[:page_aligned_len].to( + dtype=torch.int64, copy=True + ) else: page_aligned_len = len(kv_indices) - page_aligned_kv_indices = kv_indices.clone() + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) page_aligned_token_ids = token_ids[:page_aligned_len] # Radix Cache takes one ref in memory pool diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 51e5ecc8b..7a5db80a7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -217,6 +217,7 @@ class ServerArgs: hicache_ratio: float = 2.0 hicache_size: int = 0 hicache_write_policy: str = "write_through_selective" + hicache_io_backend: str = "" flashinfer_mla_disable_ragged: bool = False disable_shared_experts_fusion: bool = False disable_chunked_prefix_cache: bool = False @@ -1530,6 +1531,13 @@ class ServerArgs: default=ServerArgs.hicache_write_policy, help="The write policy of hierarchical cache.", ) + parser.add_argument( + "--hicache-io-backend", + type=str, + choices=["direct", "kernel"], + default=ServerArgs.hicache_io_backend, + help="The IO backend for KV cache transfer between CPU and GPU", + ) parser.add_argument( "--flashinfer-mla-disable-ragged", action="store_true",