From 1b7afad0dd36dd2b5d7ced4e21cd6ef381f356fb Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Sat, 9 Aug 2025 16:03:00 +0800 Subject: [PATCH] feature(hicache): Support hf3fs-hicache reusing kvcache across different instances (#8673) Co-authored-by: Zhiqiang Xie --- .../mem_cache/storage/hf3fs/docs/README.md | 29 ++ .../hf3fs/docs/deploy_sglang_3fs_multinode.md | 65 +++ .../setup_usrbio_client.md} | 0 .../storage/hf3fs/mini_3fs_metadata_server.py | 443 ++++++++++++++++++ .../mem_cache/storage/hf3fs/storage_hf3fs.py | 210 ++++++--- 5 files changed, 678 insertions(+), 69 deletions(-) create mode 100644 python/sglang/srt/mem_cache/storage/hf3fs/docs/README.md create mode 100644 python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md rename python/sglang/srt/mem_cache/storage/hf3fs/{README.md => docs/setup_usrbio_client.md} (100%) create mode 100644 python/sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/docs/README.md b/python/sglang/srt/mem_cache/storage/hf3fs/docs/README.md new file mode 100644 index 000000000..63be34293 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/hf3fs/docs/README.md @@ -0,0 +1,29 @@ +# HF3FS as L3 KV Cache + +This document describes how to use deepseek-hf3fs as the L3 KV cache for SGLang. + +## Step1: Install deepseek-3fs by 3fs-Operator (Coming Soon) + +## Step2: Setup usrbio client + +Please follow the document [setup_usrbio_client.md](setup_usrbio_client.md) to setup usrbio client. + +## Step3: Deployment + +### Single node deployment + +```bash +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages +python3 -m sglang.launch_server \ + --model-path /code/models/Qwen3-32B/ \ + --host 0.0.0.0 --port 10000 \ + --page-size 64 \ + --enable-hierarchical-cache \ + --hicache-ratio 2 --hicache-size 0 \ + --hicache-write-policy write_through \ + --hicache-storage-backend hf3fs +``` + +### Multi nodes deployment to share KV cache + +Please follow the document [deploy_sglang_3fs_multinode.md](deploy_sglang_3fs_multinode.md) to deploy SGLang with 3FS on multiple nodes to share KV cache. diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md b/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md new file mode 100644 index 000000000..c2955cd3e --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md @@ -0,0 +1,65 @@ +# 1. Startup 3fs metadata service +```bash +nohup python3 -m sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server > meta.out & +``` + + +# 2. Startup sglang engine +## HF3fs configures +```bash +vim /sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json +{ + "file_path_prefix": "/data/hicache", + "file_size": 1099511627776, + "numjobs": 16, + "entries": 8, + "metadata_server_url": "http://metaServerIp:18000" +} +``` + +## node1 +```bash +export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages +rm -rf instance1.out && \ +nohup python3 -m sglang.launch_server \ + --model-path /code/models/Qwen3-32B/ \ + --host 0.0.0.0 --port 10000 \ + --page-size 64 \ + --enable-hierarchical-cache \ + --hicache-ratio 2 --hicache-size 0 \ + --hicache-write-policy write_through \ + --hicache-storage-backend hf3fs --tp 2 > instance1.out & +``` + +## node2 +```bash +export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages +rm -rf instance2.out && \ +nohup python3 -m sglang.launch_server \ + --model-path /code/models/Qwen3-32B/ \ + --host 0.0.0.0 --port 10000 \ + --page-size 64 \ + --enable-hierarchical-cache \ + --hicache-ratio 2 --hicache-size 0 \ + --hicache-write-policy write_through \ + --hicache-storage-backend hf3fs --tp 2 > instance2.out & +``` + +# 3. Startup router +```bash +rm -rf router.out && \ +nohup python -m sglang_router.launch_router --worker-urls http://node1:10000 http://node2:10000 > router.out & +``` + +# 4. Startup multiturn benchmark +```bash +rm -rf bench_multiturn.out && \ +nohup python3 benchmark/hicache/bench_multiturn.py \ + --model-path /code/models/Qwen3-32B \ + --dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \ + --port 30000 \ + --request-length 2048 --num-clients 512 --num-rounds 5 --max-parallel 8 \ + > bench_multiturn.out & +``` diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/README.md b/python/sglang/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md similarity index 100% rename from python/sglang/srt/mem_cache/storage/hf3fs/README.md rename to python/sglang/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py b/python/sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py new file mode 100644 index 000000000..1967259ac --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py @@ -0,0 +1,443 @@ +import argparse +import atexit +import json +import logging +import threading +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import requests +from fastapi import FastAPI, HTTPException, Request, status +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import Hf3fsMetadataInterface + +# --- Configuration --- +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +# --- Data Models --- +class RankMetadata: + """Holds all metadata for a single rank.""" + + def __init__(self, num_pages: int): + self.lock = threading.RLock() + self.num_pages = num_pages + self.free_pages: List[int] = list(range(num_pages)) + self.key_to_index: Dict[str, int] = {} + # Todo: Support multi files for HF3FS + + def exists_keys(self, keys: List[str]) -> List[bool]: + """Check if keys exist in metadata.""" + with self.lock: + return [key in self.key_to_index for key in keys] + + def reserve_and_allocate_page_indices( + self, keys: List[Tuple[str, str]] + ) -> List[Tuple[bool, int]]: + """Reserve and allocate page indices for keys.""" + with self.lock: + results = [None] * len(keys) + new_keys_to_process = [] + + for i, (key, prefix_key) in enumerate(keys): + if key in self.key_to_index: + results[i] = (True, self.key_to_index[key]) + else: + new_keys_to_process.append((i, key, prefix_key)) + + # Todo: Implementing data eviction logic after HiCache supports prefix information pass-through + for i, key, prefix_key in new_keys_to_process: + if len(self.free_pages) > 0: + page_idx = self.free_pages.pop() + results[i] = (False, page_idx) + else: + results[i] = (False, -1) + + return results + + def confirm_write( + self, + written_keys_to_confirm: List[Tuple[str, int]], + pages_to_release: List[int], + ) -> None: + """Confirm write operations and release pages.""" + with self.lock: + for key, page_index in written_keys_to_confirm: + self.key_to_index[key] = page_index + + for page_index in pages_to_release: + if page_index not in self.free_pages: + self.free_pages.append(page_index) + + def delete_keys(self, keys: List[str]) -> int: + """Delete keys and return count of deleted keys.""" + with self.lock: + count = 0 + for key in keys: + if key in self.key_to_index: + page_index = self.key_to_index.pop(key) + if page_index not in self.free_pages: + self.free_pages.append(page_index) + count += 1 + return count + + def clear_all(self) -> None: + """Clear all metadata.""" + with self.lock: + self.free_pages = list(range(self.num_pages)) + self.key_to_index.clear() + + def get_page_indices(self, keys: List[str]) -> List[Optional[int]]: + """Get page indices for keys.""" + with self.lock: + return [self.key_to_index.get(key) for key in keys] + + +class GlobalMetadataState: + """Manages the state for all ranks and persistence.""" + + def __init__(self, persistence_path: Optional[str], save_interval: int): + self.global_lock = threading.RLock() + self.ranks: Dict[int, RankMetadata] = {} + self.persistence_path = Path(persistence_path) if persistence_path else None + self.save_interval = save_interval + self.save_timer: Optional[threading.Timer] = None + self.is_shutting_down = False + + def load_from_disk(self): + if not self.persistence_path or not self.persistence_path.exists(): + logging.info("Persistence file not found. Starting with a clean state.") + return + + logging.info(f"Loading state from {self.persistence_path}") + try: + with open(self.persistence_path, "r") as f: + persisted_data = json.load(f) + + with self.global_lock: + for rank_id_str, data in persisted_data.items(): + rank_id = int(rank_id_str) + num_pages = data["num_pages"] + rank_meta = RankMetadata(num_pages) + rank_meta.free_pages = data["free_pages"] + rank_meta.key_to_index = dict(data["key_to_index"]) + self.ranks[rank_id] = rank_meta + logging.info( + f"Successfully loaded metadata for {len(self.ranks)} ranks." + ) + except (json.JSONDecodeError, KeyError, TypeError) as e: + logging.error( + f"Failed to load or parse persistence file: {e}. Starting fresh.", + exc_info=True, + ) + self.ranks.clear() + + def save_to_disk(self): + if not self.persistence_path: + return + + logging.info("Persisting metadata to disk...") + with self.global_lock: + serializable_state = {} + for rank_id, rank_meta in self.ranks.items(): + with rank_meta.lock: + serializable_state[rank_id] = { + "num_pages": rank_meta.num_pages, + "free_pages": rank_meta.free_pages, + "key_to_index": list(rank_meta.key_to_index.items()), + } + + try: + temp_path = self.persistence_path.with_suffix(".tmp") + with open(temp_path, "w") as f: + json.dump(serializable_state, f, indent=4) + temp_path.rename(self.persistence_path) + logging.info(f"Metadata successfully persisted to {self.persistence_path}") + except Exception as e: + logging.error(f"Failed to save metadata to disk: {e}", exc_info=True) + + def schedule_save(self): + if self.is_shutting_down or not self.persistence_path: + return + self.save_to_disk() + self.save_timer = threading.Timer(self.save_interval, self.schedule_save) + self.save_timer.start() + + def shutdown(self): + logging.info("Shutting down metadata server...") + self.is_shutting_down = True + if self.save_timer: + self.save_timer.cancel() + self.save_to_disk() + logging.info("Shutdown complete.") + + +# --- Global MetadataServer implementation --- +class Hf3fsMetadataServer: + """HF3FS Metadata Server that manages metadata for multiple ranks.""" + + def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60): + self.state = GlobalMetadataState(persistence_path, save_interval) + self.app = FastAPI() + self._setup_routes() + + def _setup_routes(self): + """Setup FastAPI routes.""" + self.app.post("/{rank}/initialize")(self.initialize) + self.app.post("/{rank}/exists")(self.exists) + self.app.post("/{rank}/reserve_and_allocate_page_indices")( + self.reserve_and_allocate_page_indices + ) + self.app.post("/{rank}/confirm_write")(self.confirm_write) + self.app.post("/{rank}/delete_keys")(self.delete_keys) + self.app.post("/{rank}/clear")(self.clear) + self.app.post("/{rank}/get_page_indices")(self.get_page_indices) + + def get_rank_metadata(self, rank: int) -> RankMetadata: + """Get rank metadata with proper error handling.""" + with self.state.global_lock: + if rank not in self.state.ranks: + raise HTTPException( + status_code=404, + detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.", + ) + return self.state.ranks[rank] + + async def initialize(self, rank: int, request: Request): + """Initialize a rank with specified number of pages.""" + data = await request.json() + num_pages = data["num_pages"] + with self.state.global_lock: + if rank in self.state.ranks: + logging.info( + f"Rank {rank} already exists. Initialization request ignored." + ) + if self.state.ranks[rank].num_pages != num_pages: + logging.warning( + f"Rank {rank} initialized with different num_pages. Existing: {self.state.ranks[rank].num_pages}, New: {num_pages}" + ) + else: + logging.info(f"Initializing new Rank {rank} with {num_pages} pages.") + self.state.ranks[rank] = RankMetadata(num_pages) + return {"message": f"Rank {rank} is ready."} + + async def exists(self, rank: int, request: Request): + """Check if keys exist in metadata.""" + data = await request.json() + keys = data["keys"] + metadata = self.get_rank_metadata(rank) + results = metadata.exists_keys(keys) + return {"exists": results} + + async def reserve_and_allocate_page_indices(self, rank: int, request: Request): + """Reserve and allocate page indices for keys.""" + data = await request.json() + metadata = self.get_rank_metadata(rank) + keys = data["keys"] + results = metadata.reserve_and_allocate_page_indices(keys) + return {"indices": results} + + async def confirm_write(self, rank: int, request: Request): + """Confirm write operations and release pages.""" + data = await request.json() + metadata = self.get_rank_metadata(rank) + success_written_keys = data.get("written_keys_to_confirm", []) + released_pages = data.get("pages_to_release", []) + + metadata.confirm_write(success_written_keys, released_pages) + + return { + "message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released." + } + + async def delete_keys(self, rank: int, request: Request): + """Delete keys from metadata.""" + data = await request.json() + metadata = self.get_rank_metadata(rank) + count = metadata.delete_keys(data["keys"]) + return {"message": f"Rank {rank}: {count} keys deleted."} + + async def clear(self, rank: int): + """Clear all metadata for a rank.""" + metadata = self.get_rank_metadata(rank) + metadata.clear_all() + return {"message": f"Rank {rank}: Metadata cleared."} + + async def get_page_indices(self, rank: int, request: Request): + """Get page indices for keys.""" + data = await request.json() + metadata = self.get_rank_metadata(rank) + keys = data["keys"] + results = metadata.get_page_indices(keys) + return {"indices": results} + + def run(self, host: str = "0.0.0.0", port: int = 18000): + """Run the metadata server.""" + self.state.load_from_disk() + if self.state.persistence_path: + self.state.schedule_save() + atexit.register(self.state.shutdown) + + import uvicorn + + logging.info(f"Starting metadata server on http://{host}:{port}") + if self.state.persistence_path: + logging.info( + f"Persistence is ENABLED. Saving to '{self.state.persistence_path}' every {self.state.save_interval} seconds." + ) + else: + logging.info("Persistence is DISABLED.") + + uvicorn.run(self.app, host=host, port=port) + + +# --- Client implementation --- +class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface): + """Global http metadata client for HF3FS.""" + + def __init__(self, base_url: str, max_retries: int = 3): + self.base_url = base_url.rstrip("/") + self._session = requests.Session() + + retry_strategy = Retry( + total=max_retries, + backoff_factor=0.3, + status_forcelist=[500, 502, 503, 504], + allowed_methods=["GET", "POST"], + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + self._session.mount("http://", adapter) + + def _post(self, endpoint: str, json_data: dict) -> dict: + try: + response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logging.error(f"Failed to POST to {endpoint} after retries: {e}") + raise RuntimeError(f"Failed to connect to metadata server: {e}") from e + + def initialize(self, rank: int, num_pages: int) -> None: + self._post(f"{rank}/initialize", {"num_pages": num_pages}) + + def reserve_and_allocate_page_indices( + self, rank: int, keys: List[Tuple[str, str]] + ) -> List[Tuple[bool, int]]: + response = self._post( + f"{rank}/reserve_and_allocate_page_indices", {"keys": keys} + ) + return [tuple(item) for item in response.get("indices")] + + def confirm_write( + self, + rank: int, + written_keys_to_confirm: List[Tuple[str, int]], + pages_to_release: List[int], + ) -> None: + self._post( + f"{rank}/confirm_write", + { + "written_keys_to_confirm": written_keys_to_confirm, + "pages_to_release": pages_to_release, + }, + ) + + def delete_keys(self, rank: int, keys: List[str]) -> None: + self._post(f"{rank}/delete_keys", {"keys": keys}) + + def exists(self, rank: int, keys: List[str]) -> List[bool]: + response = self._post(f"{rank}/exists", {"keys": keys}) + return response.get("exists", []) + + def clear(self, rank: int) -> None: + self._post(f"{rank}/clear", {}) + + def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]: + response = self._post(f"{rank}/get_page_indices", {"keys": keys}) + return response.get("indices") + + +class Hf3fsLocalMetadataClient(Hf3fsMetadataInterface): + """Local metadata client that directly operates on single RankMetadata in memory without metadata server.""" + + def __init__(self): + self.rank_metadata = None + + def initialize(self, rank: int, num_pages: int) -> None: + self.rank_metadata = RankMetadata(num_pages) + + def reserve_and_allocate_page_indices( + self, rank: int, keys: List[Tuple[str, str]] + ) -> List[Tuple[bool, int]]: + """Reserve and allocate page indices for keys.""" + return self.rank_metadata.reserve_and_allocate_page_indices(keys) + + def confirm_write( + self, + rank: int, + written_keys_to_confirm: List[Tuple[str, int]], + pages_to_release: List[int], + ) -> None: + """Confirm write operations.""" + self.rank_metadata.confirm_write(written_keys_to_confirm, pages_to_release) + + def delete_keys(self, rank: int, keys: List[str]) -> None: + """Delete keys.""" + self.rank_metadata.delete_keys(keys) + + def exists(self, rank: int, keys: List[str]) -> List[bool]: + """Check if keys exist.""" + return self.rank_metadata.exists_keys(keys) + + def clear(self, rank: int) -> None: + """Clear all metadata for rank.""" + self.rank_metadata.clear_all() + + def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]: + """Get page indices for keys.""" + return self.rank_metadata.get_page_indices(keys) + + +def run_metadata_server( + host: str = "0.0.0.0", + port: int = 18000, + persistence_path: Optional[str] = None, + save_interval: int = 60, +): + """Run the HF3FS metadata server.""" + global server + server = Hf3fsMetadataServer( + persistence_path=persistence_path, save_interval=save_interval + ) + + server.run(host=host, port=port) + + +# --- Main Execution --- +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="HF3FS Metadata Server") + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host to bind the server to." + ) + parser.add_argument( + "--port", type=int, default=18000, help="Port to run the server on." + ) + parser.add_argument( + "--persistence-path", + type=str, + default=None, + help="Path to the file for persisting metadata. If not provided, persistence is disabled.", + ) + parser.add_argument( + "--save-interval", + type=int, + default=60, + help="Interval in seconds for periodically saving metadata to disk.", + ) + args = parser.parse_args() + + run_metadata_server(args.host, args.port, args.persistence_path, args.save_interval) diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index 0cc2b0a26..e7dd01c73 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -5,9 +5,9 @@ import logging import os import signal import threading -from collections import OrderedDict +from abc import ABC, abstractmethod from functools import wraps -from typing import List, Optional +from typing import List, Optional, Tuple import torch @@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient logger = logging.getLogger(__name__) +class Hf3fsMetadataInterface(ABC): + """Interface for HF3FS metadata operations.""" + + @abstractmethod + def initialize(self, rank: int, num_pages: int) -> None: + """Initialize the metadata service with specified number of pages.""" + pass + + @abstractmethod + def reserve_and_allocate_page_indices( + self, + rank: int, + keys: List[Tuple[str, str]], + ) -> List[Tuple[bool, int]]: + """ + Reserve and allocate page indices for the specified keys. + Args: + rank: The rank of the process. + keys: The keys to reserve and allocate page indices for. Each tuple contains a key and the key of its prefix block. + Returns: + List[Tuple[bool, int]]: A list of tuples, where each tuple contains a boolean indicating whether the key has existed and an integer indicating the allocated page index. + """ + pass + + @abstractmethod + def confirm_write( + self, + rank: int, + written_keys_to_confirm: List[Tuple[str, int]], + pages_to_release: List[int], + ) -> None: + """ + Confirm that key-value pairs have been successfully written to storage. + Args: + rank: The rank of the process. + written_keys_to_confirm: A list of tuples, where each tuple contains a key and its corresponding page index. + pages_to_release: A list of page indices to be released. + """ + pass + + @abstractmethod + def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]: + """ + Get page indices for the specified keys. + Args: + rank: The rank of the process. + keys: A list of keys. + Returns: + List[Optional[int]]: A list of integers representing the page indices for the specified keys. + If a key is not found, the corresponding index will be None. + """ + pass + + @abstractmethod + def delete_keys(self, rank: int, keys: List[str]) -> None: + """Delete specified keys and their associated pages.""" + pass + + @abstractmethod + def exists(self, rank: int, keys: List[str]) -> List[bool]: + """Check if the specified keys exist.""" + pass + + @abstractmethod + def clear(self, rank: int) -> None: + """Clear all key-value pairs and page allocations for the specified rank.""" + pass + + class AtomicCounter: def __init__(self, n: int): assert n > 0 @@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage): def __init__( self, + rank: int, file_path: str, file_size: int, numjobs: int, bytes_per_page: int, entries: int, dtype: torch.dtype, + metadata_client: Hf3fsMetadataInterface, ): + self.rank = rank self.file_path = file_path self.file_size = file_size self.numjobs = numjobs self.bytes_per_page = bytes_per_page self.entries = entries self.dtype = dtype + self.metadata_client = metadata_client self.numel = self.bytes_per_page // self.dtype.itemsize - self.num_pages = self.file_size // self.bytes_per_page logger.info( - "HiCacheHF3FS " - f"file_path = {self.file_path}, " - f"file_size = {self.file_size/(2**30):.2f} GB, " - f"numjobs = {self.numjobs}, " - f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, " - f"entries = {self.entries}, " - f"num_pages = {self.num_pages}" + f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: " + f"file_path={self.file_path}, " + f"file_size={self.file_size / (2 ** 30):.2f} GB, " + f"num_pages={self.num_pages}" ) self.ac = AtomicCounter(self.numjobs) @@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage): for _ in range(numjobs) ] self.executor = concurrent.futures.ThreadPoolExecutor( - max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS" + max_workers=self.numjobs, thread_name_prefix=f"HiCacheHF3FS-Rank{self.rank}" ) - # Implemented a preliminary single-file page_hash -> file_offset index as interim storage. - # Future iterations may adopt a global KVCache manager to coordinate external cache instances - # through centralized metadata orchestration. + self.metadata_client.initialize(self.rank, self.num_pages) self.lock = threading.RLock() - self.free_pages = list(range(self.num_pages)) - self.key_to_index = OrderedDict() atexit.register(self.close) @@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage): def from_env_config( rank: int, bytes_per_page: int, dtype: torch.dtype ) -> "HiCacheHF3FS": + from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import ( + Hf3fsGlobalMetadataClient, + Hf3fsLocalMetadataClient, + ) + config_path = os.getenv(HiCacheHF3FS.default_env_var) if not config_path: return HiCacheHF3FS( + rank=rank, file_path=f"/data/hicache.{rank}.bin", file_size=1 << 40, numjobs=16, bytes_per_page=bytes_per_page, entries=8, dtype=dtype, + metadata_client=Hf3fsLocalMetadataClient(), ) try: @@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage): except Exception as e: raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}") + # Check required keys (metadata_server_url is now optional) required_keys = { "file_path_prefix", "file_size", @@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage): if missing_keys: raise ValueError(f"Missing required keys in config: {missing_keys}") + # Choose metadata client based on configuration + if "metadata_server_url" in config and config["metadata_server_url"]: + # Use global metadata client to connect to metadata server + metadata_server_url = config["metadata_server_url"] + metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url) + logger.info( + f"Using global metadata client with server url: {metadata_server_url}" + ) + else: + # Use local metadata client for single-machine deployment + metadata_client = Hf3fsLocalMetadataClient() + return HiCacheHF3FS( + rank=rank, file_path=f"{config['file_path_prefix']}.{rank}.bin", file_size=int(config["file_size"]), numjobs=int(config["numjobs"]), bytes_per_page=bytes_per_page, entries=int(config["entries"]), dtype=dtype, + metadata_client=metadata_client, ) def get( self, key: str, target_location: Optional[torch.Tensor] = None ) -> torch.Tensor | None: - return self.batch_get([key], target_location)[0] + return self.batch_get([key], [target_location] if target_location else None)[0] @synchronized() def batch_get( @@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage): keys: List[str], target_locations: Optional[List[torch.Tensor]] = None, ) -> List[torch.Tensor | None]: + page_indices = self.metadata_client.get_page_indices(self.rank, keys) + batch_indices, file_offsets = [], [] - for i, key in enumerate(keys): - if key not in self.key_to_index: - continue - batch_indices.append(i) - file_offsets.append(self.key_to_index[key] * self.bytes_per_page) - self.key_to_index.move_to_end(key) - # TODO: target_locations + for i, page_index in enumerate(page_indices): + if page_index is not None: + batch_indices.append(i) + file_offsets.append(page_index * self.bytes_per_page) + file_results = [ torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices)) ] @@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage): if read_result == self.bytes_per_page: results[batch_index] = file_result else: - logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed") + logger.error( + f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed" + ) return results @@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage): return self.batch_set([key], [value]) def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: - indices = self.get_batch_set_indices(keys) + # Todo: Add prefix block's hash key + key_with_prefix = [(key, "") for key in keys] + indices = self.metadata_client.reserve_and_allocate_page_indices( + self.rank, key_with_prefix + ) + batch_indices, file_offsets, file_values = [], [], [] - for i, (value, (is_written, index)) in enumerate(zip(values, indices)): - if is_written or index == -1: + pages_to_release = [] + + for i, (value, (is_written, page_index)) in enumerate(zip(values, indices)): + if is_written or page_index == -1: continue + batch_indices.append(i) - file_offsets.append(index * self.bytes_per_page) + file_offsets.append(page_index * self.bytes_per_page) file_values.append(value.contiguous()) futures = [ @@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage): for result in future.result() ] + written_keys_to_confirm = [] results = [index[0] for index in indices] for batch_index, write_result in zip(batch_indices, write_results): key = keys[batch_index] - index = indices[batch_index][1] + page_index = indices[batch_index][1] if write_result: - self.key_to_index[key] = index - self.key_to_index.move_to_end(key) + written_keys_to_confirm.append((key, page_index)) else: - logger.error(f"HiCacheHF3FS set {key} failed") - self.free_pages.append(index) + logger.error(f"[Rank {self.rank}] HiCacheHF3FS set {key} failed") + pages_to_release.append(page_index) results[batch_index] = write_result + + if len(written_keys_to_confirm) > 0 or len(pages_to_release) > 0: + self.metadata_client.confirm_write( + self.rank, written_keys_to_confirm, pages_to_release + ) + return all(results) - @synchronized() - def get_batch_set_indices(self, keys: List[str]) -> list: - ionum = len(keys) - # results: tuples of (is_written: bool, page_idx: int) - # - is_written: True = hit (no I/O), False = write (miss) - # - page_idx: page storing data - results = [None] * min(ionum, self.num_pages) - if ionum > self.num_pages: - results.extend([(False, -1)] * (ionum - self.num_pages)) - - new_keys = [] - for batch_index, key in enumerate(keys[: self.num_pages]): - if key in self.key_to_index: - results[batch_index] = (True, self.key_to_index[key]) - self.key_to_index.move_to_end(key) - else: - new_keys.append((batch_index, key)) - - for batch_index, _ in new_keys: - index = ( - self.free_pages.pop() - if len(self.free_pages) > 0 - else self.key_to_index.popitem(last=False)[1] - ) - results[batch_index] = (False, index) - - return results - @synchronized() def delete(self, key: str) -> None: - if key not in self.key_to_index: - return - index = self.key_to_index.pop(key) - self.free_pages.append(index) + self.metadata_client.delete_keys(self.rank, [key]) @synchronized() def exists(self, key: str) -> bool: - return key in self.key_to_index + result = self.metadata_client.exists(self.rank, [key]) + return result[0] if result else False @synchronized() def clear(self) -> None: - self.free_pages = list(range(self.num_pages)) - self.key_to_index.clear() + self.metadata_client.clear(self.rank) def close(self) -> None: try: