diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index e7994d791..f96037873 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -1,10 +1,12 @@ import json import logging import os +import time import uuid from dataclasses import dataclass from typing import Any, List, Optional +import requests import torch from sglang.srt.mem_cache.hicache_storage import ( @@ -17,6 +19,10 @@ from sglang.srt.mem_cache.memory_pool_host import HostKVCache DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "SGLANG_HICACHE_MOONCAKE_CONFIG_PATH" +SETUP_TIMEOUT = 600 # 10min +DEFAULT_MASTER_METRICS_PORT = 9003 +DEFAULT_CHECK_SERVER = False + logger = logging.getLogger(__name__) @@ -45,6 +51,8 @@ class MooncakeStoreConfig: protocol: str device_name: str master_server_address: str + master_metrics_port: int + check_server: bool @staticmethod def from_file() -> "MooncakeStoreConfig": @@ -67,6 +75,10 @@ class MooncakeStoreConfig: protocol=config.get("protocol", "tcp"), device_name=config.get("device_name", ""), master_server_address=config.get("master_server_address"), + master_metrics_port=config.get( + "master_metrics_port", DEFAULT_MASTER_METRICS_PORT + ), + check_server=config.get("check_server", DEFAULT_CHECK_SERVER), ) @staticmethod @@ -91,6 +103,10 @@ class MooncakeStoreConfig: protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"), device_name=os.getenv("MOONCAKE_DEVICE", ""), master_server_address=os.getenv("MOONCAKE_MASTER"), + master_metrics_port=int( + os.getenv("MOONCAKE_MASTER_METRICS_PORT", DEFAULT_GLOBAL_SEGMENT_SIZE) + ), + check_server=bool(os.getenv("MOONCAKE_CHECK_SERVER", DEFAULT_CHECK_SERVER)), ) @staticmethod @@ -111,6 +127,10 @@ class MooncakeStoreConfig: protocol=extra_config.get("protocol", "tcp"), device_name=extra_config.get("device_name", ""), master_server_address=extra_config["master_server_address"], + master_metrics_port=extra_config.get( + "master_metrics_port", DEFAULT_MASTER_METRICS_PORT + ), + check_server=extra_config.get("check_server", DEFAULT_CHECK_SERVER), ) @@ -166,6 +186,10 @@ class MooncakeStore(HiCacheStorage): self.extra_backend_tag = extra_config["extra_backend_tag"] logger.info(f"Using extra_backend_tag: {self.extra_backend_tag}") + # Check server status + if self.config.check_server: + self.check_server() + ret_code = self.store.setup( self.config.local_hostname, self.config.metadata_server, @@ -196,6 +220,39 @@ class MooncakeStore(HiCacheStorage): logger.error("An error occurred while loading the configuration: %s", exc) raise + def check_server(self): + master_server_ip = self.config.master_server_address.split(":")[0] + segments_url = f"http://{master_server_ip}:{self.config.master_metrics_port}/get_all_segments" + start_time = time.perf_counter() + + check_result = False + while time.perf_counter() - start_time < SETUP_TIMEOUT: + try: + check_segments_resp = requests.get(segments_url, timeout=3) + except Exception: + logger.info( + "waiting mooncake store server started, cost_time: %.2f seconds.", + time.perf_counter() - start_time, + ) + time.sleep(3) + continue + + if check_segments_resp.text == "": + logger.info( + "waiting mooncake store server started, cost_time: %.2f seconds.", + time.perf_counter() - start_time, + ) + time.sleep(3) + continue + + logger.info("Mooncake store server started successfully.") + check_result = True + break + + if not check_result: + logger.error("Launch mooncake store server timeout") + raise ValueError("Launch mooncake store server timeout") + def warmup(self): warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex warmup_value = bytes(4 * 1024) # 4 KB