[HICache]: Refactor HiCache CI (#11011)
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -62,19 +62,7 @@ class TestHf3fsBackendLayerFirstLayout(
|
|||||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||||
server_args["--hicache-mem-layout"] = "layer_first"
|
server_args["--hicache-mem-layout"] = "layer_first"
|
||||||
server_args["--hicache-io-backend"] = "direct"
|
server_args["--hicache-io-backend"] = "direct"
|
||||||
return server_args, env_vars
|
server_args["--tp-size"] = 2
|
||||||
|
|
||||||
|
|
||||||
class TestHf3fsBackendPageFirstLayout(
|
|
||||||
HiCacheStorage3FSBackendBaseMixin, CustomTestCase
|
|
||||||
):
|
|
||||||
"""Page first layout tests for HiCache-Hf3fs backend"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_additional_server_args_and_env(cls):
|
|
||||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
|
||||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
|
||||||
server_args["--hicache-mem-layout"] = "page_first"
|
|
||||||
return server_args, env_vars
|
return server_args, env_vars
|
||||||
|
|
||||||
|
|
||||||
@@ -91,44 +79,9 @@ class TestHf3fsBackendAccuracy(HiCacheStorage3FSBackendBaseMixin, CustomTestCase
|
|||||||
|
|
||||||
def test_eval_accuracy(self):
|
def test_eval_accuracy(self):
|
||||||
"""Test eval accuracy with cache persistence across cache flushes"""
|
"""Test eval accuracy with cache persistence across cache flushes"""
|
||||||
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
from test_hicache_storage_file_backend import run_eval_accuracy_test
|
||||||
|
|
||||||
# First evaluation - populate cache
|
run_eval_accuracy_test(self)
|
||||||
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
|
||||||
args_initial = SimpleNamespace(
|
|
||||||
num_shots=5,
|
|
||||||
data_path=None,
|
|
||||||
num_questions=50,
|
|
||||||
max_new_tokens=512,
|
|
||||||
parallel=10,
|
|
||||||
host=f"http://{self.base_host}",
|
|
||||||
port=int(self.base_port),
|
|
||||||
)
|
|
||||||
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
|
||||||
|
|
||||||
# Flush cache to force remote storage access
|
|
||||||
print("Phase 2: Flushing device cache...")
|
|
||||||
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
# Second evaluation - should use remote cache
|
|
||||||
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
|
||||||
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
|
||||||
|
|
||||||
# Verify accuracy consistency
|
|
||||||
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
|
||||||
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
self.assertGreater(
|
|
||||||
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
|
||||||
)
|
|
||||||
self.assertGreater(
|
|
||||||
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
|
||||||
)
|
|
||||||
self.assertLess(
|
|
||||||
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
CustomTestCase,
|
CustomTestCase,
|
||||||
|
is_in_ci,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,6 +90,7 @@ class HiCacheStorageBaseMixin:
|
|||||||
"""Launch server with HiCache enabled"""
|
"""Launch server with HiCache enabled"""
|
||||||
|
|
||||||
additional_server_args, env_vars = cls._get_additional_server_args_and_env()
|
additional_server_args, env_vars = cls._get_additional_server_args_and_env()
|
||||||
|
env_vars["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1"
|
||||||
server_args = cls._get_base_server_args()
|
server_args = cls._get_base_server_args()
|
||||||
if additional_server_args:
|
if additional_server_args:
|
||||||
server_args.update(additional_server_args)
|
server_args.update(additional_server_args)
|
||||||
@@ -215,42 +217,7 @@ class HiCacheStorageBaseMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStorageTP(HiCacheStorageBaseMixin, CustomTestCase):
|
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||||
"""Multi-TP tests for HiCache Storage functionality"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_additional_server_args_and_env(cls):
|
|
||||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
|
||||||
server_args = {"--tp-size": 2}
|
|
||||||
return server_args, {}
|
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase):
|
|
||||||
"""Layer first direct tests for HiCache Storage functionality"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_additional_server_args_and_env(cls):
|
|
||||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
|
||||||
server_args = {
|
|
||||||
"--hicache-mem-layout": "layer_first",
|
|
||||||
"--hicache-io-backend": "direct",
|
|
||||||
}
|
|
||||||
return server_args, {}
|
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStoragePageFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase):
|
|
||||||
"""Page first direct tests for HiCache Storage functionality"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_additional_server_args_and_env(cls):
|
|
||||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
|
||||||
server_args = {
|
|
||||||
"--hicache-mem-layout": "page_first_direct",
|
|
||||||
"--hicache-io-backend": "direct",
|
|
||||||
}
|
|
||||||
return server_args, {}
|
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseMixin, CustomTestCase):
|
class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseMixin, CustomTestCase):
|
||||||
"""Page first layout tests for HiCache Storage functionality"""
|
"""Page first layout tests for HiCache Storage functionality"""
|
||||||
|
|
||||||
@@ -261,6 +228,7 @@ class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseMixin, CustomTestCase)
|
|||||||
return server_args, {}
|
return server_args, {}
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||||
class TestHiCacheStorageMLA(HiCacheStorageBaseMixin, CustomTestCase):
|
class TestHiCacheStorageMLA(HiCacheStorageBaseMixin, CustomTestCase):
|
||||||
"""MLA Model tests for HiCache Storage functionality"""
|
"""MLA Model tests for HiCache Storage functionality"""
|
||||||
|
|
||||||
@@ -276,71 +244,84 @@ class TestHiCacheStorageMLA(HiCacheStorageBaseMixin, CustomTestCase):
|
|||||||
return server_args, {}
|
return server_args, {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestHiCacheStoragePageFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase):
|
||||||
|
"""Page first direct tests for HiCache Storage functionality"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_additional_server_args_and_env(cls):
|
||||||
|
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||||
|
server_args = {
|
||||||
|
"--hicache-mem-layout": "page_first_direct",
|
||||||
|
"--hicache-io-backend": "direct",
|
||||||
|
"--tp-size": 2,
|
||||||
|
}
|
||||||
|
return server_args, {}
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStorageAccuracy(HiCacheStorageBaseMixin, CustomTestCase):
|
class TestHiCacheStorageAccuracy(HiCacheStorageBaseMixin, CustomTestCase):
|
||||||
"""Accuracy tests for HiCache Storage functionality"""
|
"""Accuracy tests for HiCache Storage functionality"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_additional_server_args_and_env(cls):
|
def _get_additional_server_args_and_env(cls):
|
||||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||||
server_args = {"--tp-size": 2, "--hicache-ratio": 1.5}
|
server_args = {
|
||||||
|
"--tp-size": 2,
|
||||||
|
"--hicache-ratio": 1.5,
|
||||||
|
}
|
||||||
|
|
||||||
return server_args, {}
|
return server_args, {}
|
||||||
|
|
||||||
def test_eval_accuracy(self):
|
def test_eval_accuracy(self):
|
||||||
"""Test eval accuracy with cache persistence across cache flushes"""
|
"""Test eval accuracy with cache persistence across cache flushes"""
|
||||||
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
run_eval_accuracy_test(self)
|
||||||
|
|
||||||
# First evaluation - populate cache
|
|
||||||
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
|
||||||
args_initial = SimpleNamespace(
|
|
||||||
num_shots=5,
|
|
||||||
data_path=None,
|
|
||||||
num_questions=50,
|
|
||||||
max_new_tokens=512,
|
|
||||||
parallel=10,
|
|
||||||
host=f"http://{self.base_host}",
|
|
||||||
port=int(self.base_port),
|
|
||||||
)
|
|
||||||
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
|
||||||
|
|
||||||
# Flush cache to force remote storage access
|
|
||||||
print("Phase 2: Flushing device cache...")
|
|
||||||
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
# Second evaluation - should use remote cache
|
|
||||||
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
|
||||||
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
|
||||||
|
|
||||||
# Verify accuracy consistency
|
|
||||||
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
|
||||||
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
self.assertGreater(
|
|
||||||
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
|
||||||
)
|
|
||||||
self.assertGreater(
|
|
||||||
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
|
||||||
)
|
|
||||||
self.assertLess(
|
|
||||||
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Add other backends tests(3fs/mooncake)
|
def run_eval_accuracy_test(test_instance, accuracy_threshold: float = 0.03):
|
||||||
# class TestHiCacheStorageMooncakeBackend(HiCacheStorageBaseTest):
|
"""Generic eval accuracy test with configurable accuracy threshold
|
||||||
# """Mooncake backend tests for HiCache Storage functionality"""
|
|
||||||
|
|
||||||
# @classmethod
|
Args:
|
||||||
# def _get_additional_server_args_and_env(cls):
|
test_instance: The test class instance that provides base_host, base_port, flush_cache, and assert methods
|
||||||
# """Get additional server arguments specific to configuration - override in subclasses"""
|
"""
|
||||||
# server_args = ["--hicache-storage-backend", "mooncake"]
|
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
||||||
# env = {
|
|
||||||
# "MOONCAKE_TE_META_DATA_SERVER": "http://127.0.0.1:8080/metadata",
|
# First evaluation - populate cache
|
||||||
# "MOONCAKE_MASTER": "127.0.0.1:50051"
|
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
||||||
# xxxxx
|
args_initial = SimpleNamespace(
|
||||||
# }
|
num_shots=5,
|
||||||
# return server_args, {}
|
data_path=None,
|
||||||
|
num_questions=50,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=10,
|
||||||
|
host=f"http://{test_instance.base_host}",
|
||||||
|
port=int(test_instance.base_port),
|
||||||
|
)
|
||||||
|
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
||||||
|
|
||||||
|
# Flush cache to force remote storage access
|
||||||
|
print("Phase 2: Flushing device cache...")
|
||||||
|
test_instance.assertTrue(test_instance.flush_cache(), "Cache flush should succeed")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Second evaluation - should use remote cache
|
||||||
|
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
||||||
|
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
||||||
|
|
||||||
|
# Verify accuracy consistency
|
||||||
|
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
||||||
|
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
test_instance.assertGreater(
|
||||||
|
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
||||||
|
)
|
||||||
|
test_instance.assertGreater(
|
||||||
|
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
||||||
|
)
|
||||||
|
test_instance.assertLess(
|
||||||
|
accuracy_diff,
|
||||||
|
accuracy_threshold,
|
||||||
|
"Accuracy should be consistent between cache states",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||||
CustomTestCase,
|
CustomTestCase,
|
||||||
find_available_port,
|
find_available_port,
|
||||||
|
is_in_ci,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -226,6 +227,7 @@ class TestMooncakeBackendLayerFirstLayout(
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||||
class TestMooncakeBackendPageFirstLayout(
|
class TestMooncakeBackendPageFirstLayout(
|
||||||
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
|
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
|
||||||
):
|
):
|
||||||
@@ -236,21 +238,6 @@ class TestMooncakeBackendPageFirstLayout(
|
|||||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||||
server_args["--hicache-mem-layout"] = "page_first"
|
server_args["--hicache-mem-layout"] = "page_first"
|
||||||
server_args["--hicache-io-backend"] = "kernel"
|
|
||||||
return server_args, env_vars
|
|
||||||
|
|
||||||
|
|
||||||
class TestMooncakeBackendPageFirstDirectLayout(
|
|
||||||
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
|
|
||||||
):
|
|
||||||
"""Page first layout tests for HiCache-Mooncake backend"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_additional_server_args_and_env(cls):
|
|
||||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
|
||||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
|
||||||
server_args["--hicache-mem-layout"] = "page_first_direct"
|
|
||||||
server_args["--hicache-io-backend"] = "direct"
|
|
||||||
return server_args, env_vars
|
return server_args, env_vars
|
||||||
|
|
||||||
|
|
||||||
@@ -284,48 +271,15 @@ class TestMooncakeBackendAccuracy(
|
|||||||
server_args, env_vars = super()._get_additional_server_args_and_env()
|
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||||
server_args["--hicache-ratio"] = 1.5
|
server_args["--hicache-ratio"] = 1.5
|
||||||
server_args["--tp-size"] = 2
|
server_args["--tp-size"] = 2
|
||||||
|
server_args["--hicache-mem-layout"] = "page_first_direct"
|
||||||
|
server_args["--hicache-io-backend"] = "direct"
|
||||||
return server_args, env_vars
|
return server_args, env_vars
|
||||||
|
|
||||||
def test_eval_accuracy(self):
|
def test_eval_accuracy(self):
|
||||||
"""Test eval accuracy with cache persistence across cache flushes"""
|
"""Test eval accuracy with cache persistence across cache flushes"""
|
||||||
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
from test_hicache_storage_file_backend import run_eval_accuracy_test
|
||||||
|
|
||||||
# First evaluation - populate cache
|
run_eval_accuracy_test(self)
|
||||||
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
|
||||||
args_initial = SimpleNamespace(
|
|
||||||
num_shots=5,
|
|
||||||
data_path=None,
|
|
||||||
num_questions=50,
|
|
||||||
max_new_tokens=512,
|
|
||||||
parallel=10,
|
|
||||||
host=f"http://{self.base_host}",
|
|
||||||
port=int(self.base_port),
|
|
||||||
)
|
|
||||||
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
|
||||||
|
|
||||||
# Flush cache to force remote storage access
|
|
||||||
print("Phase 2: Flushing device cache...")
|
|
||||||
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
# Second evaluation - should use remote cache
|
|
||||||
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
|
||||||
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
|
||||||
|
|
||||||
# Verify accuracy consistency
|
|
||||||
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
|
||||||
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
self.assertGreater(
|
|
||||||
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
|
||||||
)
|
|
||||||
self.assertGreater(
|
|
||||||
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
|
||||||
)
|
|
||||||
self.assertLess(
|
|
||||||
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -132,8 +132,8 @@ suites = {
|
|||||||
TestFile("test_load_weights_from_remote_instance.py", 72),
|
TestFile("test_load_weights_from_remote_instance.py", 72),
|
||||||
TestFile("test_patch_torch.py", 19),
|
TestFile("test_patch_torch.py", 19),
|
||||||
TestFile("test_release_memory_occupation.py", 257),
|
TestFile("test_release_memory_occupation.py", 257),
|
||||||
TestFile("hicache/test_hicache_storage_file_backend.py", 400),
|
TestFile("hicache/test_hicache_storage_file_backend.py", 200),
|
||||||
TestFile("hicache/test_hicache_storage_3fs_backend.py", 400),
|
TestFile("hicache/test_hicache_storage_3fs_backend.py", 200),
|
||||||
],
|
],
|
||||||
"per-commit-4-gpu": [
|
"per-commit-4-gpu": [
|
||||||
TestFile("test_gpt_oss_4gpu.py", 600),
|
TestFile("test_gpt_oss_4gpu.py", 600),
|
||||||
@@ -144,7 +144,7 @@ suites = {
|
|||||||
TestFile("test_multi_instance_release_memory_occupation.py", 64),
|
TestFile("test_multi_instance_release_memory_occupation.py", 64),
|
||||||
],
|
],
|
||||||
"per-commit-8-gpu": [
|
"per-commit-8-gpu": [
|
||||||
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 800),
|
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
|
||||||
TestFile("lora/test_lora_llama4.py", 600),
|
TestFile("lora/test_lora_llama4.py", 600),
|
||||||
TestFile("test_disaggregation.py", 499),
|
TestFile("test_disaggregation.py", 499),
|
||||||
TestFile("test_disaggregation_dp_attention.py", 155),
|
TestFile("test_disaggregation_dp_attention.py", 155),
|
||||||
|
|||||||
Reference in New Issue
Block a user