193 lines
6.9 KiB
Python
193 lines
6.9 KiB
Python
"""
|
|
Benchmark tests for HiCache Storage functionality.
|
|
Usage:
|
|
python3 -m pytest test/srt/hicache/test_hicache_storage_benchmark.py -v
|
|
"""
|
|
|
|
import time
|
|
import unittest
|
|
from types import SimpleNamespace
|
|
from typing import Dict
|
|
|
|
import requests
|
|
from test_hicache_storage_e2e import HiCacheStorageBaseTest
|
|
|
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
|
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
|
|
|
|
|
class TestHiCacheStorageBenchmark(HiCacheStorageBaseTest):
|
|
"""Benchmark tests for HiCache Storage functionality"""
|
|
|
|
@classmethod
|
|
def _get_additional_server_args_and_env(cls):
|
|
"""Get additional server arguments specific to configuration - override in subclasses"""
|
|
server_args = {"--tp-size": 2, "--hicache-ratio": 1.5}
|
|
return server_args, {}
|
|
|
|
def flush_cache(self) -> bool:
|
|
"""Flush device cache to force remote storage access"""
|
|
try:
|
|
response = requests.post(f"{self.base_url}/flush_cache", timeout=10)
|
|
return response.status_code == 200
|
|
except requests.RequestException:
|
|
return False
|
|
|
|
# === Accuracy Tests ===
|
|
def test_eval_accuracy_with_cache_persistence(self):
|
|
"""Test eval accuracy with cache persistence across cache flushes"""
|
|
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
|
|
|
# First evaluation - populate cache
|
|
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
|
args_initial = SimpleNamespace(
|
|
num_shots=5,
|
|
data_path=None,
|
|
num_questions=400,
|
|
max_new_tokens=512,
|
|
parallel=32,
|
|
host=f"http://{self.base_host}",
|
|
port=int(self.base_port),
|
|
)
|
|
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
|
print(f"Evaluation metrics: {metrics_initial}")
|
|
self.assertGreater(metrics_initial["accuracy"], 0.60)
|
|
|
|
# Flush cache to force remote storage access
|
|
print("Phase 2: Flushing device cache...")
|
|
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
|
time.sleep(2)
|
|
|
|
# Second evaluation - should use remote cache
|
|
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
|
|
|
start_time = time.time()
|
|
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
|
cached_time = time.time() - start_time
|
|
|
|
print(f"Cached evaluation completed in {cached_time:.2f}s")
|
|
print(f"Cached accuracy: {metrics_cached['accuracy']:.3f}")
|
|
print(f"Cached throughput: {metrics_cached['output_throughput']:.2f} token/s")
|
|
|
|
# Verify accuracy consistency
|
|
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
|
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
|
|
|
# Assertions
|
|
self.assertGreater(
|
|
metrics_initial["accuracy"], 0.5, "Initial accuracy should be reasonable"
|
|
)
|
|
self.assertGreater(
|
|
metrics_cached["accuracy"], 0.5, "Cached accuracy should be reasonable"
|
|
)
|
|
self.assertLess(
|
|
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
|
)
|
|
|
|
# Performance should be similar or better with cache
|
|
throughput_ratio = (
|
|
metrics_cached["output_throughput"] / metrics_initial["output_throughput"]
|
|
)
|
|
print(f"Throughput ratio (cached/initial): {throughput_ratio:.2f}")
|
|
|
|
if is_in_ci():
|
|
write_github_step_summary(
|
|
f"### HiCache Storage Accuracy Test\n"
|
|
f"Initial accuracy: {metrics_initial['accuracy']:.3f}\n"
|
|
f"Cached accuracy: {metrics_cached['accuracy']:.3f}\n"
|
|
f"Accuracy difference: {accuracy_diff:.4f}\n"
|
|
f"Throughput ratio: {throughput_ratio:.2f}\n"
|
|
)
|
|
|
|
# === Performance Benchmark Tests ===
|
|
|
|
def test_throughput_benchmark_with_hicache(self):
|
|
"""Benchmark throughput performance with HiCache enabled"""
|
|
print("\n=== Benchmarking Throughput with HiCache ===")
|
|
|
|
# throughput test
|
|
res1 = self._run_throughput_benchmark(
|
|
test_name="hicache_offline_throughput",
|
|
num_prompts=200,
|
|
request_rate=10,
|
|
additional_args=[],
|
|
)
|
|
|
|
# Flush cache to force remote storage access
|
|
print("Phase 2: Flushing device cache...")
|
|
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
|
time.sleep(2)
|
|
|
|
# Second benchmark, should use remote cache
|
|
res2 = self._run_throughput_benchmark(
|
|
test_name="hicache_online_throughput",
|
|
num_prompts=400,
|
|
request_rate=10,
|
|
additional_args=[],
|
|
)
|
|
|
|
if is_in_ci():
|
|
write_github_step_summary(
|
|
f"### HiCache Storage FileBackend Benchmark Test\n"
|
|
f"First time throughput: {res1['input_throughput']:.2f} token/s\n"
|
|
f"Second time throughput: {res2['input_throughput']:.2f} token/s\n"
|
|
f"First time TTFT: {res1['mean_ttft_ms']:.2f} ms\n"
|
|
f"Second time TTFT: {res2['mean_ttft_ms']:.2f} ms\n"
|
|
)
|
|
|
|
def _run_throughput_benchmark(
|
|
self,
|
|
test_name: str,
|
|
num_prompts: int,
|
|
request_rate: float,
|
|
dataset_name: str = "random",
|
|
additional_args: list = None,
|
|
) -> Dict:
|
|
"""Helper method to run throughput benchmarks"""
|
|
if additional_args is None:
|
|
additional_args = []
|
|
|
|
print(f"Running {test_name} benchmark...")
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Use the existing server instead of launching a new one
|
|
from sglang.bench_serving import run_benchmark
|
|
from sglang.test.test_utils import get_benchmark_args
|
|
|
|
args = get_benchmark_args(
|
|
base_url=self.base_url,
|
|
dataset_name=dataset_name,
|
|
tokenizer=self.model,
|
|
num_prompts=num_prompts,
|
|
request_rate=request_rate,
|
|
random_input_len=1024,
|
|
random_output_len=64,
|
|
)
|
|
|
|
# Run benchmark
|
|
result = run_benchmark(args)
|
|
|
|
elapsed_time = time.time() - start_time
|
|
print(f"{test_name} completed in {elapsed_time:.2f}s")
|
|
print(
|
|
f"Output throughput: {result.get('output_throughput', 0.0):.2f} token/s"
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"Benchmark {test_name} failed: {e}")
|
|
# Fallback to avoid hard failure; return minimal metrics
|
|
return {
|
|
"output_throughput": 0.0,
|
|
"input_throughput": 0.0,
|
|
"mean_ttft_ms": float("inf"),
|
|
"mean_latency_ms": float("inf"),
|
|
"p99_ttft_ms": float("inf"),
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|