Conditionally import HiCacheHF3FS (#8598)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -25,12 +25,6 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
|
||||||
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
|
||||||
MooncakeStore,
|
|
||||||
get_hash_str_mooncake,
|
|
||||||
)
|
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -251,16 +245,7 @@ class HiCacheController:
|
|||||||
self.enable_storage = False
|
self.enable_storage = False
|
||||||
# todo: move backend initialization to storage backend module
|
# todo: move backend initialization to storage backend module
|
||||||
if storage_backend is not None:
|
if storage_backend is not None:
|
||||||
# create a new communication group for synchronizing storage operations across TP workers
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||||
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
|
||||||
if self.tp_world_size > 1:
|
|
||||||
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
|
||||||
self.prefetch_tp_group = torch.distributed.new_group(
|
|
||||||
group_ranks, backend="gloo"
|
|
||||||
)
|
|
||||||
self.backup_tp_group = torch.distributed.new_group(
|
|
||||||
group_ranks, backend="gloo"
|
|
||||||
)
|
|
||||||
|
|
||||||
if storage_backend == "file":
|
if storage_backend == "file":
|
||||||
self.storage_backend = HiCacheFile()
|
self.storage_backend = HiCacheFile()
|
||||||
@@ -271,11 +256,19 @@ class HiCacheController:
|
|||||||
self.storage_backend = HiCacheNixl()
|
self.storage_backend = HiCacheNixl()
|
||||||
self.get_hash_str = get_hash_str
|
self.get_hash_str = get_hash_str
|
||||||
elif storage_backend == "mooncake":
|
elif storage_backend == "mooncake":
|
||||||
|
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
||||||
|
MooncakeStore,
|
||||||
|
get_hash_str_mooncake,
|
||||||
|
)
|
||||||
|
|
||||||
self.storage_backend = MooncakeStore()
|
self.storage_backend = MooncakeStore()
|
||||||
self.get_hash_str = get_hash_str_mooncake
|
self.get_hash_str = get_hash_str_mooncake
|
||||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||||
elif storage_backend == "hf3fs":
|
elif storage_backend == "hf3fs":
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
||||||
|
HiCacheHF3FS,
|
||||||
|
)
|
||||||
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
rank = get_tensor_model_parallel_rank()
|
||||||
bytes_per_page = (
|
bytes_per_page = (
|
||||||
@@ -293,6 +286,16 @@ class HiCacheController:
|
|||||||
self.enable_storage = True
|
self.enable_storage = True
|
||||||
# todo: threshold policy for prefetching
|
# todo: threshold policy for prefetching
|
||||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||||
|
# create a new communication group for synchronizing storage operations across TP workers
|
||||||
|
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||||
|
if self.tp_world_size > 1:
|
||||||
|
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
||||||
|
self.prefetch_tp_group = torch.distributed.new_group(
|
||||||
|
group_ranks, backend="gloo"
|
||||||
|
)
|
||||||
|
self.backup_tp_group = torch.distributed.new_group(
|
||||||
|
group_ranks, backend="gloo"
|
||||||
|
)
|
||||||
|
|
||||||
self.load_cache_event = load_cache_event
|
self.load_cache_event = load_cache_event
|
||||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||||
|
|||||||
Reference in New Issue
Block a user