feature(hicache): Support hf3fs-hicache reusing kvcache across different instances (#8673)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
29
python/sglang/srt/mem_cache/storage/hf3fs/docs/README.md
Normal file
29
python/sglang/srt/mem_cache/storage/hf3fs/docs/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# HF3FS as L3 KV Cache
|
||||
|
||||
This document describes how to use deepseek-hf3fs as the L3 KV cache for SGLang.
|
||||
|
||||
## Step1: Install deepseek-3fs by 3fs-Operator (Coming Soon)
|
||||
|
||||
## Step2: Setup usrbio client
|
||||
|
||||
Please follow the document [setup_usrbio_client.md](setup_usrbio_client.md) to setup usrbio client.
|
||||
|
||||
## Step3: Deployment
|
||||
|
||||
### Single node deployment
|
||||
|
||||
```bash
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
||||
python3 -m sglang.launch_server \
|
||||
--model-path /code/models/Qwen3-32B/ \
|
||||
--host 0.0.0.0 --port 10000 \
|
||||
--page-size 64 \
|
||||
--enable-hierarchical-cache \
|
||||
--hicache-ratio 2 --hicache-size 0 \
|
||||
--hicache-write-policy write_through \
|
||||
--hicache-storage-backend hf3fs
|
||||
```
|
||||
|
||||
### Multi nodes deployment to share KV cache
|
||||
|
||||
Please follow the document [deploy_sglang_3fs_multinode.md](deploy_sglang_3fs_multinode.md) to deploy SGLang with 3FS on multiple nodes to share KV cache.
|
||||
@@ -0,0 +1,65 @@
|
||||
# 1. Startup 3fs metadata service
|
||||
```bash
|
||||
nohup python3 -m sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server > meta.out &
|
||||
```
|
||||
|
||||
|
||||
# 2. Startup sglang engine
|
||||
## HF3fs configures
|
||||
```bash
|
||||
vim /sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json
|
||||
{
|
||||
"file_path_prefix": "/data/hicache",
|
||||
"file_size": 1099511627776,
|
||||
"numjobs": 16,
|
||||
"entries": 8,
|
||||
"metadata_server_url": "http://metaServerIp:18000"
|
||||
}
|
||||
```
|
||||
|
||||
## node1
|
||||
```bash
|
||||
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
||||
rm -rf instance1.out && \
|
||||
nohup python3 -m sglang.launch_server \
|
||||
--model-path /code/models/Qwen3-32B/ \
|
||||
--host 0.0.0.0 --port 10000 \
|
||||
--page-size 64 \
|
||||
--enable-hierarchical-cache \
|
||||
--hicache-ratio 2 --hicache-size 0 \
|
||||
--hicache-write-policy write_through \
|
||||
--hicache-storage-backend hf3fs --tp 2 > instance1.out &
|
||||
```
|
||||
|
||||
## node2
|
||||
```bash
|
||||
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
||||
rm -rf instance2.out && \
|
||||
nohup python3 -m sglang.launch_server \
|
||||
--model-path /code/models/Qwen3-32B/ \
|
||||
--host 0.0.0.0 --port 10000 \
|
||||
--page-size 64 \
|
||||
--enable-hierarchical-cache \
|
||||
--hicache-ratio 2 --hicache-size 0 \
|
||||
--hicache-write-policy write_through \
|
||||
--hicache-storage-backend hf3fs --tp 2 > instance2.out &
|
||||
```
|
||||
|
||||
# 3. Startup router
|
||||
```bash
|
||||
rm -rf router.out && \
|
||||
nohup python -m sglang_router.launch_router --worker-urls http://node1:10000 http://node2:10000 > router.out &
|
||||
```
|
||||
|
||||
# 4. Startup multiturn benchmark
|
||||
```bash
|
||||
rm -rf bench_multiturn.out && \
|
||||
nohup python3 benchmark/hicache/bench_multiturn.py \
|
||||
--model-path /code/models/Qwen3-32B \
|
||||
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--port 30000 \
|
||||
--request-length 2048 --num-clients 512 --num-rounds 5 --max-parallel 8 \
|
||||
> bench_multiturn.out &
|
||||
```
|
||||
@@ -0,0 +1,443 @@
|
||||
import argparse
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import Hf3fsMetadataInterface
|
||||
|
||||
# --- Configuration ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
# --- Data Models ---
|
||||
class RankMetadata:
|
||||
"""Holds all metadata for a single rank."""
|
||||
|
||||
def __init__(self, num_pages: int):
|
||||
self.lock = threading.RLock()
|
||||
self.num_pages = num_pages
|
||||
self.free_pages: List[int] = list(range(num_pages))
|
||||
self.key_to_index: Dict[str, int] = {}
|
||||
# Todo: Support multi files for HF3FS
|
||||
|
||||
def exists_keys(self, keys: List[str]) -> List[bool]:
|
||||
"""Check if keys exist in metadata."""
|
||||
with self.lock:
|
||||
return [key in self.key_to_index for key in keys]
|
||||
|
||||
def reserve_and_allocate_page_indices(
|
||||
self, keys: List[Tuple[str, str]]
|
||||
) -> List[Tuple[bool, int]]:
|
||||
"""Reserve and allocate page indices for keys."""
|
||||
with self.lock:
|
||||
results = [None] * len(keys)
|
||||
new_keys_to_process = []
|
||||
|
||||
for i, (key, prefix_key) in enumerate(keys):
|
||||
if key in self.key_to_index:
|
||||
results[i] = (True, self.key_to_index[key])
|
||||
else:
|
||||
new_keys_to_process.append((i, key, prefix_key))
|
||||
|
||||
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
||||
for i, key, prefix_key in new_keys_to_process:
|
||||
if len(self.free_pages) > 0:
|
||||
page_idx = self.free_pages.pop()
|
||||
results[i] = (False, page_idx)
|
||||
else:
|
||||
results[i] = (False, -1)
|
||||
|
||||
return results
|
||||
|
||||
def confirm_write(
|
||||
self,
|
||||
written_keys_to_confirm: List[Tuple[str, int]],
|
||||
pages_to_release: List[int],
|
||||
) -> None:
|
||||
"""Confirm write operations and release pages."""
|
||||
with self.lock:
|
||||
for key, page_index in written_keys_to_confirm:
|
||||
self.key_to_index[key] = page_index
|
||||
|
||||
for page_index in pages_to_release:
|
||||
if page_index not in self.free_pages:
|
||||
self.free_pages.append(page_index)
|
||||
|
||||
def delete_keys(self, keys: List[str]) -> int:
|
||||
"""Delete keys and return count of deleted keys."""
|
||||
with self.lock:
|
||||
count = 0
|
||||
for key in keys:
|
||||
if key in self.key_to_index:
|
||||
page_index = self.key_to_index.pop(key)
|
||||
if page_index not in self.free_pages:
|
||||
self.free_pages.append(page_index)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all metadata."""
|
||||
with self.lock:
|
||||
self.free_pages = list(range(self.num_pages))
|
||||
self.key_to_index.clear()
|
||||
|
||||
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
||||
"""Get page indices for keys."""
|
||||
with self.lock:
|
||||
return [self.key_to_index.get(key) for key in keys]
|
||||
|
||||
|
||||
class GlobalMetadataState:
|
||||
"""Manages the state for all ranks and persistence."""
|
||||
|
||||
def __init__(self, persistence_path: Optional[str], save_interval: int):
|
||||
self.global_lock = threading.RLock()
|
||||
self.ranks: Dict[int, RankMetadata] = {}
|
||||
self.persistence_path = Path(persistence_path) if persistence_path else None
|
||||
self.save_interval = save_interval
|
||||
self.save_timer: Optional[threading.Timer] = None
|
||||
self.is_shutting_down = False
|
||||
|
||||
def load_from_disk(self):
|
||||
if not self.persistence_path or not self.persistence_path.exists():
|
||||
logging.info("Persistence file not found. Starting with a clean state.")
|
||||
return
|
||||
|
||||
logging.info(f"Loading state from {self.persistence_path}")
|
||||
try:
|
||||
with open(self.persistence_path, "r") as f:
|
||||
persisted_data = json.load(f)
|
||||
|
||||
with self.global_lock:
|
||||
for rank_id_str, data in persisted_data.items():
|
||||
rank_id = int(rank_id_str)
|
||||
num_pages = data["num_pages"]
|
||||
rank_meta = RankMetadata(num_pages)
|
||||
rank_meta.free_pages = data["free_pages"]
|
||||
rank_meta.key_to_index = dict(data["key_to_index"])
|
||||
self.ranks[rank_id] = rank_meta
|
||||
logging.info(
|
||||
f"Successfully loaded metadata for {len(self.ranks)} ranks."
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logging.error(
|
||||
f"Failed to load or parse persistence file: {e}. Starting fresh.",
|
||||
exc_info=True,
|
||||
)
|
||||
self.ranks.clear()
|
||||
|
||||
def save_to_disk(self):
|
||||
if not self.persistence_path:
|
||||
return
|
||||
|
||||
logging.info("Persisting metadata to disk...")
|
||||
with self.global_lock:
|
||||
serializable_state = {}
|
||||
for rank_id, rank_meta in self.ranks.items():
|
||||
with rank_meta.lock:
|
||||
serializable_state[rank_id] = {
|
||||
"num_pages": rank_meta.num_pages,
|
||||
"free_pages": rank_meta.free_pages,
|
||||
"key_to_index": list(rank_meta.key_to_index.items()),
|
||||
}
|
||||
|
||||
try:
|
||||
temp_path = self.persistence_path.with_suffix(".tmp")
|
||||
with open(temp_path, "w") as f:
|
||||
json.dump(serializable_state, f, indent=4)
|
||||
temp_path.rename(self.persistence_path)
|
||||
logging.info(f"Metadata successfully persisted to {self.persistence_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to save metadata to disk: {e}", exc_info=True)
|
||||
|
||||
def schedule_save(self):
|
||||
if self.is_shutting_down or not self.persistence_path:
|
||||
return
|
||||
self.save_to_disk()
|
||||
self.save_timer = threading.Timer(self.save_interval, self.schedule_save)
|
||||
self.save_timer.start()
|
||||
|
||||
def shutdown(self):
|
||||
logging.info("Shutting down metadata server...")
|
||||
self.is_shutting_down = True
|
||||
if self.save_timer:
|
||||
self.save_timer.cancel()
|
||||
self.save_to_disk()
|
||||
logging.info("Shutdown complete.")
|
||||
|
||||
|
||||
# --- Global MetadataServer implementation ---
|
||||
class Hf3fsMetadataServer:
|
||||
"""HF3FS Metadata Server that manages metadata for multiple ranks."""
|
||||
|
||||
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
||||
self.state = GlobalMetadataState(persistence_path, save_interval)
|
||||
self.app = FastAPI()
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Setup FastAPI routes."""
|
||||
self.app.post("/{rank}/initialize")(self.initialize)
|
||||
self.app.post("/{rank}/exists")(self.exists)
|
||||
self.app.post("/{rank}/reserve_and_allocate_page_indices")(
|
||||
self.reserve_and_allocate_page_indices
|
||||
)
|
||||
self.app.post("/{rank}/confirm_write")(self.confirm_write)
|
||||
self.app.post("/{rank}/delete_keys")(self.delete_keys)
|
||||
self.app.post("/{rank}/clear")(self.clear)
|
||||
self.app.post("/{rank}/get_page_indices")(self.get_page_indices)
|
||||
|
||||
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
||||
"""Get rank metadata with proper error handling."""
|
||||
with self.state.global_lock:
|
||||
if rank not in self.state.ranks:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
|
||||
)
|
||||
return self.state.ranks[rank]
|
||||
|
||||
async def initialize(self, rank: int, request: Request):
|
||||
"""Initialize a rank with specified number of pages."""
|
||||
data = await request.json()
|
||||
num_pages = data["num_pages"]
|
||||
with self.state.global_lock:
|
||||
if rank in self.state.ranks:
|
||||
logging.info(
|
||||
f"Rank {rank} already exists. Initialization request ignored."
|
||||
)
|
||||
if self.state.ranks[rank].num_pages != num_pages:
|
||||
logging.warning(
|
||||
f"Rank {rank} initialized with different num_pages. Existing: {self.state.ranks[rank].num_pages}, New: {num_pages}"
|
||||
)
|
||||
else:
|
||||
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
||||
self.state.ranks[rank] = RankMetadata(num_pages)
|
||||
return {"message": f"Rank {rank} is ready."}
|
||||
|
||||
async def exists(self, rank: int, request: Request):
|
||||
"""Check if keys exist in metadata."""
|
||||
data = await request.json()
|
||||
keys = data["keys"]
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
results = metadata.exists_keys(keys)
|
||||
return {"exists": results}
|
||||
|
||||
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
||||
"""Reserve and allocate page indices for keys."""
|
||||
data = await request.json()
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
keys = data["keys"]
|
||||
results = metadata.reserve_and_allocate_page_indices(keys)
|
||||
return {"indices": results}
|
||||
|
||||
async def confirm_write(self, rank: int, request: Request):
|
||||
"""Confirm write operations and release pages."""
|
||||
data = await request.json()
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
success_written_keys = data.get("written_keys_to_confirm", [])
|
||||
released_pages = data.get("pages_to_release", [])
|
||||
|
||||
metadata.confirm_write(success_written_keys, released_pages)
|
||||
|
||||
return {
|
||||
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
|
||||
}
|
||||
|
||||
async def delete_keys(self, rank: int, request: Request):
|
||||
"""Delete keys from metadata."""
|
||||
data = await request.json()
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
count = metadata.delete_keys(data["keys"])
|
||||
return {"message": f"Rank {rank}: {count} keys deleted."}
|
||||
|
||||
async def clear(self, rank: int):
|
||||
"""Clear all metadata for a rank."""
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
metadata.clear_all()
|
||||
return {"message": f"Rank {rank}: Metadata cleared."}
|
||||
|
||||
async def get_page_indices(self, rank: int, request: Request):
|
||||
"""Get page indices for keys."""
|
||||
data = await request.json()
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
keys = data["keys"]
|
||||
results = metadata.get_page_indices(keys)
|
||||
return {"indices": results}
|
||||
|
||||
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
||||
"""Run the metadata server."""
|
||||
self.state.load_from_disk()
|
||||
if self.state.persistence_path:
|
||||
self.state.schedule_save()
|
||||
atexit.register(self.state.shutdown)
|
||||
|
||||
import uvicorn
|
||||
|
||||
logging.info(f"Starting metadata server on http://{host}:{port}")
|
||||
if self.state.persistence_path:
|
||||
logging.info(
|
||||
f"Persistence is ENABLED. Saving to '{self.state.persistence_path}' every {self.state.save_interval} seconds."
|
||||
)
|
||||
else:
|
||||
logging.info("Persistence is DISABLED.")
|
||||
|
||||
uvicorn.run(self.app, host=host, port=port)
|
||||
|
||||
|
||||
# --- Client implementation ---
|
||||
class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
|
||||
"""Global http metadata client for HF3FS."""
|
||||
|
||||
def __init__(self, base_url: str, max_retries: int = 3):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self._session = requests.Session()
|
||||
|
||||
retry_strategy = Retry(
|
||||
total=max_retries,
|
||||
backoff_factor=0.3,
|
||||
status_forcelist=[500, 502, 503, 504],
|
||||
allowed_methods=["GET", "POST"],
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
self._session.mount("http://", adapter)
|
||||
|
||||
def _post(self, endpoint: str, json_data: dict) -> dict:
|
||||
try:
|
||||
response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
||||
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|
||||
|
||||
def initialize(self, rank: int, num_pages: int) -> None:
|
||||
self._post(f"{rank}/initialize", {"num_pages": num_pages})
|
||||
|
||||
def reserve_and_allocate_page_indices(
|
||||
self, rank: int, keys: List[Tuple[str, str]]
|
||||
) -> List[Tuple[bool, int]]:
|
||||
response = self._post(
|
||||
f"{rank}/reserve_and_allocate_page_indices", {"keys": keys}
|
||||
)
|
||||
return [tuple(item) for item in response.get("indices")]
|
||||
|
||||
def confirm_write(
|
||||
self,
|
||||
rank: int,
|
||||
written_keys_to_confirm: List[Tuple[str, int]],
|
||||
pages_to_release: List[int],
|
||||
) -> None:
|
||||
self._post(
|
||||
f"{rank}/confirm_write",
|
||||
{
|
||||
"written_keys_to_confirm": written_keys_to_confirm,
|
||||
"pages_to_release": pages_to_release,
|
||||
},
|
||||
)
|
||||
|
||||
def delete_keys(self, rank: int, keys: List[str]) -> None:
|
||||
self._post(f"{rank}/delete_keys", {"keys": keys})
|
||||
|
||||
def exists(self, rank: int, keys: List[str]) -> List[bool]:
|
||||
response = self._post(f"{rank}/exists", {"keys": keys})
|
||||
return response.get("exists", [])
|
||||
|
||||
def clear(self, rank: int) -> None:
|
||||
self._post(f"{rank}/clear", {})
|
||||
|
||||
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
|
||||
response = self._post(f"{rank}/get_page_indices", {"keys": keys})
|
||||
return response.get("indices")
|
||||
|
||||
|
||||
class Hf3fsLocalMetadataClient(Hf3fsMetadataInterface):
|
||||
"""Local metadata client that directly operates on single RankMetadata in memory without metadata server."""
|
||||
|
||||
def __init__(self):
|
||||
self.rank_metadata = None
|
||||
|
||||
def initialize(self, rank: int, num_pages: int) -> None:
|
||||
self.rank_metadata = RankMetadata(num_pages)
|
||||
|
||||
def reserve_and_allocate_page_indices(
|
||||
self, rank: int, keys: List[Tuple[str, str]]
|
||||
) -> List[Tuple[bool, int]]:
|
||||
"""Reserve and allocate page indices for keys."""
|
||||
return self.rank_metadata.reserve_and_allocate_page_indices(keys)
|
||||
|
||||
def confirm_write(
|
||||
self,
|
||||
rank: int,
|
||||
written_keys_to_confirm: List[Tuple[str, int]],
|
||||
pages_to_release: List[int],
|
||||
) -> None:
|
||||
"""Confirm write operations."""
|
||||
self.rank_metadata.confirm_write(written_keys_to_confirm, pages_to_release)
|
||||
|
||||
def delete_keys(self, rank: int, keys: List[str]) -> None:
|
||||
"""Delete keys."""
|
||||
self.rank_metadata.delete_keys(keys)
|
||||
|
||||
def exists(self, rank: int, keys: List[str]) -> List[bool]:
|
||||
"""Check if keys exist."""
|
||||
return self.rank_metadata.exists_keys(keys)
|
||||
|
||||
def clear(self, rank: int) -> None:
|
||||
"""Clear all metadata for rank."""
|
||||
self.rank_metadata.clear_all()
|
||||
|
||||
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
|
||||
"""Get page indices for keys."""
|
||||
return self.rank_metadata.get_page_indices(keys)
|
||||
|
||||
|
||||
def run_metadata_server(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 18000,
|
||||
persistence_path: Optional[str] = None,
|
||||
save_interval: int = 60,
|
||||
):
|
||||
"""Run the HF3FS metadata server."""
|
||||
global server
|
||||
server = Hf3fsMetadataServer(
|
||||
persistence_path=persistence_path, save_interval=save_interval
|
||||
)
|
||||
|
||||
server.run(host=host, port=port)
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="HF3FS Metadata Server")
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="0.0.0.0", help="Host to bind the server to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=18000, help="Port to run the server on."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--persistence-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the file for persisting metadata. If not provided, persistence is disabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-interval",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Interval in seconds for periodically saving metadata to disk.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_metadata_server(args.host, args.port, args.persistence_path, args.save_interval)
|
||||
@@ -5,9 +5,9 @@ import logging
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import wraps
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Hf3fsMetadataInterface(ABC):
|
||||
"""Interface for HF3FS metadata operations."""
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self, rank: int, num_pages: int) -> None:
|
||||
"""Initialize the metadata service with specified number of pages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reserve_and_allocate_page_indices(
|
||||
self,
|
||||
rank: int,
|
||||
keys: List[Tuple[str, str]],
|
||||
) -> List[Tuple[bool, int]]:
|
||||
"""
|
||||
Reserve and allocate page indices for the specified keys.
|
||||
Args:
|
||||
rank: The rank of the process.
|
||||
keys: The keys to reserve and allocate page indices for. Each tuple contains a key and the key of its prefix block.
|
||||
Returns:
|
||||
List[Tuple[bool, int]]: A list of tuples, where each tuple contains a boolean indicating whether the key has existed and an integer indicating the allocated page index.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def confirm_write(
|
||||
self,
|
||||
rank: int,
|
||||
written_keys_to_confirm: List[Tuple[str, int]],
|
||||
pages_to_release: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
Confirm that key-value pairs have been successfully written to storage.
|
||||
Args:
|
||||
rank: The rank of the process.
|
||||
written_keys_to_confirm: A list of tuples, where each tuple contains a key and its corresponding page index.
|
||||
pages_to_release: A list of page indices to be released.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
|
||||
"""
|
||||
Get page indices for the specified keys.
|
||||
Args:
|
||||
rank: The rank of the process.
|
||||
keys: A list of keys.
|
||||
Returns:
|
||||
List[Optional[int]]: A list of integers representing the page indices for the specified keys.
|
||||
If a key is not found, the corresponding index will be None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_keys(self, rank: int, keys: List[str]) -> None:
|
||||
"""Delete specified keys and their associated pages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, rank: int, keys: List[str]) -> List[bool]:
|
||||
"""Check if the specified keys exist."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, rank: int) -> None:
|
||||
"""Clear all key-value pairs and page allocations for the specified rank."""
|
||||
pass
|
||||
|
||||
|
||||
class AtomicCounter:
|
||||
def __init__(self, n: int):
|
||||
assert n > 0
|
||||
@@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
file_path: str,
|
||||
file_size: int,
|
||||
numjobs: int,
|
||||
bytes_per_page: int,
|
||||
entries: int,
|
||||
dtype: torch.dtype,
|
||||
metadata_client: Hf3fsMetadataInterface,
|
||||
):
|
||||
self.rank = rank
|
||||
self.file_path = file_path
|
||||
self.file_size = file_size
|
||||
self.numjobs = numjobs
|
||||
self.bytes_per_page = bytes_per_page
|
||||
self.entries = entries
|
||||
self.dtype = dtype
|
||||
self.metadata_client = metadata_client
|
||||
|
||||
self.numel = self.bytes_per_page // self.dtype.itemsize
|
||||
|
||||
self.num_pages = self.file_size // self.bytes_per_page
|
||||
|
||||
logger.info(
|
||||
"HiCacheHF3FS "
|
||||
f"file_path = {self.file_path}, "
|
||||
f"file_size = {self.file_size/(2**30):.2f} GB, "
|
||||
f"numjobs = {self.numjobs}, "
|
||||
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
|
||||
f"entries = {self.entries}, "
|
||||
f"num_pages = {self.num_pages}"
|
||||
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
||||
f"file_path={self.file_path}, "
|
||||
f"file_size={self.file_size / (2 ** 30):.2f} GB, "
|
||||
f"num_pages={self.num_pages}"
|
||||
)
|
||||
|
||||
self.ac = AtomicCounter(self.numjobs)
|
||||
@@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
for _ in range(numjobs)
|
||||
]
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
|
||||
max_workers=self.numjobs, thread_name_prefix=f"HiCacheHF3FS-Rank{self.rank}"
|
||||
)
|
||||
|
||||
# Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
|
||||
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
|
||||
# through centralized metadata orchestration.
|
||||
self.metadata_client.initialize(self.rank, self.num_pages)
|
||||
self.lock = threading.RLock()
|
||||
self.free_pages = list(range(self.num_pages))
|
||||
self.key_to_index = OrderedDict()
|
||||
|
||||
atexit.register(self.close)
|
||||
|
||||
@@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
def from_env_config(
|
||||
rank: int, bytes_per_page: int, dtype: torch.dtype
|
||||
) -> "HiCacheHF3FS":
|
||||
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||
Hf3fsGlobalMetadataClient,
|
||||
Hf3fsLocalMetadataClient,
|
||||
)
|
||||
|
||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||
if not config_path:
|
||||
return HiCacheHF3FS(
|
||||
rank=rank,
|
||||
file_path=f"/data/hicache.{rank}.bin",
|
||||
file_size=1 << 40,
|
||||
numjobs=16,
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=8,
|
||||
dtype=dtype,
|
||||
metadata_client=Hf3fsLocalMetadataClient(),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
||||
|
||||
# Check required keys (metadata_server_url is now optional)
|
||||
required_keys = {
|
||||
"file_path_prefix",
|
||||
"file_size",
|
||||
@@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
||||
|
||||
# Choose metadata client based on configuration
|
||||
if "metadata_server_url" in config and config["metadata_server_url"]:
|
||||
# Use global metadata client to connect to metadata server
|
||||
metadata_server_url = config["metadata_server_url"]
|
||||
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
||||
logger.info(
|
||||
f"Using global metadata client with server url: {metadata_server_url}"
|
||||
)
|
||||
else:
|
||||
# Use local metadata client for single-machine deployment
|
||||
metadata_client = Hf3fsLocalMetadataClient()
|
||||
|
||||
return HiCacheHF3FS(
|
||||
rank=rank,
|
||||
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
||||
file_size=int(config["file_size"]),
|
||||
numjobs=int(config["numjobs"]),
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=int(config["entries"]),
|
||||
dtype=dtype,
|
||||
metadata_client=metadata_client,
|
||||
)
|
||||
|
||||
def get(
|
||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor | None:
|
||||
return self.batch_get([key], target_location)[0]
|
||||
return self.batch_get([key], [target_location] if target_location else None)[0]
|
||||
|
||||
@synchronized()
|
||||
def batch_get(
|
||||
@@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
keys: List[str],
|
||||
target_locations: Optional[List[torch.Tensor]] = None,
|
||||
) -> List[torch.Tensor | None]:
|
||||
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
||||
|
||||
batch_indices, file_offsets = [], []
|
||||
for i, key in enumerate(keys):
|
||||
if key not in self.key_to_index:
|
||||
continue
|
||||
batch_indices.append(i)
|
||||
file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
|
||||
self.key_to_index.move_to_end(key)
|
||||
# TODO: target_locations
|
||||
for i, page_index in enumerate(page_indices):
|
||||
if page_index is not None:
|
||||
batch_indices.append(i)
|
||||
file_offsets.append(page_index * self.bytes_per_page)
|
||||
|
||||
file_results = [
|
||||
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
||||
]
|
||||
@@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
if read_result == self.bytes_per_page:
|
||||
results[batch_index] = file_result
|
||||
else:
|
||||
logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
|
||||
logger.error(
|
||||
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
return self.batch_set([key], [value])
|
||||
|
||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||
indices = self.get_batch_set_indices(keys)
|
||||
# Todo: Add prefix block's hash key
|
||||
key_with_prefix = [(key, "") for key in keys]
|
||||
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
||||
self.rank, key_with_prefix
|
||||
)
|
||||
|
||||
batch_indices, file_offsets, file_values = [], [], []
|
||||
for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
|
||||
if is_written or index == -1:
|
||||
pages_to_release = []
|
||||
|
||||
for i, (value, (is_written, page_index)) in enumerate(zip(values, indices)):
|
||||
if is_written or page_index == -1:
|
||||
continue
|
||||
|
||||
batch_indices.append(i)
|
||||
file_offsets.append(index * self.bytes_per_page)
|
||||
file_offsets.append(page_index * self.bytes_per_page)
|
||||
file_values.append(value.contiguous())
|
||||
|
||||
futures = [
|
||||
@@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
for result in future.result()
|
||||
]
|
||||
|
||||
written_keys_to_confirm = []
|
||||
results = [index[0] for index in indices]
|
||||
for batch_index, write_result in zip(batch_indices, write_results):
|
||||
key = keys[batch_index]
|
||||
index = indices[batch_index][1]
|
||||
page_index = indices[batch_index][1]
|
||||
if write_result:
|
||||
self.key_to_index[key] = index
|
||||
self.key_to_index.move_to_end(key)
|
||||
written_keys_to_confirm.append((key, page_index))
|
||||
else:
|
||||
logger.error(f"HiCacheHF3FS set {key} failed")
|
||||
self.free_pages.append(index)
|
||||
logger.error(f"[Rank {self.rank}] HiCacheHF3FS set {key} failed")
|
||||
pages_to_release.append(page_index)
|
||||
results[batch_index] = write_result
|
||||
|
||||
if len(written_keys_to_confirm) > 0 or len(pages_to_release) > 0:
|
||||
self.metadata_client.confirm_write(
|
||||
self.rank, written_keys_to_confirm, pages_to_release
|
||||
)
|
||||
|
||||
return all(results)
|
||||
|
||||
@synchronized()
|
||||
def get_batch_set_indices(self, keys: List[str]) -> list:
|
||||
ionum = len(keys)
|
||||
# results: tuples of (is_written: bool, page_idx: int)
|
||||
# - is_written: True = hit (no I/O), False = write (miss)
|
||||
# - page_idx: page storing data
|
||||
results = [None] * min(ionum, self.num_pages)
|
||||
if ionum > self.num_pages:
|
||||
results.extend([(False, -1)] * (ionum - self.num_pages))
|
||||
|
||||
new_keys = []
|
||||
for batch_index, key in enumerate(keys[: self.num_pages]):
|
||||
if key in self.key_to_index:
|
||||
results[batch_index] = (True, self.key_to_index[key])
|
||||
self.key_to_index.move_to_end(key)
|
||||
else:
|
||||
new_keys.append((batch_index, key))
|
||||
|
||||
for batch_index, _ in new_keys:
|
||||
index = (
|
||||
self.free_pages.pop()
|
||||
if len(self.free_pages) > 0
|
||||
else self.key_to_index.popitem(last=False)[1]
|
||||
)
|
||||
results[batch_index] = (False, index)
|
||||
|
||||
return results
|
||||
|
||||
@synchronized()
|
||||
def delete(self, key: str) -> None:
|
||||
if key not in self.key_to_index:
|
||||
return
|
||||
index = self.key_to_index.pop(key)
|
||||
self.free_pages.append(index)
|
||||
self.metadata_client.delete_keys(self.rank, [key])
|
||||
|
||||
@synchronized()
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self.key_to_index
|
||||
result = self.metadata_client.exists(self.rank, [key])
|
||||
return result[0] if result else False
|
||||
|
||||
@synchronized()
|
||||
def clear(self) -> None:
|
||||
self.free_pages = list(range(self.num_pages))
|
||||
self.key_to_index.clear()
|
||||
self.metadata_client.clear(self.rank)
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user