[Hicache]: Add E2E CI For 3FS-KVStore (#10131)
This commit is contained in:
164
python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py
Normal file
164
python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Hf3fsClient(ABC):
|
||||||
|
"""Abstract interface for HF3FS clients."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
||||||
|
"""Initialize the HF3FS client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path for storage
|
||||||
|
size: Total size of storage file
|
||||||
|
bytes_per_page: Bytes per page
|
||||||
|
entries: Number of entries for batch operations
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
||||||
|
"""Batch read from storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
||||||
|
"""Batch write to storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
||||||
|
"""Validate batch operation parameters."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_size(self) -> int:
|
||||||
|
"""Get total storage size."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client and cleanup resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def flush(self) -> None:
|
||||||
|
"""Flush data to disk."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Hf3fsMockClient(Hf3fsClient):
|
||||||
|
"""Mock implementation of Hf3fsClient for CI testing purposes."""
|
||||||
|
|
||||||
|
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
||||||
|
"""Initialize mock HF3FS client."""
|
||||||
|
self.path = path
|
||||||
|
self.size = size
|
||||||
|
self.bytes_per_page = bytes_per_page
|
||||||
|
self.entries = entries
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
||||||
|
|
||||||
|
# Create and initialize the file
|
||||||
|
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
|
||||||
|
os.ftruncate(self.file, size)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Hf3fsMockClient initialized: path={path}, size={size}, "
|
||||||
|
f"bytes_per_page={bytes_per_page}, entries={entries}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
||||||
|
"""Batch read from mock storage."""
|
||||||
|
self.check(offsets, tensors)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for offset, tensor in zip(offsets, tensors):
|
||||||
|
size = tensor.numel() * tensor.itemsize
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.lseek(self.file, offset, os.SEEK_SET)
|
||||||
|
bytes_read = os.read(self.file, size)
|
||||||
|
|
||||||
|
if len(bytes_read) == size:
|
||||||
|
# Convert bytes to tensor and copy to target
|
||||||
|
bytes_tensor = torch.frombuffer(bytes_read, dtype=torch.uint8)
|
||||||
|
typed_tensor = bytes_tensor.view(tensor.dtype).view(tensor.shape)
|
||||||
|
tensor.copy_(typed_tensor)
|
||||||
|
results.append(size)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Short read: expected {size}, got {len(bytes_read)}"
|
||||||
|
)
|
||||||
|
results.append(len(bytes_read))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading from offset {offset}: {e}")
|
||||||
|
results.append(0)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
||||||
|
"""Batch write to mock storage."""
|
||||||
|
self.check(offsets, tensors)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for offset, tensor in zip(offsets, tensors):
|
||||||
|
size = tensor.numel() * tensor.itemsize
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert tensor to bytes and write directly to file
|
||||||
|
tensor_bytes = tensor.contiguous().view(torch.uint8).flatten()
|
||||||
|
data = tensor_bytes.numpy().tobytes()
|
||||||
|
|
||||||
|
os.lseek(self.file, offset, os.SEEK_SET)
|
||||||
|
bytes_written = os.write(self.file, data)
|
||||||
|
|
||||||
|
if bytes_written == size:
|
||||||
|
results.append(size)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Short write: expected {size}, got {bytes_written}")
|
||||||
|
results.append(bytes_written)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing to offset {offset}: {e}")
|
||||||
|
results.append(0)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
||||||
|
"""Validate batch operation parameters."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_size(self) -> int:
|
||||||
|
"""Get total storage size."""
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the mock client and cleanup resources."""
|
||||||
|
try:
|
||||||
|
if hasattr(self, "file") and self.file >= 0:
|
||||||
|
os.close(self.file)
|
||||||
|
self.file = -1 # Mark as closed
|
||||||
|
logger.info(f"MockHf3fsClient closed: {self.path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing MockHf3fsClient: {e}")
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
"""Flush data to disk."""
|
||||||
|
try:
|
||||||
|
os.fsync(self.file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error flushing MockHf3fsClient: {e}")
|
||||||
@@ -9,6 +9,8 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.cpp_extension import load
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
||||||
|
|
||||||
root = Path(__file__).parent.resolve()
|
root = Path(__file__).parent.resolve()
|
||||||
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
||||||
|
|
||||||
@@ -51,7 +53,9 @@ def wsynchronized():
|
|||||||
return _decorator
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
class Hf3fsClient:
|
class Hf3fsUsrBioClient(Hf3fsClient):
|
||||||
|
"""HF3FS client implementation using usrbio."""
|
||||||
|
|
||||||
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
||||||
if not HF3FS_AVAILABLE:
|
if not HF3FS_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -13,7 +13,7 @@ from typing import Any, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
|
||||||
from sglang.srt.metrics.collector import StorageMetrics
|
from sglang.srt.metrics.collector import StorageMetrics
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -114,6 +114,33 @@ def synchronized():
|
|||||||
return _decorator
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
def create_hf3fs_client(
|
||||||
|
path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False
|
||||||
|
) -> Hf3fsClient:
|
||||||
|
"""Factory function to create appropriate HF3FS client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path for storage
|
||||||
|
size: Total size of storage file
|
||||||
|
bytes_per_page: Bytes per page
|
||||||
|
entries: Number of entries for batch operations
|
||||||
|
use_mock: Whether to use mock client instead of real usrbio client
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
if use_mock:
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsMockClient
|
||||||
|
|
||||||
|
logger.info(f"[Rank Using Hf3fsMockClient for testing")
|
||||||
|
return Hf3fsMockClient(path, size, bytes_per_page, entries)
|
||||||
|
else:
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import (
|
||||||
|
Hf3fsUsrBioClient,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Hf3fsUsrBioClient(path, size, bytes_per_page, entries)
|
||||||
|
|
||||||
|
|
||||||
class HiCacheHF3FS(HiCacheStorage):
|
class HiCacheHF3FS(HiCacheStorage):
|
||||||
"""HiCache backend that stores KV cache pages in HF3FS files."""
|
"""HiCache backend that stores KV cache pages in HF3FS files."""
|
||||||
|
|
||||||
@@ -131,6 +158,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
metadata_client: Hf3fsMetadataInterface,
|
metadata_client: Hf3fsMetadataInterface,
|
||||||
is_mla_model: bool = False,
|
is_mla_model: bool = False,
|
||||||
is_page_first_layout: bool = False,
|
is_page_first_layout: bool = False,
|
||||||
|
use_mock_client: bool = False,
|
||||||
):
|
):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
@@ -159,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
|
|
||||||
self.ac = AtomicCounter(self.numjobs)
|
self.ac = AtomicCounter(self.numjobs)
|
||||||
self.clients = [
|
self.clients = [
|
||||||
Hf3fsClient(
|
create_hf3fs_client(
|
||||||
self.file_path, self.file_size, self.bytes_per_page, self.entries
|
self.file_path,
|
||||||
|
self.file_size,
|
||||||
|
self.bytes_per_page,
|
||||||
|
self.entries,
|
||||||
|
use_mock_client,
|
||||||
)
|
)
|
||||||
for _ in range(numjobs)
|
for _ in range(numjobs)
|
||||||
]
|
]
|
||||||
@@ -202,14 +234,24 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
Hf3fsLocalMetadataClient,
|
Hf3fsLocalMetadataClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_mock_client = False
|
||||||
if storage_config is not None:
|
if storage_config is not None:
|
||||||
rank, is_mla_model, is_page_first_layout = (
|
rank, is_mla_model, is_page_first_layout = (
|
||||||
storage_config.tp_rank,
|
storage_config.tp_rank,
|
||||||
storage_config.is_mla_model,
|
storage_config.is_mla_model,
|
||||||
storage_config.is_page_first_layout,
|
storage_config.is_page_first_layout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if storage_config.extra_config is not None:
|
||||||
|
use_mock_client = storage_config.extra_config.get(
|
||||||
|
"use_mock_hf3fs_client", False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
rank, is_mla_model, is_page_first_layout = 0, False, False
|
rank, is_mla_model, is_page_first_layout = (
|
||||||
|
0,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
||||||
|
|
||||||
@@ -228,6 +270,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
metadata_client=Hf3fsLocalMetadataClient(),
|
metadata_client=Hf3fsLocalMetadataClient(),
|
||||||
is_page_first_layout=is_page_first_layout,
|
is_page_first_layout=is_page_first_layout,
|
||||||
|
use_mock_client=use_mock_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -277,6 +320,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
metadata_client=metadata_client,
|
metadata_client=metadata_client,
|
||||||
is_mla_model=is_mla_model,
|
is_mla_model=is_mla_model,
|
||||||
is_page_first_layout=is_page_first_layout,
|
is_page_first_layout=is_page_first_layout,
|
||||||
|
use_mock_client=use_mock_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
|
|||||||
135
test/srt/hicache/test_hicache_storage_3fs_backend.py
Normal file
135
test/srt/hicache/test_hicache_storage_3fs_backend.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
Benchmark tests for HiCache Storage with 3FS backend.
|
||||||
|
Usage:
|
||||||
|
python3 -m pytest test/srt/hicache/test_hicache_storage_3fs_backend.py -v
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from test_hicache_storage_file_backend import HiCacheStorageBaseMixin
|
||||||
|
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class HiCacheStorage3FSBackendBaseMixin(HiCacheStorageBaseMixin):
|
||||||
|
"""Base mixin class with common setup and utilities"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_additional_server_args_and_env(cls):
|
||||||
|
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||||
|
# Create a temporary JSON config file for HF3FS
|
||||||
|
hf3fs_config = {
|
||||||
|
"file_path_prefix": os.path.join(cls.temp_dir, "hicache"),
|
||||||
|
"file_size": 1024 * 1024 * 1024 * 2,
|
||||||
|
"numjobs": 2,
|
||||||
|
"entries": 8,
|
||||||
|
"use_mock_hf3fs_client": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Write config to temporary file
|
||||||
|
config_file = os.path.join(cls.temp_dir, "hf3fs_config.json")
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
json.dump(hf3fs_config, f, indent=2)
|
||||||
|
|
||||||
|
server_args = {
|
||||||
|
"--tp-size": 1,
|
||||||
|
"--hicache-ratio": 1.2,
|
||||||
|
"--hicache-storage-backend": "hf3fs",
|
||||||
|
"--hicache-storage-backend-extra-config": json.dumps(hf3fs_config),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set the environment variable to point to our config file
|
||||||
|
env_vars = {
|
||||||
|
"SGLANG_HICACHE_HF3FS_CONFIG_PATH": config_file,
|
||||||
|
}
|
||||||
|
|
||||||
|
return server_args, env_vars
|
||||||
|
|
||||||
|
|
||||||
|
class TestHf3fsBackendLayerFirstLayout(
|
||||||
|
HiCacheStorage3FSBackendBaseMixin, CustomTestCase
|
||||||
|
):
|
||||||
|
"""Layer first layout tests for HiCache-Hf3fs backend"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_additional_server_args_and_env(cls):
|
||||||
|
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||||
|
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||||
|
server_args["--hicache-mem-layout"] = "layer_first"
|
||||||
|
server_args["--hicache-io-backend"] = "direct"
|
||||||
|
return server_args, env_vars
|
||||||
|
|
||||||
|
|
||||||
|
class TestHf3fsBackendPageFirstLayout(
|
||||||
|
HiCacheStorage3FSBackendBaseMixin, CustomTestCase
|
||||||
|
):
|
||||||
|
"""Page first layout tests for HiCache-Hf3fs backend"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_additional_server_args_and_env(cls):
|
||||||
|
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||||
|
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||||
|
server_args["--hicache-mem-layout"] = "page_first"
|
||||||
|
return server_args, env_vars
|
||||||
|
|
||||||
|
|
||||||
|
class TestHf3fsBackendAccuracy(HiCacheStorage3FSBackendBaseMixin, CustomTestCase):
|
||||||
|
"""Accuracy tests for HiCache-Hf3fs backend"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_additional_server_args_and_env(cls):
|
||||||
|
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||||
|
server_args, env_vars = super()._get_additional_server_args_and_env()
|
||||||
|
server_args["--hicache-ratio"] = 1.5
|
||||||
|
server_args["--tp-size"] = 2
|
||||||
|
return server_args, env_vars
|
||||||
|
|
||||||
|
def test_eval_accuracy(self):
|
||||||
|
"""Test eval accuracy with cache persistence across cache flushes"""
|
||||||
|
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
||||||
|
|
||||||
|
# First evaluation - populate cache
|
||||||
|
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
||||||
|
args_initial = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=50,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=10,
|
||||||
|
host=f"http://{self.base_host}",
|
||||||
|
port=int(self.base_port),
|
||||||
|
)
|
||||||
|
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
||||||
|
|
||||||
|
# Flush cache to force remote storage access
|
||||||
|
print("Phase 2: Flushing device cache...")
|
||||||
|
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Second evaluation - should use remote cache
|
||||||
|
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
||||||
|
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
||||||
|
|
||||||
|
# Verify accuracy consistency
|
||||||
|
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
||||||
|
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
self.assertGreater(
|
||||||
|
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
||||||
|
)
|
||||||
|
self.assertGreater(
|
||||||
|
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
||||||
|
)
|
||||||
|
self.assertLess(
|
||||||
|
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
@@ -1,192 +0,0 @@
|
|||||||
"""
|
|
||||||
Benchmark tests for HiCache Storage functionality.
|
|
||||||
Usage:
|
|
||||||
python3 -m pytest test/srt/hicache/test_hicache_storage_benchmark.py -v
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import unittest
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from test_hicache_storage_e2e import HiCacheStorageBaseTest
|
|
||||||
|
|
||||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
|
||||||
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStorageBenchmark(HiCacheStorageBaseTest):
|
|
||||||
"""Benchmark tests for HiCache Storage functionality"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_additional_server_args_and_env(cls):
|
|
||||||
"""Get additional server arguments specific to configuration - override in subclasses"""
|
|
||||||
server_args = {"--tp-size": 2, "--hicache-ratio": 1.5}
|
|
||||||
return server_args, {}
|
|
||||||
|
|
||||||
def flush_cache(self) -> bool:
|
|
||||||
"""Flush device cache to force remote storage access"""
|
|
||||||
try:
|
|
||||||
response = requests.post(f"{self.base_url}/flush_cache", timeout=10)
|
|
||||||
return response.status_code == 200
|
|
||||||
except requests.RequestException:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# === Accuracy Tests ===
|
|
||||||
def test_eval_accuracy_with_cache_persistence(self):
|
|
||||||
"""Test eval accuracy with cache persistence across cache flushes"""
|
|
||||||
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
|
||||||
|
|
||||||
# First evaluation - populate cache
|
|
||||||
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
|
||||||
args_initial = SimpleNamespace(
|
|
||||||
num_shots=5,
|
|
||||||
data_path=None,
|
|
||||||
num_questions=400,
|
|
||||||
max_new_tokens=512,
|
|
||||||
parallel=32,
|
|
||||||
host=f"http://{self.base_host}",
|
|
||||||
port=int(self.base_port),
|
|
||||||
)
|
|
||||||
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
|
||||||
print(f"Evaluation metrics: {metrics_initial}")
|
|
||||||
self.assertGreater(metrics_initial["accuracy"], 0.60)
|
|
||||||
|
|
||||||
# Flush cache to force remote storage access
|
|
||||||
print("Phase 2: Flushing device cache...")
|
|
||||||
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
# Second evaluation - should use remote cache
|
|
||||||
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
|
||||||
cached_time = time.time() - start_time
|
|
||||||
|
|
||||||
print(f"Cached evaluation completed in {cached_time:.2f}s")
|
|
||||||
print(f"Cached accuracy: {metrics_cached['accuracy']:.3f}")
|
|
||||||
print(f"Cached throughput: {metrics_cached['output_throughput']:.2f} token/s")
|
|
||||||
|
|
||||||
# Verify accuracy consistency
|
|
||||||
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
|
||||||
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
self.assertGreater(
|
|
||||||
metrics_initial["accuracy"], 0.5, "Initial accuracy should be reasonable"
|
|
||||||
)
|
|
||||||
self.assertGreater(
|
|
||||||
metrics_cached["accuracy"], 0.5, "Cached accuracy should be reasonable"
|
|
||||||
)
|
|
||||||
self.assertLess(
|
|
||||||
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Performance should be similar or better with cache
|
|
||||||
throughput_ratio = (
|
|
||||||
metrics_cached["output_throughput"] / metrics_initial["output_throughput"]
|
|
||||||
)
|
|
||||||
print(f"Throughput ratio (cached/initial): {throughput_ratio:.2f}")
|
|
||||||
|
|
||||||
if is_in_ci():
|
|
||||||
write_github_step_summary(
|
|
||||||
f"### HiCache Storage Accuracy Test\n"
|
|
||||||
f"Initial accuracy: {metrics_initial['accuracy']:.3f}\n"
|
|
||||||
f"Cached accuracy: {metrics_cached['accuracy']:.3f}\n"
|
|
||||||
f"Accuracy difference: {accuracy_diff:.4f}\n"
|
|
||||||
f"Throughput ratio: {throughput_ratio:.2f}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# === Performance Benchmark Tests ===
|
|
||||||
|
|
||||||
def test_throughput_benchmark_with_hicache(self):
|
|
||||||
"""Benchmark throughput performance with HiCache enabled"""
|
|
||||||
print("\n=== Benchmarking Throughput with HiCache ===")
|
|
||||||
|
|
||||||
# throughput test
|
|
||||||
res1 = self._run_throughput_benchmark(
|
|
||||||
test_name="hicache_offline_throughput",
|
|
||||||
num_prompts=200,
|
|
||||||
request_rate=10,
|
|
||||||
additional_args=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flush cache to force remote storage access
|
|
||||||
print("Phase 2: Flushing device cache...")
|
|
||||||
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
# Second benchmark, should use remote cache
|
|
||||||
res2 = self._run_throughput_benchmark(
|
|
||||||
test_name="hicache_online_throughput",
|
|
||||||
num_prompts=400,
|
|
||||||
request_rate=10,
|
|
||||||
additional_args=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_in_ci():
|
|
||||||
write_github_step_summary(
|
|
||||||
f"### HiCache Storage FileBackend Benchmark Test\n"
|
|
||||||
f"First time throughput: {res1['input_throughput']:.2f} token/s\n"
|
|
||||||
f"Second time throughput: {res2['input_throughput']:.2f} token/s\n"
|
|
||||||
f"First time TTFT: {res1['mean_ttft_ms']:.2f} ms\n"
|
|
||||||
f"Second time TTFT: {res2['mean_ttft_ms']:.2f} ms\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_throughput_benchmark(
|
|
||||||
self,
|
|
||||||
test_name: str,
|
|
||||||
num_prompts: int,
|
|
||||||
request_rate: float,
|
|
||||||
dataset_name: str = "random",
|
|
||||||
additional_args: list = None,
|
|
||||||
) -> Dict:
|
|
||||||
"""Helper method to run throughput benchmarks"""
|
|
||||||
if additional_args is None:
|
|
||||||
additional_args = []
|
|
||||||
|
|
||||||
print(f"Running {test_name} benchmark...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use the existing server instead of launching a new one
|
|
||||||
from sglang.bench_serving import run_benchmark
|
|
||||||
from sglang.test.test_utils import get_benchmark_args
|
|
||||||
|
|
||||||
args = get_benchmark_args(
|
|
||||||
base_url=self.base_url,
|
|
||||||
dataset_name=dataset_name,
|
|
||||||
tokenizer=self.model,
|
|
||||||
num_prompts=num_prompts,
|
|
||||||
request_rate=request_rate,
|
|
||||||
random_input_len=1024,
|
|
||||||
random_output_len=64,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run benchmark
|
|
||||||
result = run_benchmark(args)
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
print(f"{test_name} completed in {elapsed_time:.2f}s")
|
|
||||||
print(
|
|
||||||
f"Output throughput: {result.get('output_throughput', 0.0):.2f} token/s"
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Benchmark {test_name} failed: {e}")
|
|
||||||
# Fallback to avoid hard failure; return minimal metrics
|
|
||||||
return {
|
|
||||||
"output_throughput": 0.0,
|
|
||||||
"input_throughput": 0.0,
|
|
||||||
"mean_ttft_ms": float("inf"),
|
|
||||||
"mean_latency_ms": float("inf"),
|
|
||||||
"p99_ttft_ms": float("inf"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main(verbosity=2)
|
|
||||||
@@ -9,6 +9,7 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ import requests
|
|||||||
|
|
||||||
from sglang.bench_serving import get_tokenizer
|
from sglang.bench_serving import get_tokenizer
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
@@ -26,8 +28,8 @@ from sglang.test.test_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HiCacheStorageBaseTest(CustomTestCase):
|
class HiCacheStorageBaseMixin:
|
||||||
"""Base test class with common setup and utilities"""
|
"""Base mixin class with common setup and utilities"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -166,11 +168,7 @@ class HiCacheStorageBaseTest(CustomTestCase):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def gen_prompt(self, token_num: int) -> str:
|
def gen_prompt(self, token_num: int) -> str:
|
||||||
"""Generate a random prompt of specified token length using tokenizer vocabulary.
|
"""Generate a random prompt of specified token length using tokenizer vocabulary."""
|
||||||
|
|
||||||
This function mimics the implementation from bench_serving.py to create
|
|
||||||
realistic prompts for testing cache behavior.
|
|
||||||
"""
|
|
||||||
all_available_tokens = list(self.tokenizer.get_vocab().values())
|
all_available_tokens = list(self.tokenizer.get_vocab().values())
|
||||||
selected_tokens = random.choices(all_available_tokens, k=token_num)
|
selected_tokens = random.choices(all_available_tokens, k=token_num)
|
||||||
return self.tokenizer.decode(selected_tokens)
|
return self.tokenizer.decode(selected_tokens)
|
||||||
@@ -201,10 +199,9 @@ class HiCacheStorageBaseTest(CustomTestCase):
|
|||||||
|
|
||||||
# Second request with extended prompt - should hit remote cache
|
# Second request with extended prompt - should hit remote cache
|
||||||
print("Step 2: Testing cache hit from remote storage...")
|
print("Step 2: Testing cache hit from remote storage...")
|
||||||
extended_prompt = base_prompt + "\n\n" + self.gen_prompt(64)
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
response2 = self.send_request(extended_prompt, max_tokens=150)
|
response2 = self.send_request(base_prompt, max_tokens=150)
|
||||||
retrieval_time = time.time() - start_time
|
retrieval_time = time.time() - start_time
|
||||||
|
|
||||||
cached_tokens = self.get_cached_tokens(response2)
|
cached_tokens = self.get_cached_tokens(response2)
|
||||||
@@ -213,12 +210,12 @@ class HiCacheStorageBaseTest(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assert cached tokens indicate a remote hit
|
# Assert cached tokens indicate a remote hit
|
||||||
self.assertEqual(
|
self.assertGreater(
|
||||||
cached_tokens, 768, "Expected significant cached tokens for remote hit"
|
cached_tokens, 700, "Expected significant cached tokens for remote hit"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStorageTP(HiCacheStorageBaseTest):
|
class TestHiCacheStorageTP(HiCacheStorageBaseMixin, CustomTestCase):
|
||||||
"""Multi-TP tests for HiCache Storage functionality"""
|
"""Multi-TP tests for HiCache Storage functionality"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -228,7 +225,7 @@ class TestHiCacheStorageTP(HiCacheStorageBaseTest):
|
|||||||
return server_args, {}
|
return server_args, {}
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseTest):
|
class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase):
|
||||||
"""Layer first direct tests for HiCache Storage functionality"""
|
"""Layer first direct tests for HiCache Storage functionality"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -241,7 +238,7 @@ class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseTest):
|
|||||||
return server_args, {}
|
return server_args, {}
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseTest):
|
class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseMixin, CustomTestCase):
|
||||||
"""Page first layout tests for HiCache Storage functionality"""
|
"""Page first layout tests for HiCache Storage functionality"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -251,7 +248,7 @@ class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseTest):
|
|||||||
return server_args, {}
|
return server_args, {}
|
||||||
|
|
||||||
|
|
||||||
class TestHiCacheStorageMLA(HiCacheStorageBaseTest):
|
class TestHiCacheStorageMLA(HiCacheStorageBaseMixin, CustomTestCase):
|
||||||
"""MLA Model tests for HiCache Storage functionality"""
|
"""MLA Model tests for HiCache Storage functionality"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -266,6 +263,57 @@ class TestHiCacheStorageMLA(HiCacheStorageBaseTest):
|
|||||||
return server_args, {}
|
return server_args, {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestHiCacheStorageAccuracy(HiCacheStorageBaseMixin, CustomTestCase):
|
||||||
|
"""Accuracy tests for HiCache Storage functionality"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_additional_server_args_and_env(cls):
|
||||||
|
"""Get additional server arguments specific to configuration - override in subclasses"""
|
||||||
|
server_args = {"--tp-size": 2, "--hicache-ratio": 1.5}
|
||||||
|
return server_args, {}
|
||||||
|
|
||||||
|
def test_eval_accuracy(self):
|
||||||
|
"""Test eval accuracy with cache persistence across cache flushes"""
|
||||||
|
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
|
||||||
|
|
||||||
|
# First evaluation - populate cache
|
||||||
|
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
|
||||||
|
args_initial = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=50,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=10,
|
||||||
|
host=f"http://{self.base_host}",
|
||||||
|
port=int(self.base_port),
|
||||||
|
)
|
||||||
|
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
|
||||||
|
|
||||||
|
# Flush cache to force remote storage access
|
||||||
|
print("Phase 2: Flushing device cache...")
|
||||||
|
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Second evaluation - should use remote cache
|
||||||
|
print("Phase 3: Running second GSM8K evaluation using remote cache...")
|
||||||
|
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
|
||||||
|
|
||||||
|
# Verify accuracy consistency
|
||||||
|
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
|
||||||
|
print(f"Accuracy difference: {accuracy_diff:.4f}")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
self.assertGreater(
|
||||||
|
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
|
||||||
|
)
|
||||||
|
self.assertGreater(
|
||||||
|
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
|
||||||
|
)
|
||||||
|
self.assertLess(
|
||||||
|
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Add other backends tests(3fs/mooncake)
|
# TODO: Add other backends tests(3fs/mooncake)
|
||||||
# class TestHiCacheStorageMooncakeBackend(HiCacheStorageBaseTest):
|
# class TestHiCacheStorageMooncakeBackend(HiCacheStorageBaseTest):
|
||||||
# """Mooncake backend tests for HiCache Storage functionality"""
|
# """Mooncake backend tests for HiCache Storage functionality"""
|
||||||
@@ -125,8 +125,8 @@ suites = {
|
|||||||
TestFile("test_dp_attention.py", 277),
|
TestFile("test_dp_attention.py", 277),
|
||||||
TestFile("test_patch_torch.py", 19),
|
TestFile("test_patch_torch.py", 19),
|
||||||
TestFile("test_release_memory_occupation.py", 127),
|
TestFile("test_release_memory_occupation.py", 127),
|
||||||
TestFile("hicache/test_hicache_storage_e2e.py", 400),
|
TestFile("hicache/test_hicache_storage_file_backend.py", 400),
|
||||||
TestFile("hicache/test_hicache_storage_benchmark.py", 400),
|
TestFile("hicache/test_hicache_storage_3fs_backend.py", 400),
|
||||||
],
|
],
|
||||||
"per-commit-4-gpu": [
|
"per-commit-4-gpu": [
|
||||||
TestFile("test_gpt_oss_4gpu.py", 600),
|
TestFile("test_gpt_oss_4gpu.py", 600),
|
||||||
|
|||||||
Reference in New Issue
Block a user