Conditionally import HiCacheHF3FS (#8598)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -25,12 +25,6 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
||||
MooncakeStore,
|
||||
get_hash_str_mooncake,
|
||||
)
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -251,16 +245,7 @@ class HiCacheController:
|
||||
self.enable_storage = False
|
||||
# todo: move backend initialization to storage backend module
|
||||
if storage_backend is not None:
|
||||
# create a new communication group for synchronizing storage operations across TP workers
|
||||
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||
if self.tp_world_size > 1:
|
||||
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
||||
self.prefetch_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
self.backup_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||
|
||||
if storage_backend == "file":
|
||||
self.storage_backend = HiCacheFile()
|
||||
@@ -271,11 +256,19 @@ class HiCacheController:
|
||||
self.storage_backend = HiCacheNixl()
|
||||
self.get_hash_str = get_hash_str
|
||||
elif storage_backend == "mooncake":
|
||||
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
||||
MooncakeStore,
|
||||
get_hash_str_mooncake,
|
||||
)
|
||||
|
||||
self.storage_backend = MooncakeStore()
|
||||
self.get_hash_str = get_hash_str_mooncake
|
||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||
elif storage_backend == "hf3fs":
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
||||
HiCacheHF3FS,
|
||||
)
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
bytes_per_page = (
|
||||
@@ -293,6 +286,16 @@ class HiCacheController:
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||
# create a new communication group for synchronizing storage operations across TP workers
|
||||
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||
if self.tp_world_size > 1:
|
||||
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
||||
self.prefetch_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
self.backup_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
|
||||
Reference in New Issue
Block a user