diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 629dd71a2..91f6ef37d 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -231,16 +231,7 @@ class HiCacheController: self.mem_pool_host = mem_pool_host self.write_policy = write_policy self.page_size = page_size - # using kernel for small page KV cache transfer and DMA for large pages - if not io_backend: - IO_BACKEND_PAGE_SIZE_THRESHOLD = 64 - self.io_backend = ( - "direct" - if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD - else "kernel" - ) - else: - self.io_backend = io_backend + self.io_backend = io_backend self.enable_storage = False # todo: move backend initialization to storage backend module @@ -447,11 +438,8 @@ class HiCacheController: host_indices, device_indices = self.move_indices( operation.host_indices, operation.device_indices ) - self.mem_pool_device.backup_to_host_all_layer( - self.mem_pool_host, - host_indices, - device_indices, - self.io_backend, + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, host_indices, device_indices, self.io_backend ) self.write_stream.synchronize() self.mem_pool_host.complete_io(operation.host_indices) @@ -491,8 +479,8 @@ class HiCacheController: batch_operation.host_indices, batch_operation.device_indices ) for i in range(self.mem_pool_host.layer_num): - self.mem_pool_device.load_from_host_per_layer( - self.mem_pool_host, + self.mem_pool_host.load_to_device_per_layer( + self.mem_pool_device, host_indices, device_indices, i, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d71f02275..c5998cdec 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -588,6 +588,7 @@ class Scheduler( == "fa3" # hot fix for incompatibility else server_args.hicache_io_backend ), + hicache_mem_layout=server_args.hicache_mem_layout, hicache_storage_backend=server_args.hicache_storage_backend, ) self.tp_worker.register_hicache_layer_transfer_counter( diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 5a2ff6fb8..7b26fa8a7 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -35,16 +35,33 @@ class HiRadixCache(RadixCache): hicache_size: int, hicache_write_policy: str, hicache_io_backend: str, + hicache_mem_layout: str, hicache_storage_backend: Optional[str] = None, ): + + if hicache_io_backend == "direct": + if hicache_mem_layout == "page_first": + hicache_mem_layout = "layer_first" + logger.warning( + "Page first layout is not supported with direct IO backend, switching to layer first layout" + ) + self.kv_cache = token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): self.token_to_kv_pool_host = MHATokenToKVPoolHost( - self.kv_cache, hicache_ratio, hicache_size, page_size + self.kv_cache, + hicache_ratio, + hicache_size, + page_size, + hicache_mem_layout, ) elif isinstance(self.kv_cache, MLATokenToKVPool): self.token_to_kv_pool_host = MLATokenToKVPoolHost( - self.kv_cache, hicache_ratio, hicache_size, page_size + self.kv_cache, + hicache_ratio, + hicache_size, + page_size, + hicache_mem_layout, ) else: raise ValueError(f"HiRadixCache only supports MHA and MLA yet") diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 2af8838b9..cc3faea0a 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch -import torch.distributed as dist import triton import triton.language as tl from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 +from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2 logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 _is_cuda = is_cuda() -_is_npu = is_npu() -if not _is_npu: - from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla class ReqToTokenPool: @@ -153,18 +149,6 @@ class KVCache(abc.ABC): ) -> None: raise NotImplementedError() - @abc.abstractmethod - def load_from_host_per_layer( - self, host_pool, host_indices, device_indices, layer_id, io_backend - ): - raise NotImplementedError() - - @abc.abstractmethod - def backup_to_host_all_layer( - self, host_pool, host_indices, device_indices, io_backend - ): - raise NotImplementedError() - def register_layer_transfer_counter(self, layer_transfer_counter): self.layer_transfer_counter = layer_transfer_counter @@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache): ) for _ in range(self.layer_num) ] - self.token_stride = self.head_num * self.head_dim - self.data_ptrs = torch.tensor( - [x.data_ptr() for x in self.k_buffer + self.v_buffer], + + self.k_data_ptrs = torch.tensor( + [x.data_ptr() for x in self.k_buffer], dtype=torch.uint64, device=self.device, ) + self.v_data_ptrs = torch.tensor( + [x.data_ptr() for x in self.v_buffer], + dtype=torch.uint64, + device=self.device, + ) + self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0) self.data_strides = torch.tensor( [ np.prod(x.shape[1:]) * x.dtype.itemsize @@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache): self.v_buffer[layer_id][chunk_indices] = v_chunk torch.cuda.synchronize() - def load_from_host_per_layer( - self, - host_pool, - host_indices, - device_indices, - layer_id, - io_backend, - ): - transfer_kv_per_layer( - src_k=host_pool.k_buffer[layer_id], - dst_k=self.k_buffer[layer_id], - src_v=host_pool.v_buffer[layer_id], - dst_v=self.v_buffer[layer_id], - src_indices=host_indices, - dst_indices=device_indices, - io_backend=io_backend, - page_size=self.page_size, - item_size=self.token_stride, - ) - - def backup_to_host_all_layer( - self, host_pool, host_indices, device_indices, io_backend - ): - # todo: specialized all layer kernels for the layer-non-contiguous memory pool - for layer_id in range(self.start_layer, self.start_layer + self.layer_num): - if layer_id - self.start_layer >= len(host_pool.k_buffer): - raise ValueError( - f"Layer ID {layer_id} exceeds the number of layers in host pool." - ) - transfer_kv_per_layer( - src_k=self.k_buffer[layer_id], - dst_k=host_pool.k_buffer[layer_id], - src_v=self.v_buffer[layer_id], - dst_v=host_pool.v_buffer[layer_id], - src_indices=device_indices, - dst_indices=host_indices, - io_backend=io_backend, - page_size=self.page_size, - item_size=self.token_stride, - ) - def _get_key_buffer(self, layer_id: int): # for internal use of referencing if self.store_dtype != self.dtype: @@ -602,16 +551,6 @@ class SWAKVPool(KVCache): layer_id_override=layer_id_pool, ) - def load_from_host_per_layer( - self, host_pool, host_indices, device_indices, layer_id, io_backend - ): - raise NotImplementedError("HiCache not supported for SWAKVPool.") - - def backup_to_host_all_layer( - self, host_pool, host_indices, device_indices, io_backend - ): - raise NotImplementedError("HiCache not supported for SWAKVPool.") - class AscendTokenToKVPool(MHATokenToKVPool): @@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache): for _ in range(layer_num) ] - self.token_stride = kv_lora_rank + qk_rope_head_dim + self.data_ptrs = torch.tensor( + [x.data_ptr() for x in self.kv_buffer], + dtype=torch.uint64, + device=self.device, + ) self.layer_transfer_counter = None kv_size = self.get_kv_size_bytes() @@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache): self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope ) - def load_from_host_per_layer( - self, host_pool, host_indices, device_indices, layer_id, io_backend - ): - transfer_kv_per_layer_mla( - src=host_pool.kv_buffer[layer_id], - dst=self.kv_buffer[layer_id], - src_indices=host_indices, - dst_indices=device_indices, - io_backend=io_backend, - page_size=self.page_size, - item_size=self.token_stride, - ) - - def backup_to_host_all_layer( - self, host_pool, host_indices, device_indices, io_backend - ): - # todo: specialized all layer kernels for the layer-non-contiguous memory pool - for layer_id in range(self.start_layer, self.start_layer + self.layer_num): - if layer_id - self.start_layer >= len(host_pool.kv_buffer): - raise ValueError( - f"Layer ID {layer_id} exceeds the number of layers in host pool." - ) - transfer_kv_per_layer_mla( - src=self.kv_buffer[layer_id], - dst=host_pool.kv_buffer[layer_id], - src_indices=device_indices, - dst_indices=host_indices, - io_backend=io_backend, - page_size=self.page_size, - item_size=self.token_stride, - ) - def get_cpu_copy(self, indices): torch.cuda.synchronize() kv_cache_cpu = [] @@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache): self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.label_buffer[layer_id - self.start_layer][loc] = cache_label - def load_from_host_per_layer( - self, host_pool, host_indices, device_indices, layer_id, io_backend - ): - raise NotImplementedError( - "HiCache not supported for DoubleSparseTokenToKVPool." - ) - - def backup_to_host_all_layer( - self, host_pool, host_indices, device_indices, io_backend - ): - raise NotImplementedError( - "HiCache not supported for DoubleSparseTokenToKVPool." - ) - @triton.jit def copy_all_layer_kv_cache( diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 5d9a88f35..fc0ba09bc 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -8,6 +8,21 @@ import psutil import torch from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool +from sglang.srt.utils import is_npu + +_is_npu = is_npu() +if not _is_npu: + from sgl_kernel.kvcacheio import ( + transfer_kv_all_layer, + transfer_kv_all_layer_lf_pf, + transfer_kv_all_layer_mla, + transfer_kv_all_layer_mla_lf_pf, + transfer_kv_direct, + transfer_kv_per_layer, + transfer_kv_per_layer_mla, + transfer_kv_per_layer_mla_pf_lf, + transfer_kv_per_layer_pf_lf, + ) logger = logging.getLogger(__name__) @@ -42,15 +57,18 @@ class HostKVCache(abc.ABC): device_pool: KVCache, host_to_device_ratio: float, host_size: int, + page_size: int, + layout: str, pin_memory: bool, device: str, - page_size: int, ): self.device_pool = device_pool - self.dtype = device_pool.store_dtype + self.page_size = page_size + self.layout = layout self.pin_memory = pin_memory self.device = device - self.page_size = page_size + + self.dtype = device_pool.store_dtype self.size_per_token = self.get_size_per_token() if host_size > 0: self.size = int(host_size * 1e9 // self.size_per_token) @@ -98,6 +116,24 @@ class HostKVCache(abc.ABC): def init_kv_buffer(self): raise NotImplementedError() + @abc.abstractmethod + def load_to_device_per_layer( + self, device_pool, host_indices, device_indices, layer_id, io_backend + ) -> None: + """ + Load KV data from the host memory pool to the device memory pool for a specific layer. + """ + raise NotImplementedError() + + @abc.abstractmethod + def backup_from_device_all_layer( + self, device_pool, host_indices, device_indices, io_backend + ) -> None: + """ + Backup KV data from the device memory pool to the host memory pool for all layers. + """ + raise NotImplementedError() + @abc.abstractmethod def get_flat_data_page(self, index) -> torch.Tensor: """ @@ -238,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache): host_to_device_ratio: float, host_size: int, page_size: int, + layout: str, pin_memory: bool = True, device: str = "cpu", ): super().__init__( - device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size + device_pool, + host_to_device_ratio, + host_size, + page_size, + layout, + pin_memory, + device, + ) + self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)] + self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)] + self.k_data_ptrs = torch.tensor( + [x.data_ptr() for x in self.k_data_refs], + dtype=torch.uint64, + device=self.device_pool.device, + ) + self.v_data_ptrs = torch.tensor( + [x.data_ptr() for x in self.v_data_refs], + dtype=torch.uint64, + device=self.device_pool.device, ) def get_size_per_token(self): @@ -253,16 +308,128 @@ class MHATokenToKVPoolHost(HostKVCache): return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 def init_kv_buffer(self): + if self.layout == "layer_first": + dims = (2, self.layer_num, self.size, self.head_num, self.head_dim) + elif self.layout == "page_first": + dims = (2, self.size, self.layer_num, self.head_num, self.head_dim) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize + self.layout_dim = self.token_stride_size * self.layer_num return torch.empty( - (2, self.layer_num, self.size, self.head_num, self.head_dim), + dims, dtype=self.dtype, device=self.device, pin_memory=self.pin_memory, ) - # todo, page first memory layout + @property + def k_buffer(self): + return self.kv_buffer[0] + + @property + def v_buffer(self): + return self.kv_buffer[1] + + def load_to_device_per_layer( + self, + device_pool, + host_indices, + device_indices, + layer_id, + io_backend, + ): + if io_backend == "kernel": + if self.layout == "layer_first": + transfer_kv_per_layer( + src_k=self.k_buffer[layer_id], + dst_k=device_pool.k_buffer[layer_id], + src_v=self.v_buffer[layer_id], + dst_v=device_pool.v_buffer[layer_id], + src_indices=host_indices, + dst_indices=device_indices, + item_size=self.token_stride_size, + ) + elif self.layout == "page_first": + transfer_kv_per_layer_pf_lf( + src_k=self.k_buffer, + dst_k=device_pool.k_buffer[layer_id], + src_v=self.v_buffer, + dst_v=device_pool.v_buffer[layer_id], + src_indices=host_indices, + dst_indices=device_indices, + item_size=self.token_stride_size, + src_layout_dim=self.layout_dim, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + elif io_backend == "direct": + assert ( + self.layout == "layer_first" + ), f"Direct IO backend only supports layer_first layout." + transfer_kv_direct( + src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]], + dst_layers=[ + device_pool.k_buffer[layer_id], + device_pool.v_buffer[layer_id], + ], + src_indices=host_indices, + dst_indices=device_indices, + page_size=self.page_size, + ) + else: + raise ValueError(f"Unsupported IO backend: {io_backend}") + + def backup_from_device_all_layer( + self, device_pool, host_indices, device_indices, io_backend + ): + if io_backend == "kernel": + if self.layout == "layer_first": + transfer_kv_all_layer( + src_k_layers=device_pool.k_data_ptrs, + dst_k_layers=self.k_data_ptrs, + src_v_layers=device_pool.v_data_ptrs, + dst_v_layers=self.v_data_ptrs, + src_indices=device_indices, + dst_indices=host_indices, + item_size=self.token_stride_size, + num_layers=self.layer_num, + ) + elif self.layout == "page_first": + transfer_kv_all_layer_lf_pf( + src_k_layers=device_pool.k_data_ptrs, + dst_k=self.k_buffer, + src_v_layers=device_pool.v_data_ptrs, + dst_v=self.v_buffer, + src_indices=device_indices, + dst_indices=host_indices, + item_size=self.token_stride_size, + dst_layout_dim=self.layout_dim, + num_layers=self.layer_num, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + elif io_backend == "direct": + assert ( + self.layout == "layer_first" + ), f"Direct IO backend only supports layer_first layout." + transfer_kv_direct( + src_layers=device_pool.k_buffer + device_pool.v_buffer, + dst_layers=self.k_data_refs + self.v_data_refs, + src_indices=device_indices, + dst_indices=host_indices, + page_size=self.page_size, + ) + else: + raise ValueError(f"Unsupported IO backend: {io_backend}") + def get_flat_data_page(self, index) -> torch.Tensor: - return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten() + if self.layout == "layer_first": + return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten() + elif self.layout == "page_first": + return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten() + else: + raise ValueError(f"Unsupported layout: {self.layout}") def get_dummy_flat_data_page(self) -> torch.Tensor: return torch.zeros( @@ -273,13 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache): ).flatten() def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: - self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape( - 2, - self.layer_num, - self.page_size, - self.head_num, - self.head_dim, - ) + if self.layout == "layer_first": + self.kv_buffer[:, :, index : index + self.page_size, :, :] = ( + data_page.reshape( + 2, + self.layer_num, + self.page_size, + self.head_num, + self.head_dim, + ) + ) + elif self.layout == "page_first": + self.kv_buffer[:, index : index + self.page_size, :, :, :] = ( + data_page.reshape( + 2, self.page_size, self.layer_num, self.head_num, self.head_dim + ) + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") def get_buffer_meta(self, keys, indices): ptr_list = [] @@ -318,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache): element_size_list = [element_size] * len(key_list) return key_list, ptr_list, element_size_list - @property - def k_buffer(self): - return self.kv_buffer[0] - - @property - def v_buffer(self): - return self.kv_buffer[1] - class MLATokenToKVPoolHost(HostKVCache): device_pool: MLATokenToKVPool @@ -336,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache): host_to_device_ratio: float, host_size: int, page_size: int, + layout: str, pin_memory: bool = True, device: str = "cpu", ): super().__init__( - device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size + device_pool, + host_to_device_ratio, + host_size, + page_size, + layout, + pin_memory, + device, + ) + self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)] + self.data_ptrs = torch.tensor( + [x.data_ptr() for x in self.data_refs], + dtype=torch.uint64, + device=self.device_pool.device, ) def get_size_per_token(self): @@ -356,20 +539,115 @@ class MLATokenToKVPoolHost(HostKVCache): ) def init_kv_buffer(self): - return torch.empty( - ( + if self.layout == "layer_first": + dims = ( self.layer_num, self.size, 1, self.kv_lora_rank + self.qk_rope_head_dim, - ), + ) + elif self.layout == "page_first": + dims = ( + self.size, + self.layer_num, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + self.token_stride_size = ( + self.kv_lora_rank + self.qk_rope_head_dim + ) * self.dtype.itemsize + self.layout_dim = self.token_stride_size * self.layer_num + + return torch.empty( + dims, dtype=self.dtype, device=self.device, pin_memory=self.pin_memory, ) + def load_to_device_per_layer( + self, device_pool, host_indices, device_indices, layer_id, io_backend + ): + if io_backend == "kernel": + if self.layout == "layer_first": + transfer_kv_per_layer_mla( + src=self.kv_buffer[layer_id], + dst=device_pool.kv_buffer[layer_id], + src_indices=host_indices, + dst_indices=device_indices, + item_size=self.token_stride_size, + ) + elif self.layout == "page_first": + transfer_kv_per_layer_mla_pf_lf( + src=self.kv_buffer, + dst=device_pool.kv_buffer[layer_id], + src_indices=host_indices, + dst_indices=device_indices, + item_size=self.token_stride_size, + src_layout_dim=self.layout_dim, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + elif io_backend == "direct": + assert ( + self.layout == "layer_first" + ), f"Direct IO backend only supports layer_first layout." + transfer_kv_direct( + src_layers=[self.kv_buffer[layer_id]], + dst_layers=[device_pool.kv_buffer[layer_id]], + src_indices=host_indices, + dst_indices=device_indices, + page_size=self.page_size, + ) + + def backup_from_device_all_layer( + self, device_pool, host_indices, device_indices, io_backend + ): + if io_backend == "kernel": + if self.layout == "layer_first": + transfer_kv_all_layer_mla( + src_layers=device_pool.data_ptrs, + dst_layers=self.data_ptrs, + src_indices=device_indices, + dst_indices=host_indices, + item_size=self.token_stride_size, + num_layers=self.layer_num, + ) + elif self.layout == "page_first": + transfer_kv_all_layer_mla_lf_pf( + src_layers=device_pool.data_ptrs, + dst_k=self.kv_buffer, + src_indices=device_indices, + dst_indices=host_indices, + item_size=self.token_stride_size, + dst_layout_dim=self.layout_dim, + num_layers=self.layer_num, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + elif io_backend == "direct": + assert ( + self.layout == "layer_first" + ), f"Direct IO backend only supports layer_first layout." + transfer_kv_direct( + src_layers=device_pool.kv_buffer, + dst_layers=self.data_refs, + src_indices=device_indices, + dst_indices=host_indices, + page_size=self.page_size, + ) + else: + raise ValueError(f"Unsupported IO backend: {io_backend}") + def get_flat_data_page(self, index) -> torch.Tensor: - return self.kv_buffer[:, index : index + self.page_size, :, :].flatten() + if self.layout == "layer_first": + return self.kv_buffer[:, index : index + self.page_size, :, :].flatten() + elif self.layout == "page_first": + return self.kv_buffer[index : index + self.page_size, :, :, :].flatten() + else: + raise ValueError(f"Unsupported layout: {self.layout}") def get_dummy_flat_data_page(self) -> torch.Tensor: return torch.zeros( @@ -385,12 +663,22 @@ class MLATokenToKVPoolHost(HostKVCache): ).flatten() def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: - self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape( - self.layer_num, - self.page_size, - 1, - self.kv_lora_rank + self.qk_rope_head_dim, - ) + if self.layout == "layer_first": + self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape( + self.layer_num, + self.page_size, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + elif self.layout == "page_first": + self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape( + self.page_size, + self.layer_num, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") def get_buffer_meta(self, keys, indices): ptr_list = [] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2927a7071..0b442dede 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -198,7 +198,8 @@ class ServerArgs: hicache_ratio: float = 2.0 hicache_size: int = 0 hicache_write_policy: str = "write_through_selective" - hicache_io_backend: str = "" + hicache_io_backend: str = "kernel" + hicache_mem_layout: str = "layer_first" hicache_storage_backend: Optional[str] = None # Double Sparsity @@ -1487,6 +1488,14 @@ class ServerArgs: default=ServerArgs.hicache_io_backend, help="The IO backend for KV cache transfer between CPU and GPU", ) + parser.add_argument( + "--hicache-mem-layout", + type=str, + choices=["layer_first", "page_first"], + default=ServerArgs.hicache_mem_layout, + help="The layout of host memory pool for hierarchical cache.", + ) + parser.add_argument( "--hicache-storage-backend", type=str,