diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index fca7a6a39..629dd71a2 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -25,12 +25,6 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool_host import HostKVCache -from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str -from sglang.srt.mem_cache.mooncake_store.mooncake_store import ( - MooncakeStore, - get_hash_str_mooncake, -) -from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS logger = logging.getLogger(__name__) @@ -251,16 +245,7 @@ class HiCacheController: self.enable_storage = False # todo: move backend initialization to storage backend module if storage_backend is not None: - # create a new communication group for synchronizing storage operations across TP workers - self.tp_world_size = torch.distributed.get_world_size(group=tp_group) - if self.tp_world_size > 1: - group_ranks = torch.distributed.get_process_group_ranks(tp_group) - self.prefetch_tp_group = torch.distributed.new_group( - group_ranks, backend="gloo" - ) - self.backup_tp_group = torch.distributed.new_group( - group_ranks, backend="gloo" - ) + from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str if storage_backend == "file": self.storage_backend = HiCacheFile() @@ -271,11 +256,19 @@ class HiCacheController: self.storage_backend = HiCacheNixl() self.get_hash_str = get_hash_str elif storage_backend == "mooncake": + from sglang.srt.mem_cache.mooncake_store.mooncake_store import ( + MooncakeStore, + get_hash_str_mooncake, + ) + self.storage_backend = MooncakeStore() self.get_hash_str = get_hash_str_mooncake self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) elif storage_backend == "hf3fs": from sglang.srt.distributed import get_tensor_model_parallel_rank + from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import ( + HiCacheHF3FS, + ) rank = get_tensor_model_parallel_rank() bytes_per_page = ( @@ -293,6 +286,16 @@ class HiCacheController: self.enable_storage = True # todo: threshold policy for prefetching self.prefetch_threshold = max(prefetch_threshold, self.page_size) + # create a new communication group for synchronizing storage operations across TP workers + self.tp_world_size = torch.distributed.get_world_size(group=tp_group) + if self.tp_world_size > 1: + group_ranks = torch.distributed.get_process_group_ranks(tp_group) + self.prefetch_tp_group = torch.distributed.new_group( + group_ranks, backend="gloo" + ) + self.backup_tp_group = torch.distributed.new_group( + group_ranks, backend="gloo" + ) self.load_cache_event = load_cache_event self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)