diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py b/python/sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py index 1967259ac..414d13adc 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py @@ -4,10 +4,12 @@ import json import logging import threading from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, OrderedDict, Tuple +import orjson import requests -from fastapi import FastAPI, HTTPException, Request, status +from fastapi import FastAPI, HTTPException, Request, Response +from fastapi.responses import ORJSONResponse from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry @@ -24,10 +26,10 @@ class RankMetadata: """Holds all metadata for a single rank.""" def __init__(self, num_pages: int): - self.lock = threading.RLock() + self.lock = threading.Lock() self.num_pages = num_pages self.free_pages: List[int] = list(range(num_pages)) - self.key_to_index: Dict[str, int] = {} + self.key_to_index: OrderedDict[str, int] = OrderedDict() # Todo: Support multi files for HF3FS def exists_keys(self, keys: List[str]) -> List[bool]: @@ -46,16 +48,18 @@ class RankMetadata: for i, (key, prefix_key) in enumerate(keys): if key in self.key_to_index: results[i] = (True, self.key_to_index[key]) + self.key_to_index.move_to_end(key) else: new_keys_to_process.append((i, key, prefix_key)) # Todo: Implementing data eviction logic after HiCache supports prefix information pass-through for i, key, prefix_key in new_keys_to_process: if len(self.free_pages) > 0: - page_idx = self.free_pages.pop() - results[i] = (False, page_idx) + page_index = self.free_pages.pop() else: - results[i] = (False, -1) + page_index = self.key_to_index.popitem(last=False)[1] + + results[i] = (False, page_index) return results @@ -68,6 +72,7 @@ class RankMetadata: with self.lock: for key, page_index in written_keys_to_confirm: self.key_to_index[key] = page_index + self.key_to_index.move_to_end(key) for page_index in pages_to_release: if page_index not in self.free_pages: @@ -94,7 +99,14 @@ class RankMetadata: def get_page_indices(self, keys: List[str]) -> List[Optional[int]]: """Get page indices for keys.""" with self.lock: - return [self.key_to_index.get(key) for key in keys] + results = [] + for key in keys: + if key in self.key_to_index: + results.append(self.key_to_index[key]) + self.key_to_index.move_to_end(key) + else: + results.append(None) + return results class GlobalMetadataState: @@ -182,7 +194,8 @@ class Hf3fsMetadataServer: def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60): self.state = GlobalMetadataState(persistence_path, save_interval) - self.app = FastAPI() + self.app = FastAPI(default_response_class=ORJSONResponse) + self._setup_routes() def _setup_routes(self): @@ -199,17 +212,25 @@ class Hf3fsMetadataServer: def get_rank_metadata(self, rank: int) -> RankMetadata: """Get rank metadata with proper error handling.""" - with self.state.global_lock: - if rank not in self.state.ranks: - raise HTTPException( - status_code=404, - detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.", - ) - return self.state.ranks[rank] + if rank not in self.state.ranks: + raise HTTPException( + status_code=404, + detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.", + ) + return self.state.ranks[rank] + + async def _read_json(self, request: Request) -> dict: + """Parse request JSON using orjson if available.""" + body = await request.body() + return orjson.loads(body) + + def _json_response(self, content: dict): + """Return ORJSONResponse when available to bypass jsonable_encoder.""" + return ORJSONResponse(content) async def initialize(self, rank: int, request: Request): """Initialize a rank with specified number of pages.""" - data = await request.json() + data = await self._read_json(request) num_pages = data["num_pages"] with self.state.global_lock: if rank in self.state.ranks: @@ -223,57 +244,55 @@ class Hf3fsMetadataServer: else: logging.info(f"Initializing new Rank {rank} with {num_pages} pages.") self.state.ranks[rank] = RankMetadata(num_pages) - return {"message": f"Rank {rank} is ready."} + return Response(status_code=204) async def exists(self, rank: int, request: Request): """Check if keys exist in metadata.""" - data = await request.json() + data = await self._read_json(request) keys = data["keys"] metadata = self.get_rank_metadata(rank) results = metadata.exists_keys(keys) - return {"exists": results} + return self._json_response({"exists": results}) async def reserve_and_allocate_page_indices(self, rank: int, request: Request): """Reserve and allocate page indices for keys.""" - data = await request.json() + data = await self._read_json(request) metadata = self.get_rank_metadata(rank) keys = data["keys"] results = metadata.reserve_and_allocate_page_indices(keys) - return {"indices": results} + return self._json_response({"indices": results}) async def confirm_write(self, rank: int, request: Request): """Confirm write operations and release pages.""" - data = await request.json() + data = await self._read_json(request) metadata = self.get_rank_metadata(rank) success_written_keys = data.get("written_keys_to_confirm", []) released_pages = data.get("pages_to_release", []) metadata.confirm_write(success_written_keys, released_pages) - return { - "message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released." - } + return Response(status_code=204) async def delete_keys(self, rank: int, request: Request): """Delete keys from metadata.""" - data = await request.json() + data = await self._read_json(request) metadata = self.get_rank_metadata(rank) count = metadata.delete_keys(data["keys"]) - return {"message": f"Rank {rank}: {count} keys deleted."} + return Response(status_code=204) async def clear(self, rank: int): """Clear all metadata for a rank.""" metadata = self.get_rank_metadata(rank) metadata.clear_all() - return {"message": f"Rank {rank}: Metadata cleared."} + return Response(status_code=204) async def get_page_indices(self, rank: int, request: Request): """Get page indices for keys.""" - data = await request.json() + data = await self._read_json(request) metadata = self.get_rank_metadata(rank) keys = data["keys"] results = metadata.get_page_indices(keys) - return {"indices": results} + return self._json_response({"indices": results}) def run(self, host: str = "0.0.0.0", port: int = 18000): """Run the metadata server.""" @@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface): status_forcelist=[500, 502, 503, 504], allowed_methods=["GET", "POST"], ) - adapter = HTTPAdapter(max_retries=retry_strategy) + adapter = HTTPAdapter( + max_retries=retry_strategy, pool_connections=256, pool_maxsize=256 + ) self._session.mount("http://", adapter) def _post(self, endpoint: str, json_data: dict) -> dict: try: - response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data) + url = f"{self.base_url}/{endpoint}" + headers = {"Content-Type": "application/json"} + payload = orjson.dumps(json_data) # type: ignore[union-attr] + response = self._session.post(url, data=payload, headers=headers) response.raise_for_status() - return response.json() + + if response.status_code == 204 or not response.content: + return {} + return orjson.loads(response.content) # type: ignore[union-attr] except requests.exceptions.RequestException as e: logging.error(f"Failed to POST to {endpoint} after retries: {e}") raise RuntimeError(f"Failed to connect to metadata server: {e}") from e diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index 82e850d37..a30230cdc 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -113,6 +113,8 @@ def synchronized(): class HiCacheHF3FS(HiCacheStorage): + """HiCache backend that stores KV cache pages in HF3FS files.""" + default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH" def __init__( @@ -176,15 +178,32 @@ class HiCacheHF3FS(HiCacheStorage): dtype: torch.dtype, storage_config: HiCacheStorageConfig = None, ) -> "HiCacheHF3FS": + """Create a HiCacheHF3FS instance from environment configuration. + + Environment: + - Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config. + - Falls back to a local single-machine config when the env var is not set. + + Raises: + ValueError: If MLA Model is requested without global metadata server or required keys are missing. + """ from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import ( Hf3fsGlobalMetadataClient, Hf3fsLocalMetadataClient, ) - rank = storage_config.tp_rank if storage_config is not None else 0 + if storage_config is not None: + rank, is_mla_model = storage_config.tp_rank, storage_config.is_mla_model + else: + rank, is_mla_model = 0, False + + mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md" config_path = os.getenv(HiCacheHF3FS.default_env_var) if not config_path: + if is_mla_model: + raise ValueError(mla_unsupported_msg) + return HiCacheHF3FS( rank=rank, file_path=f"/data/hicache.{rank}.bin", @@ -214,25 +233,27 @@ class HiCacheHF3FS(HiCacheStorage): raise ValueError(f"Missing required keys in config: {missing_keys}") # Choose metadata client based on configuration - is_mla_model = False - if "metadata_server_url" in config and config["metadata_server_url"]: + if config.get("metadata_server_url"): # Use global metadata client to connect to metadata server metadata_server_url = config["metadata_server_url"] metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url) - # Enable MLA optimization only when using the global metadata client - is_mla_model = storage_config.is_mla_model if storage_config else False logger.info( f"Using global metadata client with server url: {metadata_server_url}" ) else: + # Enable MLA optimization only when using the global metadata client + if is_mla_model: + raise ValueError(mla_unsupported_msg) + # Use local metadata client for single-machine deployment metadata_client = Hf3fsLocalMetadataClient() + rank_for_path = 0 if is_mla_model else rank return HiCacheHF3FS( rank=rank, # Let all ranks use the same file path for MLA model - file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin", + file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin", file_size=int(config["file_size"]), numjobs=int(config["numjobs"]), bytes_per_page=bytes_per_page,