[HiCacheStorage]: Improve 3fs kvstore‘s performance and resolve mla issues (#9876)
This commit is contained in:
@@ -4,10 +4,12 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, OrderedDict, Tuple
|
||||
|
||||
import orjson
|
||||
import requests
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
@@ -24,10 +26,10 @@ class RankMetadata:
|
||||
"""Holds all metadata for a single rank."""
|
||||
|
||||
def __init__(self, num_pages: int):
|
||||
self.lock = threading.RLock()
|
||||
self.lock = threading.Lock()
|
||||
self.num_pages = num_pages
|
||||
self.free_pages: List[int] = list(range(num_pages))
|
||||
self.key_to_index: Dict[str, int] = {}
|
||||
self.key_to_index: OrderedDict[str, int] = OrderedDict()
|
||||
# Todo: Support multi files for HF3FS
|
||||
|
||||
def exists_keys(self, keys: List[str]) -> List[bool]:
|
||||
@@ -46,16 +48,18 @@ class RankMetadata:
|
||||
for i, (key, prefix_key) in enumerate(keys):
|
||||
if key in self.key_to_index:
|
||||
results[i] = (True, self.key_to_index[key])
|
||||
self.key_to_index.move_to_end(key)
|
||||
else:
|
||||
new_keys_to_process.append((i, key, prefix_key))
|
||||
|
||||
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
||||
for i, key, prefix_key in new_keys_to_process:
|
||||
if len(self.free_pages) > 0:
|
||||
page_idx = self.free_pages.pop()
|
||||
results[i] = (False, page_idx)
|
||||
page_index = self.free_pages.pop()
|
||||
else:
|
||||
results[i] = (False, -1)
|
||||
page_index = self.key_to_index.popitem(last=False)[1]
|
||||
|
||||
results[i] = (False, page_index)
|
||||
|
||||
return results
|
||||
|
||||
@@ -68,6 +72,7 @@ class RankMetadata:
|
||||
with self.lock:
|
||||
for key, page_index in written_keys_to_confirm:
|
||||
self.key_to_index[key] = page_index
|
||||
self.key_to_index.move_to_end(key)
|
||||
|
||||
for page_index in pages_to_release:
|
||||
if page_index not in self.free_pages:
|
||||
@@ -94,7 +99,14 @@ class RankMetadata:
|
||||
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
||||
"""Get page indices for keys."""
|
||||
with self.lock:
|
||||
return [self.key_to_index.get(key) for key in keys]
|
||||
results = []
|
||||
for key in keys:
|
||||
if key in self.key_to_index:
|
||||
results.append(self.key_to_index[key])
|
||||
self.key_to_index.move_to_end(key)
|
||||
else:
|
||||
results.append(None)
|
||||
return results
|
||||
|
||||
|
||||
class GlobalMetadataState:
|
||||
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
|
||||
|
||||
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
||||
self.state = GlobalMetadataState(persistence_path, save_interval)
|
||||
self.app = FastAPI()
|
||||
self.app = FastAPI(default_response_class=ORJSONResponse)
|
||||
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self):
|
||||
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
|
||||
|
||||
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
||||
"""Get rank metadata with proper error handling."""
|
||||
with self.state.global_lock:
|
||||
if rank not in self.state.ranks:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
|
||||
)
|
||||
return self.state.ranks[rank]
|
||||
if rank not in self.state.ranks:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
|
||||
)
|
||||
return self.state.ranks[rank]
|
||||
|
||||
async def _read_json(self, request: Request) -> dict:
|
||||
"""Parse request JSON using orjson if available."""
|
||||
body = await request.body()
|
||||
return orjson.loads(body)
|
||||
|
||||
def _json_response(self, content: dict):
|
||||
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
|
||||
return ORJSONResponse(content)
|
||||
|
||||
async def initialize(self, rank: int, request: Request):
|
||||
"""Initialize a rank with specified number of pages."""
|
||||
data = await request.json()
|
||||
data = await self._read_json(request)
|
||||
num_pages = data["num_pages"]
|
||||
with self.state.global_lock:
|
||||
if rank in self.state.ranks:
|
||||
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
|
||||
else:
|
||||
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
||||
self.state.ranks[rank] = RankMetadata(num_pages)
|
||||
return {"message": f"Rank {rank} is ready."}
|
||||
return Response(status_code=204)
|
||||
|
||||
async def exists(self, rank: int, request: Request):
|
||||
"""Check if keys exist in metadata."""
|
||||
data = await request.json()
|
||||
data = await self._read_json(request)
|
||||
keys = data["keys"]
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
results = metadata.exists_keys(keys)
|
||||
return {"exists": results}
|
||||
return self._json_response({"exists": results})
|
||||
|
||||
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
||||
"""Reserve and allocate page indices for keys."""
|
||||
data = await request.json()
|
||||
data = await self._read_json(request)
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
keys = data["keys"]
|
||||
results = metadata.reserve_and_allocate_page_indices(keys)
|
||||
return {"indices": results}
|
||||
return self._json_response({"indices": results})
|
||||
|
||||
async def confirm_write(self, rank: int, request: Request):
|
||||
"""Confirm write operations and release pages."""
|
||||
data = await request.json()
|
||||
data = await self._read_json(request)
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
success_written_keys = data.get("written_keys_to_confirm", [])
|
||||
released_pages = data.get("pages_to_release", [])
|
||||
|
||||
metadata.confirm_write(success_written_keys, released_pages)
|
||||
|
||||
return {
|
||||
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
|
||||
}
|
||||
return Response(status_code=204)
|
||||
|
||||
async def delete_keys(self, rank: int, request: Request):
|
||||
"""Delete keys from metadata."""
|
||||
data = await request.json()
|
||||
data = await self._read_json(request)
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
count = metadata.delete_keys(data["keys"])
|
||||
return {"message": f"Rank {rank}: {count} keys deleted."}
|
||||
return Response(status_code=204)
|
||||
|
||||
async def clear(self, rank: int):
|
||||
"""Clear all metadata for a rank."""
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
metadata.clear_all()
|
||||
return {"message": f"Rank {rank}: Metadata cleared."}
|
||||
return Response(status_code=204)
|
||||
|
||||
async def get_page_indices(self, rank: int, request: Request):
|
||||
"""Get page indices for keys."""
|
||||
data = await request.json()
|
||||
data = await self._read_json(request)
|
||||
metadata = self.get_rank_metadata(rank)
|
||||
keys = data["keys"]
|
||||
results = metadata.get_page_indices(keys)
|
||||
return {"indices": results}
|
||||
return self._json_response({"indices": results})
|
||||
|
||||
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
||||
"""Run the metadata server."""
|
||||
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
|
||||
status_forcelist=[500, 502, 503, 504],
|
||||
allowed_methods=["GET", "POST"],
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
adapter = HTTPAdapter(
|
||||
max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
|
||||
)
|
||||
self._session.mount("http://", adapter)
|
||||
|
||||
def _post(self, endpoint: str, json_data: dict) -> dict:
|
||||
try:
|
||||
response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = orjson.dumps(json_data) # type: ignore[union-attr]
|
||||
response = self._session.post(url, data=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
if response.status_code == 204 or not response.content:
|
||||
return {}
|
||||
return orjson.loads(response.content) # type: ignore[union-attr]
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
||||
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|
||||
|
||||
@@ -113,6 +113,8 @@ def synchronized():
|
||||
|
||||
|
||||
class HiCacheHF3FS(HiCacheStorage):
|
||||
"""HiCache backend that stores KV cache pages in HF3FS files."""
|
||||
|
||||
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
||||
|
||||
def __init__(
|
||||
@@ -176,15 +178,32 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
dtype: torch.dtype,
|
||||
storage_config: HiCacheStorageConfig = None,
|
||||
) -> "HiCacheHF3FS":
|
||||
"""Create a HiCacheHF3FS instance from environment configuration.
|
||||
|
||||
Environment:
|
||||
- Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
|
||||
- Falls back to a local single-machine config when the env var is not set.
|
||||
|
||||
Raises:
|
||||
ValueError: If MLA Model is requested without global metadata server or required keys are missing.
|
||||
"""
|
||||
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||
Hf3fsGlobalMetadataClient,
|
||||
Hf3fsLocalMetadataClient,
|
||||
)
|
||||
|
||||
rank = storage_config.tp_rank if storage_config is not None else 0
|
||||
if storage_config is not None:
|
||||
rank, is_mla_model = storage_config.tp_rank, storage_config.is_mla_model
|
||||
else:
|
||||
rank, is_mla_model = 0, False
|
||||
|
||||
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
||||
|
||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||
if not config_path:
|
||||
if is_mla_model:
|
||||
raise ValueError(mla_unsupported_msg)
|
||||
|
||||
return HiCacheHF3FS(
|
||||
rank=rank,
|
||||
file_path=f"/data/hicache.{rank}.bin",
|
||||
@@ -214,25 +233,27 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
||||
|
||||
# Choose metadata client based on configuration
|
||||
is_mla_model = False
|
||||
if "metadata_server_url" in config and config["metadata_server_url"]:
|
||||
if config.get("metadata_server_url"):
|
||||
# Use global metadata client to connect to metadata server
|
||||
metadata_server_url = config["metadata_server_url"]
|
||||
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
||||
|
||||
# Enable MLA optimization only when using the global metadata client
|
||||
is_mla_model = storage_config.is_mla_model if storage_config else False
|
||||
logger.info(
|
||||
f"Using global metadata client with server url: {metadata_server_url}"
|
||||
)
|
||||
else:
|
||||
# Enable MLA optimization only when using the global metadata client
|
||||
if is_mla_model:
|
||||
raise ValueError(mla_unsupported_msg)
|
||||
|
||||
# Use local metadata client for single-machine deployment
|
||||
metadata_client = Hf3fsLocalMetadataClient()
|
||||
|
||||
rank_for_path = 0 if is_mla_model else rank
|
||||
return HiCacheHF3FS(
|
||||
rank=rank,
|
||||
# Let all ranks use the same file path for MLA model
|
||||
file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
|
||||
file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
|
||||
file_size=int(config["file_size"]),
|
||||
numjobs=int(config["numjobs"]),
|
||||
bytes_per_page=bytes_per_page,
|
||||
|
||||
Reference in New Issue
Block a user