[HiCache] Add tests for hicache storage mooncake backend (#10171)
Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: hzh0425 <hzh0425@apache.org> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -74,7 +74,7 @@ fi
|
||||
$PIP_CMD list
|
||||
|
||||
# Install additional dependencies
|
||||
$PIP_CMD install mooncake-transfer-engine==0.3.5 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
|
||||
$PIP_CMD install mooncake-transfer-engine==0.3.6 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
|
||||
|
||||
if [ "$IS_BLACKWELL" != "1" ]; then
|
||||
# For lmms_evals evaluating MMMU
|
||||
|
||||
316
test/srt/hicache/test_hicache_storage_mooncake_backend.py
Normal file
316
test/srt/hicache/test_hicache_storage_mooncake_backend.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Benchmark tests for HiCache Storage with Mooncake backend.
|
||||
Usage:
|
||||
python3.10 -m pytest test/srt/hicache/test_hicache_storage_mooncake_backend.py -v
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
from test_hicache_storage_file_backend import HiCacheStorageBaseMixin
|
||||
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
CustomTestCase,
|
||||
find_available_port,
|
||||
)
|
||||
|
||||
|
||||
class HiCacheStorageMooncakeBackendBaseMixin(HiCacheStorageBaseMixin):
|
||||
"""Base mixin class with common setup and utilities"""
|
||||
|
||||
# Default port ranges for Mooncake services - can be overridden in subclasses
|
||||
mooncake_master_port_base = 50051
|
||||
mooncake_metadata_port_base = 8080
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Set up test environment and launch Mooncake services before server setup"""
|
||||
# Find available ports for Mooncake services to avoid conflicts
|
||||
cls.mooncake_master_port = find_available_port(
|
||||
HiCacheStorageMooncakeBackendBaseMixin.mooncake_master_port_base
|
||||
)
|
||||
cls.mooncake_metadata_port = find_available_port(
|
||||
HiCacheStorageMooncakeBackendBaseMixin.mooncake_metadata_port_base
|
||||
)
|
||||
|
||||
# Start Mooncake services first
|
||||
cls._start_mooncake_services()
|
||||
|
||||
# Call parent setup
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""Clean up Mooncake services after server teardown"""
|
||||
# Call parent teardown first
|
||||
super().tearDownClass()
|
||||
|
||||
# Stop Mooncake services
|
||||
cls._stop_mooncake_services()
|
||||
|
||||
@classmethod
|
||||
def _start_mooncake_services(cls):
|
||||
"""Start Mooncake metadata and master services with configurable ports and readiness detection"""
|
||||
print("Starting Mooncake services...")
|
||||
print(
|
||||
f"Using master port: {cls.mooncake_master_port}, metadata port: {cls.mooncake_metadata_port}"
|
||||
)
|
||||
|
||||
# Start metadata service with configurable port
|
||||
try:
|
||||
# Start metadata server with port configuration
|
||||
cls.metadata_service_process = subprocess.Popen(
|
||||
[
|
||||
"python3",
|
||||
"-m",
|
||||
"mooncake.http_metadata_server",
|
||||
"--port",
|
||||
str(cls.mooncake_metadata_port),
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid, # Create new process group
|
||||
)
|
||||
print(
|
||||
f"Mooncake metadata service started on port {cls.mooncake_metadata_port}"
|
||||
)
|
||||
except (FileNotFoundError, subprocess.SubprocessError) as e:
|
||||
print(f"Warning: Could not start Mooncake metadata service: {e}")
|
||||
cls.metadata_service_process = None
|
||||
|
||||
# Start master service with configurable port
|
||||
try:
|
||||
# Start master server with port configuration
|
||||
cls.master_service_process = subprocess.Popen(
|
||||
["mooncake_master", "--port", str(cls.mooncake_master_port)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid, # Create new process group
|
||||
)
|
||||
print(f"Mooncake master service started on port {cls.mooncake_master_port}")
|
||||
except (FileNotFoundError, subprocess.SubprocessError) as e:
|
||||
print(f"Warning: Could not start Mooncake master service: {e}")
|
||||
cls.master_service_process = None
|
||||
|
||||
# Wait for services to be ready instead of fixed sleep
|
||||
cls._wait_for_mooncake_services_ready()
|
||||
|
||||
@classmethod
|
||||
def _wait_for_mooncake_services_ready(cls, timeout: int = 30) -> bool:
|
||||
"""Wait for Mooncake services to be ready by checking their endpoints"""
|
||||
print("Waiting for Mooncake services to be ready...")
|
||||
|
||||
start_time = time.time()
|
||||
services_ready = False
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
# Check metadata service
|
||||
metadata_ready = False
|
||||
if (
|
||||
cls.metadata_service_process
|
||||
and cls.metadata_service_process.poll() is None
|
||||
):
|
||||
try:
|
||||
# Try to connect to the metadata service
|
||||
metadata_url = (
|
||||
f"http://127.0.0.1:{cls.mooncake_metadata_port}/metadata"
|
||||
)
|
||||
response = requests.get(metadata_url, timeout=2)
|
||||
if response.status_code == 200:
|
||||
metadata_ready = True
|
||||
print("Mooncake metadata service is ready")
|
||||
except (requests.RequestException, ConnectionError):
|
||||
# Service might not be fully started yet
|
||||
pass
|
||||
|
||||
# Check master service (if it has a health endpoint)
|
||||
master_ready = False
|
||||
if (
|
||||
cls.master_service_process
|
||||
and cls.master_service_process.poll() is None
|
||||
):
|
||||
# For now, we'll assume master service is ready if process is running
|
||||
# and it's been a few seconds since startup
|
||||
if (
|
||||
time.time() - start_time > 5
|
||||
): # Give master service time to initialize
|
||||
master_ready = True
|
||||
print("Mooncake master service is ready")
|
||||
|
||||
# Both services should be ready
|
||||
if metadata_ready and master_ready:
|
||||
services_ready = True
|
||||
print("All Mooncake services are ready")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error checking service readiness: {e}")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
if not services_ready:
|
||||
print(
|
||||
"Warning: Mooncake services may not be fully ready, continuing anyway..."
|
||||
)
|
||||
|
||||
return services_ready
|
||||
|
||||
@classmethod
|
||||
def _stop_mooncake_services(cls):
|
||||
"""Stop Mooncake services"""
|
||||
print("Stopping Mooncake services...")
|
||||
|
||||
# Stop metadata service
|
||||
if hasattr(cls, "metadata_service_process") and cls.metadata_service_process:
|
||||
try:
|
||||
os.killpg(os.getpgid(cls.metadata_service_process.pid), 9)
|
||||
cls.metadata_service_process.wait(timeout=5)
|
||||
print("Mooncake metadata service stopped")
|
||||
except (ProcessLookupError, subprocess.TimeoutExpired, OSError) as e:
|
||||
print(f"Warning: Could not stop Mooncake metadata service: {e}")
|
||||
|
||||
# Stop master service
|
||||
if hasattr(cls, "master_service_process") and cls.master_service_process:
|
||||
try:
|
||||
os.killpg(os.getpgid(cls.master_service_process.pid), 9)
|
||||
cls.master_service_process.wait(timeout=5)
|
||||
print("Mooncake master service stopped")
|
||||
except (ProcessLookupError, subprocess.TimeoutExpired, OSError) as e:
|
||||
print(f"Warning: Could not stop Mooncake master service: {e}")
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
|
||||
server_args = {
|
||||
"--tp-size": 1,
|
||||
"--hicache-ratio": 2,
|
||||
"--hicache-storage-backend": "mooncake",
|
||||
}
|
||||
|
||||
# Set the environment variables for Mooncake using dynamic ports
|
||||
env_vars = {
|
||||
"MOONCAKE_MASTER": f"127.0.0.1:{cls.mooncake_master_port}",
|
||||
"MOONCAKE_PROTOCOL": "rdma",
|
||||
"MOONCAKE_DEVICE": "mlx5_roce0,mlx5_roce1",
|
||||
"MOONCAKE_TE_META_DATA_SERVER": f"http://127.0.0.1:{cls.mooncake_metadata_port}/metadata",
|
||||
"MOONCAKE_GLOBAL_SEGMENT_SIZE": "4294967296", # 4 GiB
|
||||
}
|
||||
|
||||
return server_args, env_vars
|
||||
|
||||
|
||||
'''
|
||||
# Same as #10131, layer first layout test TODO(mateng): will make it work
|
||||
class TestMooncakeBackendLayerFirstLayout(
|
||||
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
|
||||
):
|
||||
"""Layer first layout tests for HiCache-Mooncake backend"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||
server_args["--hicache-mem-layout"] = "layer_first"
|
||||
server_args["--hicache-io-backend"] = "direct"
|
||||
return server_args, env_vars
|
||||
'''
|
||||
|
||||
|
||||
class TestMooncakeBackendPageFirstLayout(
|
||||
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
|
||||
):
|
||||
"""Page first layout tests for HiCache-Mooncake backend"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||
server_args["--hicache-mem-layout"] = "page_first"
|
||||
return server_args, env_vars
|
||||
|
||||
|
||||
class TestMooncakeBackendMLAModel(
|
||||
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
|
||||
):
|
||||
"""MLA Model tests for HiCache-Mooncake backend"""
|
||||
|
||||
@classmethod
|
||||
def _get_model_name(cls):
|
||||
"""Use MLA model for testing"""
|
||||
return DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||
server_args["--hicache-mem-layout"] = "page_first"
|
||||
server_args["--tp-size"] = 2
|
||||
return server_args, env_vars
|
||||
|
||||
|
||||
class TestMooncakeBackendAccuracy(
|
||||
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
|
||||
):
|
||||
"""Accuracy tests for HiCache-Mooncake backend"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||
server_args["--hicache-ratio"] = 1.5
|
||||
server_args["--tp-size"] = 2
|
||||
return server_args, env_vars
|
||||
|
||||
def test_eval_accuracy(self):
|
||||
"""Test eval accuracy with cache persistence across cache flushes"""
|
||||
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
||||
|
||||
# First evaluation - populate cache
|
||||
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
||||
args_initial = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=50,
|
||||
max_new_tokens=512,
|
||||
parallel=10,
|
||||
host=f"http://{self.base_host}",
|
||||
port=int(self.base_port),
|
||||
)
|
||||
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Flush cache to force remote storage access
|
||||
print("Phase 2: Flushing device cache...")
|
||||
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
||||
time.sleep(2)
|
||||
|
||||
# Second evaluation - should use remote cache
|
||||
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
||||
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Verify accuracy consistency
|
||||
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
||||
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
||||
|
||||
# Assertions
|
||||
self.assertGreater(
|
||||
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
||||
)
|
||||
self.assertGreater(
|
||||
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
||||
)
|
||||
self.assertLess(
|
||||
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -142,6 +142,7 @@ suites = {
|
||||
"per-commit-8-gpu": [
|
||||
# Disabled because it hangs on the CI.
|
||||
# TestFile("ep/test_moe_ep.py", 181),
|
||||
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 800),
|
||||
TestFile("lora/test_lora_llama4.py", 600),
|
||||
TestFile("test_disaggregation.py", 499),
|
||||
TestFile("test_disaggregation_different_tp.py", 155),
|
||||
|
||||
Reference in New Issue
Block a user