[HiCacheStorage]: Improve 3fs kvstore‘s performance and resolve mla issues (#9876)
This commit is contained in:
@@ -4,10 +4,12 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, OrderedDict, Tuple
|
||||||
|
|
||||||
|
import orjson
|
||||||
import requests
|
import requests
|
||||||
from fastapi import FastAPI, HTTPException, Request, status
|
from fastapi import FastAPI, HTTPException, Request, Response
|
||||||
|
from fastapi.responses import ORJSONResponse
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
from urllib3.util.retry import Retry
|
from urllib3.util.retry import Retry
|
||||||
|
|
||||||
@@ -24,10 +26,10 @@ class RankMetadata:
|
|||||||
"""Holds all metadata for a single rank."""
|
"""Holds all metadata for a single rank."""
|
||||||
|
|
||||||
def __init__(self, num_pages: int):
|
def __init__(self, num_pages: int):
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.Lock()
|
||||||
self.num_pages = num_pages
|
self.num_pages = num_pages
|
||||||
self.free_pages: List[int] = list(range(num_pages))
|
self.free_pages: List[int] = list(range(num_pages))
|
||||||
self.key_to_index: Dict[str, int] = {}
|
self.key_to_index: OrderedDict[str, int] = OrderedDict()
|
||||||
# Todo: Support multi files for HF3FS
|
# Todo: Support multi files for HF3FS
|
||||||
|
|
||||||
def exists_keys(self, keys: List[str]) -> List[bool]:
|
def exists_keys(self, keys: List[str]) -> List[bool]:
|
||||||
@@ -46,16 +48,18 @@ class RankMetadata:
|
|||||||
for i, (key, prefix_key) in enumerate(keys):
|
for i, (key, prefix_key) in enumerate(keys):
|
||||||
if key in self.key_to_index:
|
if key in self.key_to_index:
|
||||||
results[i] = (True, self.key_to_index[key])
|
results[i] = (True, self.key_to_index[key])
|
||||||
|
self.key_to_index.move_to_end(key)
|
||||||
else:
|
else:
|
||||||
new_keys_to_process.append((i, key, prefix_key))
|
new_keys_to_process.append((i, key, prefix_key))
|
||||||
|
|
||||||
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
||||||
for i, key, prefix_key in new_keys_to_process:
|
for i, key, prefix_key in new_keys_to_process:
|
||||||
if len(self.free_pages) > 0:
|
if len(self.free_pages) > 0:
|
||||||
page_idx = self.free_pages.pop()
|
page_index = self.free_pages.pop()
|
||||||
results[i] = (False, page_idx)
|
|
||||||
else:
|
else:
|
||||||
results[i] = (False, -1)
|
page_index = self.key_to_index.popitem(last=False)[1]
|
||||||
|
|
||||||
|
results[i] = (False, page_index)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -68,6 +72,7 @@ class RankMetadata:
|
|||||||
with self.lock:
|
with self.lock:
|
||||||
for key, page_index in written_keys_to_confirm:
|
for key, page_index in written_keys_to_confirm:
|
||||||
self.key_to_index[key] = page_index
|
self.key_to_index[key] = page_index
|
||||||
|
self.key_to_index.move_to_end(key)
|
||||||
|
|
||||||
for page_index in pages_to_release:
|
for page_index in pages_to_release:
|
||||||
if page_index not in self.free_pages:
|
if page_index not in self.free_pages:
|
||||||
@@ -94,7 +99,14 @@ class RankMetadata:
|
|||||||
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
||||||
"""Get page indices for keys."""
|
"""Get page indices for keys."""
|
||||||
with self.lock:
|
with self.lock:
|
||||||
return [self.key_to_index.get(key) for key in keys]
|
results = []
|
||||||
|
for key in keys:
|
||||||
|
if key in self.key_to_index:
|
||||||
|
results.append(self.key_to_index[key])
|
||||||
|
self.key_to_index.move_to_end(key)
|
||||||
|
else:
|
||||||
|
results.append(None)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
class GlobalMetadataState:
|
class GlobalMetadataState:
|
||||||
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
|
|||||||
|
|
||||||
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
||||||
self.state = GlobalMetadataState(persistence_path, save_interval)
|
self.state = GlobalMetadataState(persistence_path, save_interval)
|
||||||
self.app = FastAPI()
|
self.app = FastAPI(default_response_class=ORJSONResponse)
|
||||||
|
|
||||||
self._setup_routes()
|
self._setup_routes()
|
||||||
|
|
||||||
def _setup_routes(self):
|
def _setup_routes(self):
|
||||||
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
|
|||||||
|
|
||||||
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
||||||
"""Get rank metadata with proper error handling."""
|
"""Get rank metadata with proper error handling."""
|
||||||
with self.state.global_lock:
|
if rank not in self.state.ranks:
|
||||||
if rank not in self.state.ranks:
|
raise HTTPException(
|
||||||
raise HTTPException(
|
status_code=404,
|
||||||
status_code=404,
|
detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
|
||||||
detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
|
)
|
||||||
)
|
return self.state.ranks[rank]
|
||||||
return self.state.ranks[rank]
|
|
||||||
|
async def _read_json(self, request: Request) -> dict:
|
||||||
|
"""Parse request JSON using orjson if available."""
|
||||||
|
body = await request.body()
|
||||||
|
return orjson.loads(body)
|
||||||
|
|
||||||
|
def _json_response(self, content: dict):
|
||||||
|
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
|
||||||
|
return ORJSONResponse(content)
|
||||||
|
|
||||||
async def initialize(self, rank: int, request: Request):
|
async def initialize(self, rank: int, request: Request):
|
||||||
"""Initialize a rank with specified number of pages."""
|
"""Initialize a rank with specified number of pages."""
|
||||||
data = await request.json()
|
data = await self._read_json(request)
|
||||||
num_pages = data["num_pages"]
|
num_pages = data["num_pages"]
|
||||||
with self.state.global_lock:
|
with self.state.global_lock:
|
||||||
if rank in self.state.ranks:
|
if rank in self.state.ranks:
|
||||||
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
|
|||||||
else:
|
else:
|
||||||
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
||||||
self.state.ranks[rank] = RankMetadata(num_pages)
|
self.state.ranks[rank] = RankMetadata(num_pages)
|
||||||
return {"message": f"Rank {rank} is ready."}
|
return Response(status_code=204)
|
||||||
|
|
||||||
async def exists(self, rank: int, request: Request):
|
async def exists(self, rank: int, request: Request):
|
||||||
"""Check if keys exist in metadata."""
|
"""Check if keys exist in metadata."""
|
||||||
data = await request.json()
|
data = await self._read_json(request)
|
||||||
keys = data["keys"]
|
keys = data["keys"]
|
||||||
metadata = self.get_rank_metadata(rank)
|
metadata = self.get_rank_metadata(rank)
|
||||||
results = metadata.exists_keys(keys)
|
results = metadata.exists_keys(keys)
|
||||||
return {"exists": results}
|
return self._json_response({"exists": results})
|
||||||
|
|
||||||
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
||||||
"""Reserve and allocate page indices for keys."""
|
"""Reserve and allocate page indices for keys."""
|
||||||
data = await request.json()
|
data = await self._read_json(request)
|
||||||
metadata = self.get_rank_metadata(rank)
|
metadata = self.get_rank_metadata(rank)
|
||||||
keys = data["keys"]
|
keys = data["keys"]
|
||||||
results = metadata.reserve_and_allocate_page_indices(keys)
|
results = metadata.reserve_and_allocate_page_indices(keys)
|
||||||
return {"indices": results}
|
return self._json_response({"indices": results})
|
||||||
|
|
||||||
async def confirm_write(self, rank: int, request: Request):
|
async def confirm_write(self, rank: int, request: Request):
|
||||||
"""Confirm write operations and release pages."""
|
"""Confirm write operations and release pages."""
|
||||||
data = await request.json()
|
data = await self._read_json(request)
|
||||||
metadata = self.get_rank_metadata(rank)
|
metadata = self.get_rank_metadata(rank)
|
||||||
success_written_keys = data.get("written_keys_to_confirm", [])
|
success_written_keys = data.get("written_keys_to_confirm", [])
|
||||||
released_pages = data.get("pages_to_release", [])
|
released_pages = data.get("pages_to_release", [])
|
||||||
|
|
||||||
metadata.confirm_write(success_written_keys, released_pages)
|
metadata.confirm_write(success_written_keys, released_pages)
|
||||||
|
|
||||||
return {
|
return Response(status_code=204)
|
||||||
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
|
|
||||||
}
|
|
||||||
|
|
||||||
async def delete_keys(self, rank: int, request: Request):
|
async def delete_keys(self, rank: int, request: Request):
|
||||||
"""Delete keys from metadata."""
|
"""Delete keys from metadata."""
|
||||||
data = await request.json()
|
data = await self._read_json(request)
|
||||||
metadata = self.get_rank_metadata(rank)
|
metadata = self.get_rank_metadata(rank)
|
||||||
count = metadata.delete_keys(data["keys"])
|
count = metadata.delete_keys(data["keys"])
|
||||||
return {"message": f"Rank {rank}: {count} keys deleted."}
|
return Response(status_code=204)
|
||||||
|
|
||||||
async def clear(self, rank: int):
|
async def clear(self, rank: int):
|
||||||
"""Clear all metadata for a rank."""
|
"""Clear all metadata for a rank."""
|
||||||
metadata = self.get_rank_metadata(rank)
|
metadata = self.get_rank_metadata(rank)
|
||||||
metadata.clear_all()
|
metadata.clear_all()
|
||||||
return {"message": f"Rank {rank}: Metadata cleared."}
|
return Response(status_code=204)
|
||||||
|
|
||||||
async def get_page_indices(self, rank: int, request: Request):
|
async def get_page_indices(self, rank: int, request: Request):
|
||||||
"""Get page indices for keys."""
|
"""Get page indices for keys."""
|
||||||
data = await request.json()
|
data = await self._read_json(request)
|
||||||
metadata = self.get_rank_metadata(rank)
|
metadata = self.get_rank_metadata(rank)
|
||||||
keys = data["keys"]
|
keys = data["keys"]
|
||||||
results = metadata.get_page_indices(keys)
|
results = metadata.get_page_indices(keys)
|
||||||
return {"indices": results}
|
return self._json_response({"indices": results})
|
||||||
|
|
||||||
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
||||||
"""Run the metadata server."""
|
"""Run the metadata server."""
|
||||||
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
|
|||||||
status_forcelist=[500, 502, 503, 504],
|
status_forcelist=[500, 502, 503, 504],
|
||||||
allowed_methods=["GET", "POST"],
|
allowed_methods=["GET", "POST"],
|
||||||
)
|
)
|
||||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
adapter = HTTPAdapter(
|
||||||
|
max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
|
||||||
|
)
|
||||||
self._session.mount("http://", adapter)
|
self._session.mount("http://", adapter)
|
||||||
|
|
||||||
def _post(self, endpoint: str, json_data: dict) -> dict:
|
def _post(self, endpoint: str, json_data: dict) -> dict:
|
||||||
try:
|
try:
|
||||||
response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
|
url = f"{self.base_url}/{endpoint}"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
payload = orjson.dumps(json_data) # type: ignore[union-attr]
|
||||||
|
response = self._session.post(url, data=payload, headers=headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
|
||||||
|
if response.status_code == 204 or not response.content:
|
||||||
|
return {}
|
||||||
|
return orjson.loads(response.content) # type: ignore[union-attr]
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
||||||
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|
||||||
|
|||||||
@@ -113,6 +113,8 @@ def synchronized():
|
|||||||
|
|
||||||
|
|
||||||
class HiCacheHF3FS(HiCacheStorage):
|
class HiCacheHF3FS(HiCacheStorage):
|
||||||
|
"""HiCache backend that stores KV cache pages in HF3FS files."""
|
||||||
|
|
||||||
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -176,15 +178,32 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
storage_config: HiCacheStorageConfig = None,
|
storage_config: HiCacheStorageConfig = None,
|
||||||
) -> "HiCacheHF3FS":
|
) -> "HiCacheHF3FS":
|
||||||
|
"""Create a HiCacheHF3FS instance from environment configuration.
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
- Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
|
||||||
|
- Falls back to a local single-machine config when the env var is not set.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If MLA Model is requested without global metadata server or required keys are missing.
|
||||||
|
"""
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||||
Hf3fsGlobalMetadataClient,
|
Hf3fsGlobalMetadataClient,
|
||||||
Hf3fsLocalMetadataClient,
|
Hf3fsLocalMetadataClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
rank = storage_config.tp_rank if storage_config is not None else 0
|
if storage_config is not None:
|
||||||
|
rank, is_mla_model = storage_config.tp_rank, storage_config.is_mla_model
|
||||||
|
else:
|
||||||
|
rank, is_mla_model = 0, False
|
||||||
|
|
||||||
|
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
||||||
|
|
||||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||||
if not config_path:
|
if not config_path:
|
||||||
|
if is_mla_model:
|
||||||
|
raise ValueError(mla_unsupported_msg)
|
||||||
|
|
||||||
return HiCacheHF3FS(
|
return HiCacheHF3FS(
|
||||||
rank=rank,
|
rank=rank,
|
||||||
file_path=f"/data/hicache.{rank}.bin",
|
file_path=f"/data/hicache.{rank}.bin",
|
||||||
@@ -214,25 +233,27 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
||||||
|
|
||||||
# Choose metadata client based on configuration
|
# Choose metadata client based on configuration
|
||||||
is_mla_model = False
|
if config.get("metadata_server_url"):
|
||||||
if "metadata_server_url" in config and config["metadata_server_url"]:
|
|
||||||
# Use global metadata client to connect to metadata server
|
# Use global metadata client to connect to metadata server
|
||||||
metadata_server_url = config["metadata_server_url"]
|
metadata_server_url = config["metadata_server_url"]
|
||||||
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
||||||
|
|
||||||
# Enable MLA optimization only when using the global metadata client
|
|
||||||
is_mla_model = storage_config.is_mla_model if storage_config else False
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using global metadata client with server url: {metadata_server_url}"
|
f"Using global metadata client with server url: {metadata_server_url}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Enable MLA optimization only when using the global metadata client
|
||||||
|
if is_mla_model:
|
||||||
|
raise ValueError(mla_unsupported_msg)
|
||||||
|
|
||||||
# Use local metadata client for single-machine deployment
|
# Use local metadata client for single-machine deployment
|
||||||
metadata_client = Hf3fsLocalMetadataClient()
|
metadata_client = Hf3fsLocalMetadataClient()
|
||||||
|
|
||||||
|
rank_for_path = 0 if is_mla_model else rank
|
||||||
return HiCacheHF3FS(
|
return HiCacheHF3FS(
|
||||||
rank=rank,
|
rank=rank,
|
||||||
# Let all ranks use the same file path for MLA model
|
# Let all ranks use the same file path for MLA model
|
||||||
file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
|
file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
|
||||||
file_size=int(config["file_size"]),
|
file_size=int(config["file_size"]),
|
||||||
numjobs=int(config["numjobs"]),
|
numjobs=int(config["numjobs"]),
|
||||||
bytes_per_page=bytes_per_page,
|
bytes_per_page=bytes_per_page,
|
||||||
|
|||||||
Reference in New Issue
Block a user