[HICache]: Refactor HiCache CI (#11011)
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -62,19 +62,7 @@ class TestHf3fsBackendLayerFirstLayout(
|
||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||
server_args["--hicache-mem-layout"] = "layer_first"
|
||||
server_args["--hicache-io-backend"] = "direct"
|
||||
return server_args, env_vars
|
||||
|
||||
|
||||
class TestHf3fsBackendPageFirstLayout(
|
||||
HiCacheStorage3FSBackendBaseMixin, CustomTestCase
|
||||
):
|
||||
"""Page first layout tests for HiCache-Hf3fs backend"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||
server_args["--hicache-mem-layout"] = "page_first"
|
||||
server_args["--tp-size"] = 2
|
||||
return server_args, env_vars
|
||||
|
||||
|
||||
@@ -91,44 +79,9 @@ class TestHf3fsBackendAccuracy(HiCacheStorage3FSBackendBaseMixin, CustomTestCase
|
||||
|
||||
def test_eval_accuracy(self):
|
||||
"""Test eval accuracy with cache persistence across cache flushes"""
|
||||
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
||||
from test_hicache_storage_file_backend import run_eval_accuracy_test
|
||||
|
||||
# First evaluation - populate cache
|
||||
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
||||
args_initial = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=50,
|
||||
max_new_tokens=512,
|
||||
parallel=10,
|
||||
host=f"http://{self.base_host}",
|
||||
port=int(self.base_port),
|
||||
)
|
||||
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Flush cache to force remote storage access
|
||||
print("Phase 2: Flushing device cache...")
|
||||
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
||||
time.sleep(2)
|
||||
|
||||
# Second evaluation - should use remote cache
|
||||
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
||||
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Verify accuracy consistency
|
||||
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
||||
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
||||
|
||||
# Assertions
|
||||
self.assertGreater(
|
||||
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
||||
)
|
||||
self.assertGreater(
|
||||
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
||||
)
|
||||
self.assertLess(
|
||||
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
||||
)
|
||||
run_eval_accuracy_test(self)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -24,6 +24,7 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
@@ -89,6 +90,7 @@ class HiCacheStorageBaseMixin:
|
||||
"""Launch server with HiCache enabled"""
|
||||
|
||||
additional_server_args, env_vars = cls._get_additional_server_args_and_env()
|
||||
env_vars["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1"
|
||||
server_args = cls._get_base_server_args()
|
||||
if additional_server_args:
|
||||
server_args.update(additional_server_args)
|
||||
@@ -215,42 +217,7 @@ class HiCacheStorageBaseMixin:
|
||||
)
|
||||
|
||||
|
||||
class TestHiCacheStorageTP(HiCacheStorageBaseMixin, CustomTestCase):
|
||||
"""Multi-TP tests for HiCache Storage functionality"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args = {"--tp-size": 2}
|
||||
return server_args, {}
|
||||
|
||||
|
||||
class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase):
|
||||
"""Layer first direct tests for HiCache Storage functionality"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args = {
|
||||
"--hicache-mem-layout": "layer_first",
|
||||
"--hicache-io-backend": "direct",
|
||||
}
|
||||
return server_args, {}
|
||||
|
||||
|
||||
class TestHiCacheStoragePageFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase):
|
||||
"""Page first direct tests for HiCache Storage functionality"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args = {
|
||||
"--hicache-mem-layout": "page_first_direct",
|
||||
"--hicache-io-backend": "direct",
|
||||
}
|
||||
return server_args, {}
|
||||
|
||||
|
||||
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||
class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseMixin, CustomTestCase):
|
||||
"""Page first layout tests for HiCache Storage functionality"""
|
||||
|
||||
@@ -261,6 +228,7 @@ class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseMixin, CustomTestCase)
|
||||
return server_args, {}
|
||||
|
||||
|
||||
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||
class TestHiCacheStorageMLA(HiCacheStorageBaseMixin, CustomTestCase):
|
||||
"""MLA Model tests for HiCache Storage functionality"""
|
||||
|
||||
@@ -276,71 +244,84 @@ class TestHiCacheStorageMLA(HiCacheStorageBaseMixin, CustomTestCase):
|
||||
return server_args, {}
|
||||
|
||||
|
||||
class TestHiCacheStoragePageFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase):
|
||||
"""Page first direct tests for HiCache Storage functionality"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args = {
|
||||
"--hicache-mem-layout": "page_first_direct",
|
||||
"--hicache-io-backend": "direct",
|
||||
"--tp-size": 2,
|
||||
}
|
||||
return server_args, {}
|
||||
|
||||
|
||||
class TestHiCacheStorageAccuracy(HiCacheStorageBaseMixin, CustomTestCase):
|
||||
"""Accuracy tests for HiCache Storage functionality"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args = {"--tp-size": 2, "--hicache-ratio": 1.5}
|
||||
server_args = {
|
||||
"--tp-size": 2,
|
||||
"--hicache-ratio": 1.5,
|
||||
}
|
||||
|
||||
return server_args, {}
|
||||
|
||||
def test_eval_accuracy(self):
|
||||
"""Test eval accuracy with cache persistence across cache flushes"""
|
||||
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
||||
|
||||
# First evaluation - populate cache
|
||||
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
||||
args_initial = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=50,
|
||||
max_new_tokens=512,
|
||||
parallel=10,
|
||||
host=f"http://{self.base_host}",
|
||||
port=int(self.base_port),
|
||||
)
|
||||
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Flush cache to force remote storage access
|
||||
print("Phase 2: Flushing device cache...")
|
||||
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
||||
time.sleep(2)
|
||||
|
||||
# Second evaluation - should use remote cache
|
||||
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
||||
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Verify accuracy consistency
|
||||
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
||||
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
||||
|
||||
# Assertions
|
||||
self.assertGreater(
|
||||
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
||||
)
|
||||
self.assertGreater(
|
||||
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
||||
)
|
||||
self.assertLess(
|
||||
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
||||
)
|
||||
run_eval_accuracy_test(self)
|
||||
|
||||
|
||||
# TODO: Add other backends tests(3fs/mooncake)
|
||||
# class TestHiCacheStorageMooncakeBackend(HiCacheStorageBaseTest):
|
||||
# """Mooncake backend tests for HiCache Storage functionality"""
|
||||
def run_eval_accuracy_test(test_instance, accuracy_threshold: float = 0.03):
|
||||
"""Generic eval accuracy test with configurable accuracy threshold
|
||||
|
||||
# @classmethod
|
||||
# def _get_additional_server_args_and_env(cls):
|
||||
# """Get additional server arguments specific to configuration - override in subclasses"""
|
||||
# server_args = ["--hicache-storage-backend", "mooncake"]
|
||||
# env = {
|
||||
# "MOONCAKE_TE_META_DATA_SERVER": "http://127.0.0.1:8080/metadata",
|
||||
# "MOONCAKE_MASTER": "127.0.0.1:50051"
|
||||
# xxxxx
|
||||
# }
|
||||
# return server_args, {}
|
||||
Args:
|
||||
test_instance: The test class instance that provides base_host, base_port, flush_cache, and assert methods
|
||||
"""
|
||||
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
||||
|
||||
# First evaluation - populate cache
|
||||
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
||||
args_initial = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=50,
|
||||
max_new_tokens=512,
|
||||
parallel=10,
|
||||
host=f"http://{test_instance.base_host}",
|
||||
port=int(test_instance.base_port),
|
||||
)
|
||||
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Flush cache to force remote storage access
|
||||
print("Phase 2: Flushing device cache...")
|
||||
test_instance.assertTrue(test_instance.flush_cache(), "Cache flush should succeed")
|
||||
time.sleep(2)
|
||||
|
||||
# Second evaluation - should use remote cache
|
||||
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
||||
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Verify accuracy consistency
|
||||
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
||||
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
||||
|
||||
# Assertions
|
||||
test_instance.assertGreater(
|
||||
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
||||
)
|
||||
test_instance.assertGreater(
|
||||
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
||||
)
|
||||
test_instance.assertLess(
|
||||
accuracy_diff,
|
||||
accuracy_threshold,
|
||||
"Accuracy should be consistent between cache states",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -19,6 +19,7 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
CustomTestCase,
|
||||
find_available_port,
|
||||
is_in_ci,
|
||||
)
|
||||
|
||||
|
||||
@@ -226,6 +227,7 @@ class TestMooncakeBackendLayerFirstLayout(
|
||||
'''
|
||||
|
||||
|
||||
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||
class TestMooncakeBackendPageFirstLayout(
|
||||
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
|
||||
):
|
||||
@@ -236,21 +238,6 @@ class TestMooncakeBackendPageFirstLayout(
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||
server_args["--hicache-mem-layout"] = "page_first"
|
||||
server_args["--hicache-io-backend"] = "kernel"
|
||||
return server_args, env_vars
|
||||
|
||||
|
||||
class TestMooncakeBackendPageFirstDirectLayout(
|
||||
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
|
||||
):
|
||||
"""Page first layout tests for HiCache-Mooncake backend"""
|
||||
|
||||
@classmethod
|
||||
def _get_additional_server_args_and_env(cls):
|
||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||
server_args["--hicache-mem-layout"] = "page_first_direct"
|
||||
server_args["--hicache-io-backend"] = "direct"
|
||||
return server_args, env_vars
|
||||
|
||||
|
||||
@@ -284,48 +271,15 @@ class TestMooncakeBackendAccuracy(
|
||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||
server_args["--hicache-ratio"] = 1.5
|
||||
server_args["--tp-size"] = 2
|
||||
server_args["--hicache-mem-layout"] = "page_first_direct"
|
||||
server_args["--hicache-io-backend"] = "direct"
|
||||
return server_args, env_vars
|
||||
|
||||
def test_eval_accuracy(self):
|
||||
"""Test eval accuracy with cache persistence across cache flushes"""
|
||||
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
||||
from test_hicache_storage_file_backend import run_eval_accuracy_test
|
||||
|
||||
# First evaluation - populate cache
|
||||
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
||||
args_initial = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=50,
|
||||
max_new_tokens=512,
|
||||
parallel=10,
|
||||
host=f"http://{self.base_host}",
|
||||
port=int(self.base_port),
|
||||
)
|
||||
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Flush cache to force remote storage access
|
||||
print("Phase 2: Flushing device cache...")
|
||||
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
||||
time.sleep(2)
|
||||
|
||||
# Second evaluation - should use remote cache
|
||||
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
||||
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
||||
|
||||
# Verify accuracy consistency
|
||||
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
||||
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
||||
|
||||
# Assertions
|
||||
self.assertGreater(
|
||||
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
||||
)
|
||||
self.assertGreater(
|
||||
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
||||
)
|
||||
self.assertLess(
|
||||
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
||||
)
|
||||
run_eval_accuracy_test(self)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -132,8 +132,8 @@ suites = {
|
||||
TestFile("test_load_weights_from_remote_instance.py", 72),
|
||||
TestFile("test_patch_torch.py", 19),
|
||||
TestFile("test_release_memory_occupation.py", 257),
|
||||
TestFile("hicache/test_hicache_storage_file_backend.py", 400),
|
||||
TestFile("hicache/test_hicache_storage_3fs_backend.py", 400),
|
||||
TestFile("hicache/test_hicache_storage_file_backend.py", 200),
|
||||
TestFile("hicache/test_hicache_storage_3fs_backend.py", 200),
|
||||
],
|
||||
"per-commit-4-gpu": [
|
||||
TestFile("test_gpt_oss_4gpu.py", 600),
|
||||
@@ -144,7 +144,7 @@ suites = {
|
||||
TestFile("test_multi_instance_release_memory_occupation.py", 64),
|
||||
],
|
||||
"per-commit-8-gpu": [
|
||||
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 800),
|
||||
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
|
||||
TestFile("lora/test_lora_llama4.py", 600),
|
||||
TestFile("test_disaggregation.py", 499),
|
||||
TestFile("test_disaggregation_dp_attention.py", 155),
|
||||
|
||||
Reference in New Issue
Block a user