diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index f9d45b2f7..384dceb31 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -532,9 +532,12 @@ class HiCacheController: host_indices = host_indices.to(self.device, non_blocking=True) return host_indices, device_indices elif self.io_backend == "direct": - device_indices = device_indices.cpu() - host_indices, idx = host_indices.sort() - return host_indices, device_indices.index_select(0, idx) + if self.mem_pool_host.layout == "layer_first": + device_indices = device_indices.cpu() + host_indices, idx = host_indices.sort() + return host_indices, device_indices.index_select(0, idx) + elif self.mem_pool_host.layout == "page_first_direct": + return host_indices, device_indices.cpu() else: raise ValueError(f"Unsupported io backend") diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index dc27eaa03..cab5f5ae2 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -16,11 +16,13 @@ _is_xpu = is_xpu() if not (_is_npu or _is_xpu): from sgl_kernel.kvcacheio import ( transfer_kv_all_layer, + transfer_kv_all_layer_direct_lf_pf, transfer_kv_all_layer_lf_pf, transfer_kv_all_layer_mla, transfer_kv_all_layer_mla_lf_pf, transfer_kv_direct, transfer_kv_per_layer, + transfer_kv_per_layer_direct_pf_lf, transfer_kv_per_layer_mla, transfer_kv_per_layer_mla_pf_lf, transfer_kv_per_layer_pf_lf, @@ -78,6 +80,7 @@ class HostKVCache(abc.ABC): self.size = int(device_pool.size * host_to_device_ratio) # Align the host memory pool size to the page size self.size = self.size - (self.size % self.page_size) + self.page_num = self.size // self.page_size self.start_layer = device_pool.start_layer self.end_layer = device_pool.end_layer @@ -317,6 +320,15 @@ class MHATokenToKVPoolHost(HostKVCache): dims = (2, self.layer_num, self.size, self.head_num, self.head_dim) elif self.layout == "page_first": dims = (2, self.size, self.layer_num, self.head_num, self.head_dim) + elif self.layout == "page_first_direct": + dims = ( + 2, + self.page_num, + self.layer_num, + self.page_size, + self.head_num, + self.head_dim, + ) else: raise ValueError(f"Unsupported layout: {self.layout}") self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize @@ -370,19 +382,31 @@ class MHATokenToKVPoolHost(HostKVCache): else: raise ValueError(f"Unsupported layout: {self.layout}") elif io_backend == "direct": - assert ( - self.layout == "layer_first" - ), f"Direct IO backend only supports layer_first layout." - transfer_kv_direct( - src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]], - dst_layers=[ - device_pool.k_buffer[layer_id], - device_pool.v_buffer[layer_id], - ], - src_indices=host_indices, - dst_indices=device_indices, - page_size=self.page_size, - ) + if self.layout == "layer_first": + transfer_kv_direct( + src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]], + dst_layers=[ + device_pool.k_buffer[layer_id], + device_pool.v_buffer[layer_id], + ], + src_indices=host_indices, + dst_indices=device_indices, + page_size=self.page_size, + ) + elif self.layout == "page_first_direct": + transfer_kv_per_layer_direct_pf_lf( + src_ptrs=[self.k_buffer, self.v_buffer], + dst_ptrs=[ + device_pool.k_buffer[layer_id], + device_pool.v_buffer[layer_id], + ], + src_indices=host_indices, + dst_indices=device_indices, + layer_id=layer_id, + page_size=self.page_size, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") else: raise ValueError(f"Unsupported IO backend: {io_backend}") @@ -416,16 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache): else: raise ValueError(f"Unsupported layout: {self.layout}") elif io_backend == "direct": - assert ( - self.layout == "layer_first" - ), f"Direct IO backend only supports layer_first layout." - transfer_kv_direct( - src_layers=device_pool.k_buffer + device_pool.v_buffer, - dst_layers=self.k_data_refs + self.v_data_refs, - src_indices=device_indices, - dst_indices=host_indices, - page_size=self.page_size, - ) + if self.layout == "layer_first": + transfer_kv_direct( + src_layers=device_pool.k_buffer + device_pool.v_buffer, + dst_layers=self.k_data_refs + self.v_data_refs, + src_indices=device_indices, + dst_indices=host_indices, + page_size=self.page_size, + ) + elif self.layout == "page_first_direct": + transfer_kv_all_layer_direct_lf_pf( + src_ptrs=device_pool.k_buffer + device_pool.v_buffer, + dst_ptrs=[self.k_buffer, self.v_buffer], + src_indices=device_indices, + dst_indices=host_indices, + page_size=self.page_size, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") else: raise ValueError(f"Unsupported IO backend: {io_backend}") @@ -580,6 +612,14 @@ class MLATokenToKVPoolHost(HostKVCache): 1, self.kv_lora_rank + self.qk_rope_head_dim, ) + elif self.layout == "page_first_direct": + dims = ( + self.page_num, + self.layer_num, + self.page_size, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ) else: raise ValueError(f"Unsupported layout: {self.layout}") self.token_stride_size = ( @@ -619,16 +659,25 @@ class MLATokenToKVPoolHost(HostKVCache): else: raise ValueError(f"Unsupported layout: {self.layout}") elif io_backend == "direct": - assert ( - self.layout == "layer_first" - ), f"Direct IO backend only supports layer_first layout." - transfer_kv_direct( - src_layers=[self.kv_buffer[layer_id]], - dst_layers=[device_pool.kv_buffer[layer_id]], - src_indices=host_indices, - dst_indices=device_indices, - page_size=self.page_size, - ) + if self.layout == "layer_first": + transfer_kv_direct( + src_layers=[self.kv_buffer[layer_id]], + dst_layers=[device_pool.kv_buffer[layer_id]], + src_indices=host_indices, + dst_indices=device_indices, + page_size=self.page_size, + ) + elif self.layout == "page_first_direct": + transfer_kv_per_layer_direct_pf_lf( + src_ptrs=[self.kv_buffer], + dst_ptrs=[device_pool.kv_buffer[layer_id]], + src_indices=host_indices, + dst_indices=device_indices, + layer_id=layer_id, + page_size=self.page_size, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") def backup_from_device_all_layer( self, device_pool, host_indices, device_indices, io_backend @@ -656,16 +705,24 @@ class MLATokenToKVPoolHost(HostKVCache): else: raise ValueError(f"Unsupported layout: {self.layout}") elif io_backend == "direct": - assert ( - self.layout == "layer_first" - ), f"Direct IO backend only supports layer_first layout." - transfer_kv_direct( - src_layers=device_pool.kv_buffer, - dst_layers=self.data_refs, - src_indices=device_indices, - dst_indices=host_indices, - page_size=self.page_size, - ) + if self.layout == "layer_first": + transfer_kv_direct( + src_layers=device_pool.kv_buffer, + dst_layers=self.data_refs, + src_indices=device_indices, + dst_indices=host_indices, + page_size=self.page_size, + ) + elif self.layout == "page_first_direct": + transfer_kv_all_layer_direct_lf_pf( + src_ptrs=device_pool.kv_buffer, + dst_ptrs=[self.kv_buffer], + src_indices=device_indices, + dst_indices=host_indices, + page_size=self.page_size, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") else: raise ValueError(f"Unsupported IO backend: {io_backend}") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1af94c457..363b25f46 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -721,6 +721,13 @@ class ServerArgs: self.hicache_io_backend = "kernel" self.hicache_mem_layout = "page_first" + if self.hicache_mem_layout == "page_first_direct": + if self.hicache_io_backend != "direct": + self.hicache_io_backend = "direct" + logger.warning( + "Page first direct layout only support direct io backend" + ) + # Speculative Decoding if self.speculative_algorithm == "NEXTN": # NEXTN shares the same implementation of EAGLE @@ -1779,7 +1786,7 @@ class ServerArgs: parser.add_argument( "--hicache-mem-layout", type=str, - choices=["layer_first", "page_first"], + choices=["layer_first", "page_first", "page_first_direct"], default=ServerArgs.hicache_mem_layout, help="The layout of host memory pool for hierarchical cache.", )