[HiCacheStorage]: Improve 3fs kvstore‘s performance and resolve mla issues (#9876)

This commit is contained in:
hzh0425
2025-09-02 10:01:48 +08:00
committed by GitHub
parent cb9e0e4180
commit 58d06fdc95
2 changed files with 88 additions and 40 deletions

View File

@@ -4,10 +4,12 @@ import json
import logging
import threading
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, OrderedDict, Tuple
import orjson
import requests
from fastapi import FastAPI, HTTPException, Request, status
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import ORJSONResponse
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
@@ -24,10 +26,10 @@ class RankMetadata:
"""Holds all metadata for a single rank."""
def __init__(self, num_pages: int):
self.lock = threading.RLock()
self.lock = threading.Lock()
self.num_pages = num_pages
self.free_pages: List[int] = list(range(num_pages))
self.key_to_index: Dict[str, int] = {}
self.key_to_index: OrderedDict[str, int] = OrderedDict()
# Todo: Support multi files for HF3FS
def exists_keys(self, keys: List[str]) -> List[bool]:
@@ -46,16 +48,18 @@ class RankMetadata:
for i, (key, prefix_key) in enumerate(keys):
if key in self.key_to_index:
results[i] = (True, self.key_to_index[key])
self.key_to_index.move_to_end(key)
else:
new_keys_to_process.append((i, key, prefix_key))
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
for i, key, prefix_key in new_keys_to_process:
if len(self.free_pages) > 0:
page_idx = self.free_pages.pop()
results[i] = (False, page_idx)
page_index = self.free_pages.pop()
else:
results[i] = (False, -1)
page_index = self.key_to_index.popitem(last=False)[1]
results[i] = (False, page_index)
return results
@@ -68,6 +72,7 @@ class RankMetadata:
with self.lock:
for key, page_index in written_keys_to_confirm:
self.key_to_index[key] = page_index
self.key_to_index.move_to_end(key)
for page_index in pages_to_release:
if page_index not in self.free_pages:
@@ -94,7 +99,14 @@ class RankMetadata:
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
"""Get page indices for keys."""
with self.lock:
return [self.key_to_index.get(key) for key in keys]
results = []
for key in keys:
if key in self.key_to_index:
results.append(self.key_to_index[key])
self.key_to_index.move_to_end(key)
else:
results.append(None)
return results
class GlobalMetadataState:
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
self.state = GlobalMetadataState(persistence_path, save_interval)
self.app = FastAPI()
self.app = FastAPI(default_response_class=ORJSONResponse)
self._setup_routes()
def _setup_routes(self):
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
def get_rank_metadata(self, rank: int) -> RankMetadata:
"""Get rank metadata with proper error handling."""
with self.state.global_lock:
if rank not in self.state.ranks:
raise HTTPException(
status_code=404,
detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
)
return self.state.ranks[rank]
if rank not in self.state.ranks:
raise HTTPException(
status_code=404,
detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
)
return self.state.ranks[rank]
async def _read_json(self, request: Request) -> dict:
"""Parse request JSON using orjson if available."""
body = await request.body()
return orjson.loads(body)
def _json_response(self, content: dict):
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
return ORJSONResponse(content)
async def initialize(self, rank: int, request: Request):
"""Initialize a rank with specified number of pages."""
data = await request.json()
data = await self._read_json(request)
num_pages = data["num_pages"]
with self.state.global_lock:
if rank in self.state.ranks:
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
else:
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
self.state.ranks[rank] = RankMetadata(num_pages)
return {"message": f"Rank {rank} is ready."}
return Response(status_code=204)
async def exists(self, rank: int, request: Request):
"""Check if keys exist in metadata."""
data = await request.json()
data = await self._read_json(request)
keys = data["keys"]
metadata = self.get_rank_metadata(rank)
results = metadata.exists_keys(keys)
return {"exists": results}
return self._json_response({"exists": results})
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
"""Reserve and allocate page indices for keys."""
data = await request.json()
data = await self._read_json(request)
metadata = self.get_rank_metadata(rank)
keys = data["keys"]
results = metadata.reserve_and_allocate_page_indices(keys)
return {"indices": results}
return self._json_response({"indices": results})
async def confirm_write(self, rank: int, request: Request):
"""Confirm write operations and release pages."""
data = await request.json()
data = await self._read_json(request)
metadata = self.get_rank_metadata(rank)
success_written_keys = data.get("written_keys_to_confirm", [])
released_pages = data.get("pages_to_release", [])
metadata.confirm_write(success_written_keys, released_pages)
return {
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
}
return Response(status_code=204)
async def delete_keys(self, rank: int, request: Request):
"""Delete keys from metadata."""
data = await request.json()
data = await self._read_json(request)
metadata = self.get_rank_metadata(rank)
count = metadata.delete_keys(data["keys"])
return {"message": f"Rank {rank}: {count} keys deleted."}
return Response(status_code=204)
async def clear(self, rank: int):
"""Clear all metadata for a rank."""
metadata = self.get_rank_metadata(rank)
metadata.clear_all()
return {"message": f"Rank {rank}: Metadata cleared."}
return Response(status_code=204)
async def get_page_indices(self, rank: int, request: Request):
"""Get page indices for keys."""
data = await request.json()
data = await self._read_json(request)
metadata = self.get_rank_metadata(rank)
keys = data["keys"]
results = metadata.get_page_indices(keys)
return {"indices": results}
return self._json_response({"indices": results})
def run(self, host: str = "0.0.0.0", port: int = 18000):
"""Run the metadata server."""
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
status_forcelist=[500, 502, 503, 504],
allowed_methods=["GET", "POST"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
adapter = HTTPAdapter(
max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
)
self._session.mount("http://", adapter)
def _post(self, endpoint: str, json_data: dict) -> dict:
try:
response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
url = f"{self.base_url}/{endpoint}"
headers = {"Content-Type": "application/json"}
payload = orjson.dumps(json_data) # type: ignore[union-attr]
response = self._session.post(url, data=payload, headers=headers)
response.raise_for_status()
return response.json()
if response.status_code == 204 or not response.content:
return {}
return orjson.loads(response.content) # type: ignore[union-attr]
except requests.exceptions.RequestException as e:
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e

View File

@@ -113,6 +113,8 @@ def synchronized():
class HiCacheHF3FS(HiCacheStorage):
"""HiCache backend that stores KV cache pages in HF3FS files."""
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
def __init__(
@@ -176,15 +178,32 @@ class HiCacheHF3FS(HiCacheStorage):
dtype: torch.dtype,
storage_config: HiCacheStorageConfig = None,
) -> "HiCacheHF3FS":
"""Create a HiCacheHF3FS instance from environment configuration.
Environment:
- Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
- Falls back to a local single-machine config when the env var is not set.
Raises:
ValueError: If MLA Model is requested without global metadata server or required keys are missing.
"""
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsGlobalMetadataClient,
Hf3fsLocalMetadataClient,
)
rank = storage_config.tp_rank if storage_config is not None else 0
if storage_config is not None:
rank, is_mla_model = storage_config.tp_rank, storage_config.is_mla_model
else:
rank, is_mla_model = 0, False
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
config_path = os.getenv(HiCacheHF3FS.default_env_var)
if not config_path:
if is_mla_model:
raise ValueError(mla_unsupported_msg)
return HiCacheHF3FS(
rank=rank,
file_path=f"/data/hicache.{rank}.bin",
@@ -214,25 +233,27 @@ class HiCacheHF3FS(HiCacheStorage):
raise ValueError(f"Missing required keys in config: {missing_keys}")
# Choose metadata client based on configuration
is_mla_model = False
if "metadata_server_url" in config and config["metadata_server_url"]:
if config.get("metadata_server_url"):
# Use global metadata client to connect to metadata server
metadata_server_url = config["metadata_server_url"]
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
# Enable MLA optimization only when using the global metadata client
is_mla_model = storage_config.is_mla_model if storage_config else False
logger.info(
f"Using global metadata client with server url: {metadata_server_url}"
)
else:
# Enable MLA optimization only when using the global metadata client
if is_mla_model:
raise ValueError(mla_unsupported_msg)
# Use local metadata client for single-machine deployment
metadata_client = Hf3fsLocalMetadataClient()
rank_for_path = 0 if is_mla_model else rank
return HiCacheHF3FS(
rank=rank,
# Let all ranks use the same file path for MLA model
file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
file_size=int(config["file_size"]),
numjobs=int(config["numjobs"]),
bytes_per_page=bytes_per_page,