feature(hicache): Support hf3fs-hicache reusing kvcache across different instances (#8673)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
29
python/sglang/srt/mem_cache/storage/hf3fs/docs/README.md
Normal file
29
python/sglang/srt/mem_cache/storage/hf3fs/docs/README.md
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# HF3FS as L3 KV Cache
|
||||||
|
|
||||||
|
This document describes how to use deepseek-hf3fs as the L3 KV cache for SGLang.
|
||||||
|
|
||||||
|
## Step1: Install deepseek-3fs by 3fs-Operator (Coming Soon)
|
||||||
|
|
||||||
|
## Step2: Setup usrbio client
|
||||||
|
|
||||||
|
Please follow the document [setup_usrbio_client.md](setup_usrbio_client.md) to setup usrbio client.
|
||||||
|
|
||||||
|
## Step3: Deployment
|
||||||
|
|
||||||
|
### Single node deployment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
||||||
|
python3 -m sglang.launch_server \
|
||||||
|
--model-path /code/models/Qwen3-32B/ \
|
||||||
|
--host 0.0.0.0 --port 10000 \
|
||||||
|
--page-size 64 \
|
||||||
|
--enable-hierarchical-cache \
|
||||||
|
--hicache-ratio 2 --hicache-size 0 \
|
||||||
|
--hicache-write-policy write_through \
|
||||||
|
--hicache-storage-backend hf3fs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi nodes deployment to share KV cache
|
||||||
|
|
||||||
|
Please follow the document [deploy_sglang_3fs_multinode.md](deploy_sglang_3fs_multinode.md) to deploy SGLang with 3FS on multiple nodes to share KV cache.
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
# 1. Startup 3fs metadata service
|
||||||
|
```bash
|
||||||
|
nohup python3 -m sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server > meta.out &
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# 2. Startup sglang engine
|
||||||
|
## HF3fs configures
|
||||||
|
```bash
|
||||||
|
vim /sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json
|
||||||
|
{
|
||||||
|
"file_path_prefix": "/data/hicache",
|
||||||
|
"file_size": 1099511627776,
|
||||||
|
"numjobs": 16,
|
||||||
|
"entries": 8,
|
||||||
|
"metadata_server_url": "http://metaServerIp:18000"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## node1
|
||||||
|
```bash
|
||||||
|
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json
|
||||||
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
||||||
|
rm -rf instance1.out && \
|
||||||
|
nohup python3 -m sglang.launch_server \
|
||||||
|
--model-path /code/models/Qwen3-32B/ \
|
||||||
|
--host 0.0.0.0 --port 10000 \
|
||||||
|
--page-size 64 \
|
||||||
|
--enable-hierarchical-cache \
|
||||||
|
--hicache-ratio 2 --hicache-size 0 \
|
||||||
|
--hicache-write-policy write_through \
|
||||||
|
--hicache-storage-backend hf3fs --tp 2 > instance1.out &
|
||||||
|
```
|
||||||
|
|
||||||
|
## node2
|
||||||
|
```bash
|
||||||
|
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json
|
||||||
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
||||||
|
rm -rf instance2.out && \
|
||||||
|
nohup python3 -m sglang.launch_server \
|
||||||
|
--model-path /code/models/Qwen3-32B/ \
|
||||||
|
--host 0.0.0.0 --port 10000 \
|
||||||
|
--page-size 64 \
|
||||||
|
--enable-hierarchical-cache \
|
||||||
|
--hicache-ratio 2 --hicache-size 0 \
|
||||||
|
--hicache-write-policy write_through \
|
||||||
|
--hicache-storage-backend hf3fs --tp 2 > instance2.out &
|
||||||
|
```
|
||||||
|
|
||||||
|
# 3. Startup router
|
||||||
|
```bash
|
||||||
|
rm -rf router.out && \
|
||||||
|
nohup python -m sglang_router.launch_router --worker-urls http://node1:10000 http://node2:10000 > router.out &
|
||||||
|
```
|
||||||
|
|
||||||
|
# 4. Startup multiturn benchmark
|
||||||
|
```bash
|
||||||
|
rm -rf bench_multiturn.out && \
|
||||||
|
nohup python3 benchmark/hicache/bench_multiturn.py \
|
||||||
|
--model-path /code/models/Qwen3-32B \
|
||||||
|
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
|
--port 30000 \
|
||||||
|
--request-length 2048 --num-clients 512 --num-rounds 5 --max-parallel 8 \
|
||||||
|
> bench_multiturn.out &
|
||||||
|
```
|
||||||
@@ -0,0 +1,443 @@
|
|||||||
|
import argparse
|
||||||
|
import atexit
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from fastapi import FastAPI, HTTPException, Request, status
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from urllib3.util.retry import Retry
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import Hf3fsMetadataInterface
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Data Models ---
|
||||||
|
class RankMetadata:
|
||||||
|
"""Holds all metadata for a single rank."""
|
||||||
|
|
||||||
|
def __init__(self, num_pages: int):
|
||||||
|
self.lock = threading.RLock()
|
||||||
|
self.num_pages = num_pages
|
||||||
|
self.free_pages: List[int] = list(range(num_pages))
|
||||||
|
self.key_to_index: Dict[str, int] = {}
|
||||||
|
# Todo: Support multi files for HF3FS
|
||||||
|
|
||||||
|
def exists_keys(self, keys: List[str]) -> List[bool]:
|
||||||
|
"""Check if keys exist in metadata."""
|
||||||
|
with self.lock:
|
||||||
|
return [key in self.key_to_index for key in keys]
|
||||||
|
|
||||||
|
def reserve_and_allocate_page_indices(
|
||||||
|
self, keys: List[Tuple[str, str]]
|
||||||
|
) -> List[Tuple[bool, int]]:
|
||||||
|
"""Reserve and allocate page indices for keys."""
|
||||||
|
with self.lock:
|
||||||
|
results = [None] * len(keys)
|
||||||
|
new_keys_to_process = []
|
||||||
|
|
||||||
|
for i, (key, prefix_key) in enumerate(keys):
|
||||||
|
if key in self.key_to_index:
|
||||||
|
results[i] = (True, self.key_to_index[key])
|
||||||
|
else:
|
||||||
|
new_keys_to_process.append((i, key, prefix_key))
|
||||||
|
|
||||||
|
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
||||||
|
for i, key, prefix_key in new_keys_to_process:
|
||||||
|
if len(self.free_pages) > 0:
|
||||||
|
page_idx = self.free_pages.pop()
|
||||||
|
results[i] = (False, page_idx)
|
||||||
|
else:
|
||||||
|
results[i] = (False, -1)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def confirm_write(
|
||||||
|
self,
|
||||||
|
written_keys_to_confirm: List[Tuple[str, int]],
|
||||||
|
pages_to_release: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""Confirm write operations and release pages."""
|
||||||
|
with self.lock:
|
||||||
|
for key, page_index in written_keys_to_confirm:
|
||||||
|
self.key_to_index[key] = page_index
|
||||||
|
|
||||||
|
for page_index in pages_to_release:
|
||||||
|
if page_index not in self.free_pages:
|
||||||
|
self.free_pages.append(page_index)
|
||||||
|
|
||||||
|
def delete_keys(self, keys: List[str]) -> int:
|
||||||
|
"""Delete keys and return count of deleted keys."""
|
||||||
|
with self.lock:
|
||||||
|
count = 0
|
||||||
|
for key in keys:
|
||||||
|
if key in self.key_to_index:
|
||||||
|
page_index = self.key_to_index.pop(key)
|
||||||
|
if page_index not in self.free_pages:
|
||||||
|
self.free_pages.append(page_index)
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def clear_all(self) -> None:
|
||||||
|
"""Clear all metadata."""
|
||||||
|
with self.lock:
|
||||||
|
self.free_pages = list(range(self.num_pages))
|
||||||
|
self.key_to_index.clear()
|
||||||
|
|
||||||
|
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
||||||
|
"""Get page indices for keys."""
|
||||||
|
with self.lock:
|
||||||
|
return [self.key_to_index.get(key) for key in keys]
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalMetadataState:
|
||||||
|
"""Manages the state for all ranks and persistence."""
|
||||||
|
|
||||||
|
def __init__(self, persistence_path: Optional[str], save_interval: int):
|
||||||
|
self.global_lock = threading.RLock()
|
||||||
|
self.ranks: Dict[int, RankMetadata] = {}
|
||||||
|
self.persistence_path = Path(persistence_path) if persistence_path else None
|
||||||
|
self.save_interval = save_interval
|
||||||
|
self.save_timer: Optional[threading.Timer] = None
|
||||||
|
self.is_shutting_down = False
|
||||||
|
|
||||||
|
def load_from_disk(self):
|
||||||
|
if not self.persistence_path or not self.persistence_path.exists():
|
||||||
|
logging.info("Persistence file not found. Starting with a clean state.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Loading state from {self.persistence_path}")
|
||||||
|
try:
|
||||||
|
with open(self.persistence_path, "r") as f:
|
||||||
|
persisted_data = json.load(f)
|
||||||
|
|
||||||
|
with self.global_lock:
|
||||||
|
for rank_id_str, data in persisted_data.items():
|
||||||
|
rank_id = int(rank_id_str)
|
||||||
|
num_pages = data["num_pages"]
|
||||||
|
rank_meta = RankMetadata(num_pages)
|
||||||
|
rank_meta.free_pages = data["free_pages"]
|
||||||
|
rank_meta.key_to_index = dict(data["key_to_index"])
|
||||||
|
self.ranks[rank_id] = rank_meta
|
||||||
|
logging.info(
|
||||||
|
f"Successfully loaded metadata for {len(self.ranks)} ranks."
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||||
|
logging.error(
|
||||||
|
f"Failed to load or parse persistence file: {e}. Starting fresh.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
self.ranks.clear()
|
||||||
|
|
||||||
|
def save_to_disk(self):
|
||||||
|
if not self.persistence_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info("Persisting metadata to disk...")
|
||||||
|
with self.global_lock:
|
||||||
|
serializable_state = {}
|
||||||
|
for rank_id, rank_meta in self.ranks.items():
|
||||||
|
with rank_meta.lock:
|
||||||
|
serializable_state[rank_id] = {
|
||||||
|
"num_pages": rank_meta.num_pages,
|
||||||
|
"free_pages": rank_meta.free_pages,
|
||||||
|
"key_to_index": list(rank_meta.key_to_index.items()),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
temp_path = self.persistence_path.with_suffix(".tmp")
|
||||||
|
with open(temp_path, "w") as f:
|
||||||
|
json.dump(serializable_state, f, indent=4)
|
||||||
|
temp_path.rename(self.persistence_path)
|
||||||
|
logging.info(f"Metadata successfully persisted to {self.persistence_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to save metadata to disk: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def schedule_save(self):
|
||||||
|
if self.is_shutting_down or not self.persistence_path:
|
||||||
|
return
|
||||||
|
self.save_to_disk()
|
||||||
|
self.save_timer = threading.Timer(self.save_interval, self.schedule_save)
|
||||||
|
self.save_timer.start()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
logging.info("Shutting down metadata server...")
|
||||||
|
self.is_shutting_down = True
|
||||||
|
if self.save_timer:
|
||||||
|
self.save_timer.cancel()
|
||||||
|
self.save_to_disk()
|
||||||
|
logging.info("Shutdown complete.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Global MetadataServer implementation ---
|
||||||
|
class Hf3fsMetadataServer:
|
||||||
|
"""HF3FS Metadata Server that manages metadata for multiple ranks."""
|
||||||
|
|
||||||
|
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
||||||
|
self.state = GlobalMetadataState(persistence_path, save_interval)
|
||||||
|
self.app = FastAPI()
|
||||||
|
self._setup_routes()
|
||||||
|
|
||||||
|
def _setup_routes(self):
|
||||||
|
"""Setup FastAPI routes."""
|
||||||
|
self.app.post("/{rank}/initialize")(self.initialize)
|
||||||
|
self.app.post("/{rank}/exists")(self.exists)
|
||||||
|
self.app.post("/{rank}/reserve_and_allocate_page_indices")(
|
||||||
|
self.reserve_and_allocate_page_indices
|
||||||
|
)
|
||||||
|
self.app.post("/{rank}/confirm_write")(self.confirm_write)
|
||||||
|
self.app.post("/{rank}/delete_keys")(self.delete_keys)
|
||||||
|
self.app.post("/{rank}/clear")(self.clear)
|
||||||
|
self.app.post("/{rank}/get_page_indices")(self.get_page_indices)
|
||||||
|
|
||||||
|
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
||||||
|
"""Get rank metadata with proper error handling."""
|
||||||
|
with self.state.global_lock:
|
||||||
|
if rank not in self.state.ranks:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
|
||||||
|
)
|
||||||
|
return self.state.ranks[rank]
|
||||||
|
|
||||||
|
async def initialize(self, rank: int, request: Request):
|
||||||
|
"""Initialize a rank with specified number of pages."""
|
||||||
|
data = await request.json()
|
||||||
|
num_pages = data["num_pages"]
|
||||||
|
with self.state.global_lock:
|
||||||
|
if rank in self.state.ranks:
|
||||||
|
logging.info(
|
||||||
|
f"Rank {rank} already exists. Initialization request ignored."
|
||||||
|
)
|
||||||
|
if self.state.ranks[rank].num_pages != num_pages:
|
||||||
|
logging.warning(
|
||||||
|
f"Rank {rank} initialized with different num_pages. Existing: {self.state.ranks[rank].num_pages}, New: {num_pages}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
||||||
|
self.state.ranks[rank] = RankMetadata(num_pages)
|
||||||
|
return {"message": f"Rank {rank} is ready."}
|
||||||
|
|
||||||
|
async def exists(self, rank: int, request: Request):
|
||||||
|
"""Check if keys exist in metadata."""
|
||||||
|
data = await request.json()
|
||||||
|
keys = data["keys"]
|
||||||
|
metadata = self.get_rank_metadata(rank)
|
||||||
|
results = metadata.exists_keys(keys)
|
||||||
|
return {"exists": results}
|
||||||
|
|
||||||
|
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
||||||
|
"""Reserve and allocate page indices for keys."""
|
||||||
|
data = await request.json()
|
||||||
|
metadata = self.get_rank_metadata(rank)
|
||||||
|
keys = data["keys"]
|
||||||
|
results = metadata.reserve_and_allocate_page_indices(keys)
|
||||||
|
return {"indices": results}
|
||||||
|
|
||||||
|
async def confirm_write(self, rank: int, request: Request):
|
||||||
|
"""Confirm write operations and release pages."""
|
||||||
|
data = await request.json()
|
||||||
|
metadata = self.get_rank_metadata(rank)
|
||||||
|
success_written_keys = data.get("written_keys_to_confirm", [])
|
||||||
|
released_pages = data.get("pages_to_release", [])
|
||||||
|
|
||||||
|
metadata.confirm_write(success_written_keys, released_pages)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
|
||||||
|
}
|
||||||
|
|
||||||
|
async def delete_keys(self, rank: int, request: Request):
|
||||||
|
"""Delete keys from metadata."""
|
||||||
|
data = await request.json()
|
||||||
|
metadata = self.get_rank_metadata(rank)
|
||||||
|
count = metadata.delete_keys(data["keys"])
|
||||||
|
return {"message": f"Rank {rank}: {count} keys deleted."}
|
||||||
|
|
||||||
|
async def clear(self, rank: int):
|
||||||
|
"""Clear all metadata for a rank."""
|
||||||
|
metadata = self.get_rank_metadata(rank)
|
||||||
|
metadata.clear_all()
|
||||||
|
return {"message": f"Rank {rank}: Metadata cleared."}
|
||||||
|
|
||||||
|
async def get_page_indices(self, rank: int, request: Request):
|
||||||
|
"""Get page indices for keys."""
|
||||||
|
data = await request.json()
|
||||||
|
metadata = self.get_rank_metadata(rank)
|
||||||
|
keys = data["keys"]
|
||||||
|
results = metadata.get_page_indices(keys)
|
||||||
|
return {"indices": results}
|
||||||
|
|
||||||
|
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
||||||
|
"""Run the metadata server."""
|
||||||
|
self.state.load_from_disk()
|
||||||
|
if self.state.persistence_path:
|
||||||
|
self.state.schedule_save()
|
||||||
|
atexit.register(self.state.shutdown)
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
logging.info(f"Starting metadata server on http://{host}:{port}")
|
||||||
|
if self.state.persistence_path:
|
||||||
|
logging.info(
|
||||||
|
f"Persistence is ENABLED. Saving to '{self.state.persistence_path}' every {self.state.save_interval} seconds."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Persistence is DISABLED.")
|
||||||
|
|
||||||
|
uvicorn.run(self.app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Client implementation ---
|
||||||
|
class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
|
||||||
|
"""Global http metadata client for HF3FS."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, max_retries: int = 3):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self._session = requests.Session()
|
||||||
|
|
||||||
|
retry_strategy = Retry(
|
||||||
|
total=max_retries,
|
||||||
|
backoff_factor=0.3,
|
||||||
|
status_forcelist=[500, 502, 503, 504],
|
||||||
|
allowed_methods=["GET", "POST"],
|
||||||
|
)
|
||||||
|
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||||
|
self._session.mount("http://", adapter)
|
||||||
|
|
||||||
|
def _post(self, endpoint: str, json_data: dict) -> dict:
|
||||||
|
try:
|
||||||
|
response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
||||||
|
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|
||||||
|
|
||||||
|
def initialize(self, rank: int, num_pages: int) -> None:
|
||||||
|
self._post(f"{rank}/initialize", {"num_pages": num_pages})
|
||||||
|
|
||||||
|
def reserve_and_allocate_page_indices(
|
||||||
|
self, rank: int, keys: List[Tuple[str, str]]
|
||||||
|
) -> List[Tuple[bool, int]]:
|
||||||
|
response = self._post(
|
||||||
|
f"{rank}/reserve_and_allocate_page_indices", {"keys": keys}
|
||||||
|
)
|
||||||
|
return [tuple(item) for item in response.get("indices")]
|
||||||
|
|
||||||
|
def confirm_write(
|
||||||
|
self,
|
||||||
|
rank: int,
|
||||||
|
written_keys_to_confirm: List[Tuple[str, int]],
|
||||||
|
pages_to_release: List[int],
|
||||||
|
) -> None:
|
||||||
|
self._post(
|
||||||
|
f"{rank}/confirm_write",
|
||||||
|
{
|
||||||
|
"written_keys_to_confirm": written_keys_to_confirm,
|
||||||
|
"pages_to_release": pages_to_release,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_keys(self, rank: int, keys: List[str]) -> None:
|
||||||
|
self._post(f"{rank}/delete_keys", {"keys": keys})
|
||||||
|
|
||||||
|
def exists(self, rank: int, keys: List[str]) -> List[bool]:
|
||||||
|
response = self._post(f"{rank}/exists", {"keys": keys})
|
||||||
|
return response.get("exists", [])
|
||||||
|
|
||||||
|
def clear(self, rank: int) -> None:
|
||||||
|
self._post(f"{rank}/clear", {})
|
||||||
|
|
||||||
|
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
|
||||||
|
response = self._post(f"{rank}/get_page_indices", {"keys": keys})
|
||||||
|
return response.get("indices")
|
||||||
|
|
||||||
|
|
||||||
|
class Hf3fsLocalMetadataClient(Hf3fsMetadataInterface):
|
||||||
|
"""Local metadata client that directly operates on single RankMetadata in memory without metadata server."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.rank_metadata = None
|
||||||
|
|
||||||
|
def initialize(self, rank: int, num_pages: int) -> None:
|
||||||
|
self.rank_metadata = RankMetadata(num_pages)
|
||||||
|
|
||||||
|
def reserve_and_allocate_page_indices(
|
||||||
|
self, rank: int, keys: List[Tuple[str, str]]
|
||||||
|
) -> List[Tuple[bool, int]]:
|
||||||
|
"""Reserve and allocate page indices for keys."""
|
||||||
|
return self.rank_metadata.reserve_and_allocate_page_indices(keys)
|
||||||
|
|
||||||
|
def confirm_write(
|
||||||
|
self,
|
||||||
|
rank: int,
|
||||||
|
written_keys_to_confirm: List[Tuple[str, int]],
|
||||||
|
pages_to_release: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""Confirm write operations."""
|
||||||
|
self.rank_metadata.confirm_write(written_keys_to_confirm, pages_to_release)
|
||||||
|
|
||||||
|
def delete_keys(self, rank: int, keys: List[str]) -> None:
|
||||||
|
"""Delete keys."""
|
||||||
|
self.rank_metadata.delete_keys(keys)
|
||||||
|
|
||||||
|
def exists(self, rank: int, keys: List[str]) -> List[bool]:
|
||||||
|
"""Check if keys exist."""
|
||||||
|
return self.rank_metadata.exists_keys(keys)
|
||||||
|
|
||||||
|
def clear(self, rank: int) -> None:
|
||||||
|
"""Clear all metadata for rank."""
|
||||||
|
self.rank_metadata.clear_all()
|
||||||
|
|
||||||
|
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
|
||||||
|
"""Get page indices for keys."""
|
||||||
|
return self.rank_metadata.get_page_indices(keys)
|
||||||
|
|
||||||
|
|
||||||
|
def run_metadata_server(
|
||||||
|
host: str = "0.0.0.0",
|
||||||
|
port: int = 18000,
|
||||||
|
persistence_path: Optional[str] = None,
|
||||||
|
save_interval: int = 60,
|
||||||
|
):
|
||||||
|
"""Run the HF3FS metadata server."""
|
||||||
|
global server
|
||||||
|
server = Hf3fsMetadataServer(
|
||||||
|
persistence_path=persistence_path, save_interval=save_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
server.run(host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Execution ---
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="HF3FS Metadata Server")
|
||||||
|
parser.add_argument(
|
||||||
|
"--host", type=str, default="0.0.0.0", help="Host to bind the server to."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=18000, help="Port to run the server on."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--persistence-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the file for persisting metadata. If not provided, persistence is disabled.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-interval",
|
||||||
|
type=int,
|
||||||
|
default=60,
|
||||||
|
help="Interval in seconds for periodically saving metadata to disk.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
run_metadata_server(args.host, args.port, args.persistence_path, args.save_interval)
|
||||||
@@ -5,9 +5,9 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
from collections import OrderedDict
|
from abc import ABC, abstractmethod
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Hf3fsMetadataInterface(ABC):
|
||||||
|
"""Interface for HF3FS metadata operations."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def initialize(self, rank: int, num_pages: int) -> None:
|
||||||
|
"""Initialize the metadata service with specified number of pages."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reserve_and_allocate_page_indices(
|
||||||
|
self,
|
||||||
|
rank: int,
|
||||||
|
keys: List[Tuple[str, str]],
|
||||||
|
) -> List[Tuple[bool, int]]:
|
||||||
|
"""
|
||||||
|
Reserve and allocate page indices for the specified keys.
|
||||||
|
Args:
|
||||||
|
rank: The rank of the process.
|
||||||
|
keys: The keys to reserve and allocate page indices for. Each tuple contains a key and the key of its prefix block.
|
||||||
|
Returns:
|
||||||
|
List[Tuple[bool, int]]: A list of tuples, where each tuple contains a boolean indicating whether the key has existed and an integer indicating the allocated page index.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def confirm_write(
|
||||||
|
self,
|
||||||
|
rank: int,
|
||||||
|
written_keys_to_confirm: List[Tuple[str, int]],
|
||||||
|
pages_to_release: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Confirm that key-value pairs have been successfully written to storage.
|
||||||
|
Args:
|
||||||
|
rank: The rank of the process.
|
||||||
|
written_keys_to_confirm: A list of tuples, where each tuple contains a key and its corresponding page index.
|
||||||
|
pages_to_release: A list of page indices to be released.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
|
||||||
|
"""
|
||||||
|
Get page indices for the specified keys.
|
||||||
|
Args:
|
||||||
|
rank: The rank of the process.
|
||||||
|
keys: A list of keys.
|
||||||
|
Returns:
|
||||||
|
List[Optional[int]]: A list of integers representing the page indices for the specified keys.
|
||||||
|
If a key is not found, the corresponding index will be None.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_keys(self, rank: int, keys: List[str]) -> None:
|
||||||
|
"""Delete specified keys and their associated pages."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def exists(self, rank: int, keys: List[str]) -> List[bool]:
|
||||||
|
"""Check if the specified keys exist."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self, rank: int) -> None:
|
||||||
|
"""Clear all key-value pairs and page allocations for the specified rank."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AtomicCounter:
|
class AtomicCounter:
|
||||||
def __init__(self, n: int):
|
def __init__(self, n: int):
|
||||||
assert n > 0
|
assert n > 0
|
||||||
@@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
rank: int,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
file_size: int,
|
file_size: int,
|
||||||
numjobs: int,
|
numjobs: int,
|
||||||
bytes_per_page: int,
|
bytes_per_page: int,
|
||||||
entries: int,
|
entries: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
metadata_client: Hf3fsMetadataInterface,
|
||||||
):
|
):
|
||||||
|
self.rank = rank
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.file_size = file_size
|
self.file_size = file_size
|
||||||
self.numjobs = numjobs
|
self.numjobs = numjobs
|
||||||
self.bytes_per_page = bytes_per_page
|
self.bytes_per_page = bytes_per_page
|
||||||
self.entries = entries
|
self.entries = entries
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
self.metadata_client = metadata_client
|
||||||
|
|
||||||
self.numel = self.bytes_per_page // self.dtype.itemsize
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
||||||
|
|
||||||
self.num_pages = self.file_size // self.bytes_per_page
|
self.num_pages = self.file_size // self.bytes_per_page
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"HiCacheHF3FS "
|
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
||||||
f"file_path = {self.file_path}, "
|
f"file_path={self.file_path}, "
|
||||||
f"file_size = {self.file_size/(2**30):.2f} GB, "
|
f"file_size={self.file_size / (2 ** 30):.2f} GB, "
|
||||||
f"numjobs = {self.numjobs}, "
|
f"num_pages={self.num_pages}"
|
||||||
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
|
|
||||||
f"entries = {self.entries}, "
|
|
||||||
f"num_pages = {self.num_pages}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ac = AtomicCounter(self.numjobs)
|
self.ac = AtomicCounter(self.numjobs)
|
||||||
@@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
for _ in range(numjobs)
|
for _ in range(numjobs)
|
||||||
]
|
]
|
||||||
self.executor = concurrent.futures.ThreadPoolExecutor(
|
self.executor = concurrent.futures.ThreadPoolExecutor(
|
||||||
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
|
max_workers=self.numjobs, thread_name_prefix=f"HiCacheHF3FS-Rank{self.rank}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
|
self.metadata_client.initialize(self.rank, self.num_pages)
|
||||||
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
|
|
||||||
# through centralized metadata orchestration.
|
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self.free_pages = list(range(self.num_pages))
|
|
||||||
self.key_to_index = OrderedDict()
|
|
||||||
|
|
||||||
atexit.register(self.close)
|
atexit.register(self.close)
|
||||||
|
|
||||||
@@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
def from_env_config(
|
def from_env_config(
|
||||||
rank: int, bytes_per_page: int, dtype: torch.dtype
|
rank: int, bytes_per_page: int, dtype: torch.dtype
|
||||||
) -> "HiCacheHF3FS":
|
) -> "HiCacheHF3FS":
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||||
|
Hf3fsGlobalMetadataClient,
|
||||||
|
Hf3fsLocalMetadataClient,
|
||||||
|
)
|
||||||
|
|
||||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||||
if not config_path:
|
if not config_path:
|
||||||
return HiCacheHF3FS(
|
return HiCacheHF3FS(
|
||||||
|
rank=rank,
|
||||||
file_path=f"/data/hicache.{rank}.bin",
|
file_path=f"/data/hicache.{rank}.bin",
|
||||||
file_size=1 << 40,
|
file_size=1 << 40,
|
||||||
numjobs=16,
|
numjobs=16,
|
||||||
bytes_per_page=bytes_per_page,
|
bytes_per_page=bytes_per_page,
|
||||||
entries=8,
|
entries=8,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
metadata_client=Hf3fsLocalMetadataClient(),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
||||||
|
|
||||||
|
# Check required keys (metadata_server_url is now optional)
|
||||||
required_keys = {
|
required_keys = {
|
||||||
"file_path_prefix",
|
"file_path_prefix",
|
||||||
"file_size",
|
"file_size",
|
||||||
@@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
if missing_keys:
|
if missing_keys:
|
||||||
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
||||||
|
|
||||||
|
# Choose metadata client based on configuration
|
||||||
|
if "metadata_server_url" in config and config["metadata_server_url"]:
|
||||||
|
# Use global metadata client to connect to metadata server
|
||||||
|
metadata_server_url = config["metadata_server_url"]
|
||||||
|
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
||||||
|
logger.info(
|
||||||
|
f"Using global metadata client with server url: {metadata_server_url}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use local metadata client for single-machine deployment
|
||||||
|
metadata_client = Hf3fsLocalMetadataClient()
|
||||||
|
|
||||||
return HiCacheHF3FS(
|
return HiCacheHF3FS(
|
||||||
|
rank=rank,
|
||||||
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
||||||
file_size=int(config["file_size"]),
|
file_size=int(config["file_size"]),
|
||||||
numjobs=int(config["numjobs"]),
|
numjobs=int(config["numjobs"]),
|
||||||
bytes_per_page=bytes_per_page,
|
bytes_per_page=bytes_per_page,
|
||||||
entries=int(config["entries"]),
|
entries=int(config["entries"]),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
metadata_client=metadata_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor | None:
|
) -> torch.Tensor | None:
|
||||||
return self.batch_get([key], target_location)[0]
|
return self.batch_get([key], [target_location] if target_location else None)[0]
|
||||||
|
|
||||||
@synchronized()
|
@synchronized()
|
||||||
def batch_get(
|
def batch_get(
|
||||||
@@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
keys: List[str],
|
keys: List[str],
|
||||||
target_locations: Optional[List[torch.Tensor]] = None,
|
target_locations: Optional[List[torch.Tensor]] = None,
|
||||||
) -> List[torch.Tensor | None]:
|
) -> List[torch.Tensor | None]:
|
||||||
|
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
||||||
|
|
||||||
batch_indices, file_offsets = [], []
|
batch_indices, file_offsets = [], []
|
||||||
for i, key in enumerate(keys):
|
for i, page_index in enumerate(page_indices):
|
||||||
if key not in self.key_to_index:
|
if page_index is not None:
|
||||||
continue
|
batch_indices.append(i)
|
||||||
batch_indices.append(i)
|
file_offsets.append(page_index * self.bytes_per_page)
|
||||||
file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
|
|
||||||
self.key_to_index.move_to_end(key)
|
|
||||||
# TODO: target_locations
|
|
||||||
file_results = [
|
file_results = [
|
||||||
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
||||||
]
|
]
|
||||||
@@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
if read_result == self.bytes_per_page:
|
if read_result == self.bytes_per_page:
|
||||||
results[batch_index] = file_result
|
results[batch_index] = file_result
|
||||||
else:
|
else:
|
||||||
logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
|
logger.error(
|
||||||
|
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
return self.batch_set([key], [value])
|
return self.batch_set([key], [value])
|
||||||
|
|
||||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||||
indices = self.get_batch_set_indices(keys)
|
# Todo: Add prefix block's hash key
|
||||||
|
key_with_prefix = [(key, "") for key in keys]
|
||||||
|
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
||||||
|
self.rank, key_with_prefix
|
||||||
|
)
|
||||||
|
|
||||||
batch_indices, file_offsets, file_values = [], [], []
|
batch_indices, file_offsets, file_values = [], [], []
|
||||||
for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
|
pages_to_release = []
|
||||||
if is_written or index == -1:
|
|
||||||
|
for i, (value, (is_written, page_index)) in enumerate(zip(values, indices)):
|
||||||
|
if is_written or page_index == -1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
batch_indices.append(i)
|
batch_indices.append(i)
|
||||||
file_offsets.append(index * self.bytes_per_page)
|
file_offsets.append(page_index * self.bytes_per_page)
|
||||||
file_values.append(value.contiguous())
|
file_values.append(value.contiguous())
|
||||||
|
|
||||||
futures = [
|
futures = [
|
||||||
@@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
for result in future.result()
|
for result in future.result()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
written_keys_to_confirm = []
|
||||||
results = [index[0] for index in indices]
|
results = [index[0] for index in indices]
|
||||||
for batch_index, write_result in zip(batch_indices, write_results):
|
for batch_index, write_result in zip(batch_indices, write_results):
|
||||||
key = keys[batch_index]
|
key = keys[batch_index]
|
||||||
index = indices[batch_index][1]
|
page_index = indices[batch_index][1]
|
||||||
if write_result:
|
if write_result:
|
||||||
self.key_to_index[key] = index
|
written_keys_to_confirm.append((key, page_index))
|
||||||
self.key_to_index.move_to_end(key)
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"HiCacheHF3FS set {key} failed")
|
logger.error(f"[Rank {self.rank}] HiCacheHF3FS set {key} failed")
|
||||||
self.free_pages.append(index)
|
pages_to_release.append(page_index)
|
||||||
results[batch_index] = write_result
|
results[batch_index] = write_result
|
||||||
|
|
||||||
|
if len(written_keys_to_confirm) > 0 or len(pages_to_release) > 0:
|
||||||
|
self.metadata_client.confirm_write(
|
||||||
|
self.rank, written_keys_to_confirm, pages_to_release
|
||||||
|
)
|
||||||
|
|
||||||
return all(results)
|
return all(results)
|
||||||
|
|
||||||
@synchronized()
|
|
||||||
def get_batch_set_indices(self, keys: List[str]) -> list:
|
|
||||||
ionum = len(keys)
|
|
||||||
# results: tuples of (is_written: bool, page_idx: int)
|
|
||||||
# - is_written: True = hit (no I/O), False = write (miss)
|
|
||||||
# - page_idx: page storing data
|
|
||||||
results = [None] * min(ionum, self.num_pages)
|
|
||||||
if ionum > self.num_pages:
|
|
||||||
results.extend([(False, -1)] * (ionum - self.num_pages))
|
|
||||||
|
|
||||||
new_keys = []
|
|
||||||
for batch_index, key in enumerate(keys[: self.num_pages]):
|
|
||||||
if key in self.key_to_index:
|
|
||||||
results[batch_index] = (True, self.key_to_index[key])
|
|
||||||
self.key_to_index.move_to_end(key)
|
|
||||||
else:
|
|
||||||
new_keys.append((batch_index, key))
|
|
||||||
|
|
||||||
for batch_index, _ in new_keys:
|
|
||||||
index = (
|
|
||||||
self.free_pages.pop()
|
|
||||||
if len(self.free_pages) > 0
|
|
||||||
else self.key_to_index.popitem(last=False)[1]
|
|
||||||
)
|
|
||||||
results[batch_index] = (False, index)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
@synchronized()
|
@synchronized()
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
if key not in self.key_to_index:
|
self.metadata_client.delete_keys(self.rank, [key])
|
||||||
return
|
|
||||||
index = self.key_to_index.pop(key)
|
|
||||||
self.free_pages.append(index)
|
|
||||||
|
|
||||||
@synchronized()
|
@synchronized()
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
return key in self.key_to_index
|
result = self.metadata_client.exists(self.rank, [key])
|
||||||
|
return result[0] if result else False
|
||||||
|
|
||||||
@synchronized()
|
@synchronized()
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.free_pages = list(range(self.num_pages))
|
self.metadata_client.clear(self.rank)
|
||||||
self.key_to_index.clear()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user