diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 8acce8ac7..2d5711984 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -324,6 +324,22 @@ class HiCacheController: group_ranks, backend="gloo" ) + # Select the get and set functions + self.page_get_func = self._generic_page_get + self.page_set_func = self._generic_page_set + self.batch_exists_func = self.storage_backend.batch_exists + self.is_3fs_zerocopy = ( + self.storage_backend_type == "hf3fs" + and self.mem_pool_host.layout == "page_first" + ) + if self.storage_backend_type == "mooncake": + self.page_get_func = self._mooncake_page_get + self.page_set_func = self._mooncake_page_set + elif self.is_3fs_zerocopy: + self.page_get_func = self._3fs_zero_copy_page_get + self.page_set_func = self._3fs_zero_copy_page_set + self.batch_exists_func = self._3fs_zero_copy_batch_exists + self.load_cache_event = load_cache_event self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) @@ -617,13 +633,19 @@ class HiCacheController: for chunk in chunks: self.host_mem_release_queue.put(chunk) + def _3fs_zero_copy_batch_exists(self, batch_hashes): + _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes) + hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor + return hit_page_num + def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices): - hashes, dsts = self.mem_pool_host.get_buffer_with_hash( + hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash( hash_values, host_indices ) page_data = self.storage_backend.batch_get(hashes, dsts) if page_data: - operation.increment(self.page_size * len(hashes)) + inc = self.page_size * len(hashes) // factor + operation.increment(inc) else: logger.warning( f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}." @@ -670,17 +692,6 @@ class HiCacheController: break # Operation terminated by controller def _page_transfer(self, operation): - # Select the get function and batch size - if self.storage_backend_type == "mooncake": - get_func = self._mooncake_page_get - elif ( - self.storage_backend_type == "hf3fs" - and self.mem_pool_host.layout == "page_first" - ): - get_func = self._3fs_zero_copy_page_get - else: - get_func = self._generic_page_get - # Transfer batch by batch for i in range(0, len(operation.hash_value), self.storage_batch_size): batch_hashes = operation.hash_value[i : i + self.storage_batch_size] @@ -689,7 +700,7 @@ class HiCacheController: ] prev_completed_tokens = operation.completed_tokens # Get one batch token, and update the completed_tokens if succeed - get_func(operation, batch_hashes, batch_host_indices) + self.page_get_func(operation, batch_hashes, batch_host_indices) # Check termination if ( operation.completed_tokens @@ -746,7 +757,7 @@ class HiCacheController: batch_tokens[i : i + self.page_size], last_hash ) batch_hashes.append(last_hash) - hit_page_num = self.storage_backend.batch_exists(batch_hashes) + hit_page_num = self.batch_exists_func(batch_hashes) hash_value.extend(batch_hashes[:hit_page_num]) storage_query_count += hit_page_num * self.page_size if hit_page_num < len(batch_hashes): @@ -839,23 +850,13 @@ class HiCacheController: # zero copy def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool: - hashes, dsts = self.mem_pool_host.get_buffer_with_hash( + hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash( hash_values, host_indices ) return self.storage_backend.batch_set(hashes, dsts) # Backup batch by batch def _page_backup(self, operation): - # Select the set function and batch size - if self.storage_backend_type == "mooncake": - backup_set_func = self._mooncake_page_set - elif ( - self.storage_backend_type == "hf3fs" - and self.mem_pool_host.layout == "page_first" - ): - backup_set_func = self._3fs_zero_copy_page_set - else: - backup_set_func = self._generic_page_set # Backup batch by batch for i in range(0, len(operation.hash_value), self.storage_batch_size): batch_hashes = operation.hash_value[i : i + self.storage_batch_size] @@ -864,7 +865,7 @@ class HiCacheController: ] # Set one batch token, and record if success. # todo: allow partial success - success = backup_set_func(batch_hashes, batch_host_indices) + success = self.page_set_func(batch_hashes, batch_host_indices) if not success: logger.warning( f"Write page to storage: {len(batch_hashes)} pages failed." diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index c216a1387..9b9553238 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -500,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache): element_size_list = [element_size] * len(key_list) return key_list, ptr_list, element_size_list - def get_buffer_with_hash(self, keys, indices): + def get_buffer_with_hash(self, keys, indices=None): assert self.layout == "page_first" - assert len(keys) == (len(indices) // self.page_size) + assert indices is None or (len(keys) == (len(indices) // self.page_size)) key_list = [] buf_list = [] - for key, i in zip(keys, range(0, len(indices), self.page_size)): + for i in range(len(keys)): + key = keys[i] key_list.append(f"{key}-k") - buf_list.append(self.k_buffer[i : i + self.page_size]) key_list.append(f"{key}-v") - buf_list.append(self.v_buffer[i : i + self.page_size]) + if indices is not None: + index = indices[i * self.page_size] + buf_list.append(self.k_buffer[index : index + self.page_size]) + buf_list.append(self.v_buffer[index : index + self.page_size]) - return key_list, buf_list + return key_list, buf_list, 2 class MLATokenToKVPoolHost(HostKVCache): @@ -728,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache): element_size_list = [element_size] * len(key_list) return key_list, ptr_list, element_size_list - def get_buffer_with_hash(self, keys, indices): + def get_buffer_with_hash(self, keys, indices=None): assert self.layout == "page_first" - assert len(keys) == (len(indices) // self.page_size) + assert indices is None or (len(keys) == (len(indices) // self.page_size)) buf_list = [] - for i in range(0, len(indices), self.page_size): - buf_list.append(self.kv_buffer[i : i + self.page_size]) + if indices is not None: + for i in range(len(keys)): + index = indices[i * self.page_size] + buf_list.append(self.kv_buffer[index : index + self.page_size]) - return keys, buf_list + return keys, buf_list, 1 diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index fe27673c4..48d545889 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -415,22 +415,12 @@ class HiCacheHF3FS(HiCacheStorage): return result[0] if result else False def batch_exists(self, keys: List[str]) -> int: - if self.is_page_first_layout and not self.is_mla_model: - query_keys = [] - # Compatible with page_first layout's key format, Refer to memory_pool_host.py#get_buffer_with_hash - for key in keys: - query_keys.append(f"{key}-k") - query_keys.append(f"{key}-v") - key_multiplier = 2 - else: - query_keys = keys - key_multiplier = 1 + results = self.metadata_client.exists(self.rank, keys) + for i in range(len(keys)): + if not results[i]: + return i - exist_result = self.metadata_client.exists(self.rank, query_keys) - for i in range(len(query_keys)): - if not exist_result[i]: - return i // key_multiplier - return len(query_keys) // key_multiplier + return len(keys) def clear(self) -> bool: try: