Add hf3fs support for hicache storage (based on #7704) (#7280)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
pansicheng
2025-07-31 08:42:41 +08:00
committed by GitHub
parent a79a5d7012
commit 299803343d
12 changed files with 1110 additions and 23 deletions

View File

@@ -26,6 +26,7 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
logger = logging.getLogger(__name__)
@@ -250,17 +251,33 @@ class HiCacheController:
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1:
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
self.prefetch_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.backup_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
if storage_backend == "file":
self.storage_backend = HiCacheFile()
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
elif storage_backend == "hf3fs":
from sglang.srt.distributed import get_tensor_model_parallel_rank
rank = get_tensor_model_parallel_rank()
bytes_per_page = (
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
)
dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config(
rank, bytes_per_page, dtype
)
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
)
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -522,8 +539,8 @@ class HiCacheController:
while not self.stop_event.is_set():
try:
operation = self.prefetch_buffer.get(block=True, timeout=1)
for h in operation.hash_value:
page_data = self.storage_backend.get(h)
page_datas = self.storage_backend.batch_get(operation.hash_value)
for h, page_data in zip(operation.hash_value, page_datas):
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
@@ -531,7 +548,9 @@ class HiCacheController:
break
if operation.increment(self.page_size):
self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[operation.completed_tokens],
operation.host_indices[
operation.completed_tokens - self.page_size
],
page_data,
)
else:
@@ -583,7 +602,7 @@ class HiCacheController:
torch.distributed.all_reduce(
storage_hit_count_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
group=self.prefetch_tp_group,
)
storage_hit_count = storage_hit_count_tensor.item()
@@ -635,21 +654,23 @@ class HiCacheController:
last_hash = operation.last_hash
tokens_to_backup = operation.token_ids
last_hashes, data_pages = [], []
for i in range(0, len(tokens_to_backup), self.page_size):
last_hash = get_hash_str(
tokens_to_backup[i : i + self.page_size], last_hash
)
success = self.storage_backend.set(
last_hash,
self.mem_pool_host.get_flat_data_page(
operation.host_indices[i]
),
data_page = self.mem_pool_host.get_flat_data_page(
operation.host_indices[i]
)
if not success:
logger.warning(f"Failed to write page {last_hash} to storage.")
break
operation.completed_tokens += self.page_size
operation.hash_value.append(last_hash)
last_hashes.append(last_hash)
data_pages.append(data_page)
success = self.storage_backend.batch_set(last_hashes, data_pages)
if not success:
logger.warning(f"Failed to write page {last_hashes} to storage.")
else:
operation.completed_tokens += len(tokens_to_backup)
operation.hash_value.extend(last_hashes)
min_completed_tokens = operation.completed_tokens
if self.tp_world_size > 1:
@@ -659,7 +680,7 @@ class HiCacheController:
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
group=self.backup_tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()