diff --git a/scripts/ci/ci_install_dependency.sh b/scripts/ci/ci_install_dependency.sh index f36f59206..0e370eec3 100755 --- a/scripts/ci/ci_install_dependency.sh +++ b/scripts/ci/ci_install_dependency.sh @@ -74,7 +74,7 @@ fi $PIP_CMD list # Install additional dependencies -$PIP_CMD install mooncake-transfer-engine==0.3.5 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX +$PIP_CMD install mooncake-transfer-engine==0.3.6 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX if [ "$IS_BLACKWELL" != "1" ]; then # For lmms_evals evaluating MMMU diff --git a/test/srt/hicache/test_hicache_storage_mooncake_backend.py b/test/srt/hicache/test_hicache_storage_mooncake_backend.py new file mode 100644 index 000000000..fdfa8c93f --- /dev/null +++ b/test/srt/hicache/test_hicache_storage_mooncake_backend.py @@ -0,0 +1,316 @@ +""" +Benchmark tests for HiCache Storage with Mooncake backend. +Usage: + python3.10 -m pytest test/srt/hicache/test_hicache_storage_mooncake_backend.py -v +""" + +import json +import os +import subprocess +import time +import unittest +from types import SimpleNamespace + +import requests +from test_hicache_storage_file_backend import HiCacheStorageBaseMixin + +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + CustomTestCase, + find_available_port, +) + + +class HiCacheStorageMooncakeBackendBaseMixin(HiCacheStorageBaseMixin): + """Base mixin class with common setup and utilities""" + + # Default port ranges for Mooncake services - can be overridden in subclasses + mooncake_master_port_base = 50051 + mooncake_metadata_port_base = 8080 + + @classmethod + def setUpClass(cls): + """Set up test environment and launch Mooncake services before server setup""" + # Find available ports for Mooncake services to avoid conflicts + cls.mooncake_master_port = find_available_port( + HiCacheStorageMooncakeBackendBaseMixin.mooncake_master_port_base + ) + cls.mooncake_metadata_port = find_available_port( + HiCacheStorageMooncakeBackendBaseMixin.mooncake_metadata_port_base + ) + + # Start Mooncake services first + cls._start_mooncake_services() + + # Call parent setup + super().setUpClass() + + @classmethod + def tearDownClass(cls): + """Clean up Mooncake services after server teardown""" + # Call parent teardown first + super().tearDownClass() + + # Stop Mooncake services + cls._stop_mooncake_services() + + @classmethod + def _start_mooncake_services(cls): + """Start Mooncake metadata and master services with configurable ports and readiness detection""" + print("Starting Mooncake services...") + print( + f"Using master port: {cls.mooncake_master_port}, metadata port: {cls.mooncake_metadata_port}" + ) + + # Start metadata service with configurable port + try: + # Start metadata server with port configuration + cls.metadata_service_process = subprocess.Popen( + [ + "python3", + "-m", + "mooncake.http_metadata_server", + "--port", + str(cls.mooncake_metadata_port), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=os.setsid, # Create new process group + ) + print( + f"Mooncake metadata service started on port {cls.mooncake_metadata_port}" + ) + except (FileNotFoundError, subprocess.SubprocessError) as e: + print(f"Warning: Could not start Mooncake metadata service: {e}") + cls.metadata_service_process = None + + # Start master service with configurable port + try: + # Start master server with port configuration + cls.master_service_process = subprocess.Popen( + ["mooncake_master", "--port", str(cls.mooncake_master_port)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=os.setsid, # Create new process group + ) + print(f"Mooncake master service started on port {cls.mooncake_master_port}") + except (FileNotFoundError, subprocess.SubprocessError) as e: + print(f"Warning: Could not start Mooncake master service: {e}") + cls.master_service_process = None + + # Wait for services to be ready instead of fixed sleep + cls._wait_for_mooncake_services_ready() + + @classmethod + def _wait_for_mooncake_services_ready(cls, timeout: int = 30) -> bool: + """Wait for Mooncake services to be ready by checking their endpoints""" + print("Waiting for Mooncake services to be ready...") + + start_time = time.time() + services_ready = False + + while time.time() - start_time < timeout: + try: + # Check metadata service + metadata_ready = False + if ( + cls.metadata_service_process + and cls.metadata_service_process.poll() is None + ): + try: + # Try to connect to the metadata service + metadata_url = ( + f"http://127.0.0.1:{cls.mooncake_metadata_port}/metadata" + ) + response = requests.get(metadata_url, timeout=2) + if response.status_code == 200: + metadata_ready = True + print("Mooncake metadata service is ready") + except (requests.RequestException, ConnectionError): + # Service might not be fully started yet + pass + + # Check master service (if it has a health endpoint) + master_ready = False + if ( + cls.master_service_process + and cls.master_service_process.poll() is None + ): + # For now, we'll assume master service is ready if process is running + # and it's been a few seconds since startup + if ( + time.time() - start_time > 5 + ): # Give master service time to initialize + master_ready = True + print("Mooncake master service is ready") + + # Both services should be ready + if metadata_ready and master_ready: + services_ready = True + print("All Mooncake services are ready") + break + + except Exception as e: + print(f"Error checking service readiness: {e}") + + time.sleep(2) + + if not services_ready: + print( + "Warning: Mooncake services may not be fully ready, continuing anyway..." + ) + + return services_ready + + @classmethod + def _stop_mooncake_services(cls): + """Stop Mooncake services""" + print("Stopping Mooncake services...") + + # Stop metadata service + if hasattr(cls, "metadata_service_process") and cls.metadata_service_process: + try: + os.killpg(os.getpgid(cls.metadata_service_process.pid), 9) + cls.metadata_service_process.wait(timeout=5) + print("Mooncake metadata service stopped") + except (ProcessLookupError, subprocess.TimeoutExpired, OSError) as e: + print(f"Warning: Could not stop Mooncake metadata service: {e}") + + # Stop master service + if hasattr(cls, "master_service_process") and cls.master_service_process: + try: + os.killpg(os.getpgid(cls.master_service_process.pid), 9) + cls.master_service_process.wait(timeout=5) + print("Mooncake master service stopped") + except (ProcessLookupError, subprocess.TimeoutExpired, OSError) as e: + print(f"Warning: Could not stop Mooncake master service: {e}") + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + + server_args = { + "--tp-size": 1, + "--hicache-ratio": 2, + "--hicache-storage-backend": "mooncake", + } + + # Set the environment variables for Mooncake using dynamic ports + env_vars = { + "MOONCAKE_MASTER": f"127.0.0.1:{cls.mooncake_master_port}", + "MOONCAKE_PROTOCOL": "rdma", + "MOONCAKE_DEVICE": "mlx5_roce0,mlx5_roce1", + "MOONCAKE_TE_META_DATA_SERVER": f"http://127.0.0.1:{cls.mooncake_metadata_port}/metadata", + "MOONCAKE_GLOBAL_SEGMENT_SIZE": "4294967296", # 4 GiB + } + + return server_args, env_vars + + +''' +# Same as #10131, layer first layout test TODO(mateng): will make it work +class TestMooncakeBackendLayerFirstLayout( + HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase +): + """Layer first layout tests for HiCache-Mooncake backend""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args, env_vars = super()._get_additional_server_args_and_env() + server_args["--hicache-mem-layout"] = "layer_first" + server_args["--hicache-io-backend"] = "direct" + return server_args, env_vars +''' + + +class TestMooncakeBackendPageFirstLayout( + HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase +): + """Page first layout tests for HiCache-Mooncake backend""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args, env_vars = super()._get_additional_server_args_and_env() + server_args["--hicache-mem-layout"] = "page_first" + return server_args, env_vars + + +class TestMooncakeBackendMLAModel( + HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase +): + """MLA Model tests for HiCache-Mooncake backend""" + + @classmethod + def _get_model_name(cls): + """Use MLA model for testing""" + return DEFAULT_MLA_MODEL_NAME_FOR_TEST + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args, env_vars = super()._get_additional_server_args_and_env() + server_args["--hicache-mem-layout"] = "page_first" + server_args["--tp-size"] = 2 + return server_args, env_vars + + +class TestMooncakeBackendAccuracy( + HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase +): + """Accuracy tests for HiCache-Mooncake backend""" + + @classmethod + def _get_additional_server_args_and_env(cls): + """Get additional server arguments specific to configuration - override in subclasses""" + server_args, env_vars = super()._get_additional_server_args_and_env() + server_args["--hicache-ratio"] = 1.5 + server_args["--tp-size"] = 2 + return server_args, env_vars + + def test_eval_accuracy(self): + """Test eval accuracy with cache persistence across cache flushes""" + print("\n=== Testing Eval Accuracy with Cache Persistence ===") + + # First evaluation - populate cache + print("Phase 1: Running initial GSM8K evaluation to populate cache...") + args_initial = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=50, + max_new_tokens=512, + parallel=10, + host=f"http://{self.base_host}", + port=int(self.base_port), + ) + metrics_initial = run_eval_few_shot_gsm8k(args_initial) + + # Flush cache to force remote storage access + print("Phase 2: Flushing device cache...") + self.assertTrue(self.flush_cache(), "Cache flush should succeed") + time.sleep(2) + + # Second evaluation - should use remote cache + print("Phase 3: Running second GSM8K evaluation using remote cache...") + metrics_cached = run_eval_few_shot_gsm8k(args_initial) + + # Verify accuracy consistency + accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"]) + print(f"Accuracy difference: {accuracy_diff:.4f}") + + # Assertions + self.assertGreater( + metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable" + ) + self.assertGreater( + metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable" + ) + self.assertLess( + accuracy_diff, 0.05, "Accuracy should be consistent between cache states" + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 912d32801..13d1000a3 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -142,6 +142,7 @@ suites = { "per-commit-8-gpu": [ # Disabled because it hangs on the CI. # TestFile("ep/test_moe_ep.py", 181), + TestFile("hicache/test_hicache_storage_mooncake_backend.py", 800), TestFile("lora/test_lora_llama4.py", 600), TestFile("test_disaggregation.py", 499), TestFile("test_disaggregation_different_tp.py", 155),