Hicache Storage Layer Prototype (#7704)
This commit is contained in:
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -159,6 +161,57 @@ class TransferBuffer:
|
|||||||
self.buffers.queue.clear()
|
self.buffers.queue.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class StorageOperation:
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host_indices: torch.Tensor,
|
||||||
|
token_ids: List[int],
|
||||||
|
last_hash: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.host_indices = host_indices
|
||||||
|
self.token_ids = token_ids
|
||||||
|
self.last_hash = last_hash
|
||||||
|
self.completed_tokens = 0
|
||||||
|
self.hash_value = []
|
||||||
|
|
||||||
|
self.id = StorageOperation.counter
|
||||||
|
StorageOperation.counter += 1
|
||||||
|
|
||||||
|
def __lt__(self, other: "StorageOperation"):
|
||||||
|
return self.id < other.id
|
||||||
|
|
||||||
|
|
||||||
|
class PrefetchOperation(StorageOperation):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
host_indices: torch.Tensor,
|
||||||
|
token_ids: List[int],
|
||||||
|
last_hash: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
self._done_flag = False
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
super().__init__(host_indices, token_ids, last_hash)
|
||||||
|
|
||||||
|
def increment(self, num_tokens: int):
|
||||||
|
with self._lock:
|
||||||
|
if self._done_flag:
|
||||||
|
return
|
||||||
|
self.completed_tokens += num_tokens
|
||||||
|
|
||||||
|
def mark_done(self):
|
||||||
|
with self._lock:
|
||||||
|
self._done_flag = True
|
||||||
|
|
||||||
|
def is_done(self) -> bool:
|
||||||
|
return self._done_flag
|
||||||
|
|
||||||
|
|
||||||
class HiCacheController:
|
class HiCacheController:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -169,6 +222,8 @@ class HiCacheController:
|
|||||||
load_cache_event: threading.Event = None,
|
load_cache_event: threading.Event = None,
|
||||||
write_policy: str = "write_through_selective",
|
write_policy: str = "write_through_selective",
|
||||||
io_backend: str = "",
|
io_backend: str = "",
|
||||||
|
storage_backend: Optional[str] = None,
|
||||||
|
prefetch_threshold: int = 256,
|
||||||
):
|
):
|
||||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||||
@@ -186,6 +241,19 @@ class HiCacheController:
|
|||||||
else:
|
else:
|
||||||
self.io_backend = io_backend
|
self.io_backend = io_backend
|
||||||
|
|
||||||
|
self.enable_storage = False
|
||||||
|
# todo: move backend initialization to storage backend module
|
||||||
|
if storage_backend is not None:
|
||||||
|
if storage_backend == "file":
|
||||||
|
self.storage_backend = HiCacheFile()
|
||||||
|
self.enable_storage = True
|
||||||
|
# todo: threshold policy for prefetching
|
||||||
|
self.prefetch_threshold = prefetch_threshold
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unsupported storage backend: {storage_backend}"
|
||||||
|
)
|
||||||
|
|
||||||
self.load_cache_event = load_cache_event
|
self.load_cache_event = load_cache_event
|
||||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||||
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
||||||
@@ -218,9 +286,26 @@ class HiCacheController:
|
|||||||
self.load_thread = threading.Thread(
|
self.load_thread = threading.Thread(
|
||||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||||
)
|
)
|
||||||
|
|
||||||
self.write_thread.start()
|
self.write_thread.start()
|
||||||
self.load_thread.start()
|
self.load_thread.start()
|
||||||
|
|
||||||
|
if self.enable_storage:
|
||||||
|
self.prefetch_thread = threading.Thread(
|
||||||
|
target=self.prefetch_thread_func, daemon=True
|
||||||
|
)
|
||||||
|
self.backup_thread = threading.Thread(
|
||||||
|
target=self.backup_thread_func, daemon=True
|
||||||
|
)
|
||||||
|
self.prefetch_queue = Queue()
|
||||||
|
self.backup_queue = Queue()
|
||||||
|
|
||||||
|
self.prefetch_revoke_queue = Queue()
|
||||||
|
self.ack_backup_queue = Queue()
|
||||||
|
|
||||||
|
self.prefetch_thread.start()
|
||||||
|
self.backup_thread.start()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.stop_event.set()
|
self.stop_event.set()
|
||||||
self.write_thread.join()
|
self.write_thread.join()
|
||||||
@@ -232,6 +317,13 @@ class HiCacheController:
|
|||||||
self.load_buffer.clear()
|
self.load_buffer.clear()
|
||||||
self.ack_write_queue.queue.clear()
|
self.ack_write_queue.queue.clear()
|
||||||
self.ack_load_queue.queue.clear()
|
self.ack_load_queue.queue.clear()
|
||||||
|
if self.enable_storage:
|
||||||
|
self.prefetch_thread.join()
|
||||||
|
self.backup_thread.join()
|
||||||
|
self.prefetch_queue.queue.clear()
|
||||||
|
self.backup_queue.queue.clear()
|
||||||
|
self.prefetch_revoke_queue.queue.clear()
|
||||||
|
self.ack_backup_queue.queue.clear()
|
||||||
|
|
||||||
self.write_thread = threading.Thread(
|
self.write_thread = threading.Thread(
|
||||||
target=self.write_thread_func_direct, daemon=True
|
target=self.write_thread_func_direct, daemon=True
|
||||||
@@ -243,6 +335,16 @@ class HiCacheController:
|
|||||||
self.write_thread.start()
|
self.write_thread.start()
|
||||||
self.load_thread.start()
|
self.load_thread.start()
|
||||||
|
|
||||||
|
if self.enable_storage:
|
||||||
|
self.prefetch_thread = threading.Thread(
|
||||||
|
target=self.prefetch_thread_func, daemon=True
|
||||||
|
)
|
||||||
|
self.backup_thread = threading.Thread(
|
||||||
|
target=self.backup_thread_func, daemon=True
|
||||||
|
)
|
||||||
|
self.prefetch_thread.start()
|
||||||
|
self.backup_thread.start()
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
self,
|
self,
|
||||||
device_indices: torch.Tensor,
|
device_indices: torch.Tensor,
|
||||||
@@ -383,3 +485,142 @@ class HiCacheController:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prefetch(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
host_indices: torch.Tensor,
|
||||||
|
new_input_tokens: List[int],
|
||||||
|
last_hash: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Prefetch KV caches from storage backend to host memory.
|
||||||
|
"""
|
||||||
|
operation = PrefetchOperation(
|
||||||
|
request_id, host_indices, new_input_tokens, last_hash
|
||||||
|
)
|
||||||
|
self.prefetch_queue.put(operation)
|
||||||
|
return operation
|
||||||
|
|
||||||
|
def terminate_prefetch(self, operation):
|
||||||
|
operation.mark_done()
|
||||||
|
return operation.completed_tokens, operation.hash_value
|
||||||
|
|
||||||
|
def prefetch_io_aux_func(self):
|
||||||
|
"""
|
||||||
|
Auxiliary function conducting IO operations for prefetching.
|
||||||
|
"""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||||
|
for h in operation.hash_value:
|
||||||
|
page_data = self.storage_backend.get(h)
|
||||||
|
if page_data is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
self.mem_pool_host.set_from_flat_data_page(
|
||||||
|
operation.host_indices[operation.completed_tokens],
|
||||||
|
page_data,
|
||||||
|
)
|
||||||
|
operation.increment(self.page_size)
|
||||||
|
if operation.is_done():
|
||||||
|
# operation terminated by controller, release pre-allocated memory
|
||||||
|
self.mem_pool_host.free(
|
||||||
|
operation.host_indices[operation.completed_tokens :]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def prefetch_thread_func(self):
|
||||||
|
"""
|
||||||
|
Manage prefetching operations from storage backend to host memory.
|
||||||
|
"""
|
||||||
|
self.prefetch_buffer = Queue()
|
||||||
|
aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
|
||||||
|
aux_thread.start()
|
||||||
|
while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
|
||||||
|
try:
|
||||||
|
operation = self.prefetch_queue.get(block=True, timeout=1)
|
||||||
|
if operation is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
last_hash = operation.last_hash
|
||||||
|
tokens_to_fetch = operation.token_ids
|
||||||
|
|
||||||
|
storage_hit_count = 0
|
||||||
|
remaining_tokens = len(tokens_to_fetch)
|
||||||
|
hash_value = []
|
||||||
|
while remaining_tokens >= self.page_size:
|
||||||
|
last_hash = get_hash_str(
|
||||||
|
tokens_to_fetch[
|
||||||
|
storage_hit_count : storage_hit_count + self.page_size
|
||||||
|
],
|
||||||
|
last_hash,
|
||||||
|
)
|
||||||
|
if self.storage_backend.exists(last_hash):
|
||||||
|
storage_hit_count += self.page_size
|
||||||
|
hash_value.append(last_hash)
|
||||||
|
remaining_tokens -= self.page_size
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if storage_hit_count < self.prefetch_threshold:
|
||||||
|
# not to prefetch if not enough benefits
|
||||||
|
self.prefetch_revoke_queue.put(operation.request_id)
|
||||||
|
else:
|
||||||
|
operation.hash_value = hash_value
|
||||||
|
logger.debug(
|
||||||
|
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
|
||||||
|
)
|
||||||
|
self.prefetch_buffer.put(operation)
|
||||||
|
|
||||||
|
except Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def write_storage(
|
||||||
|
self,
|
||||||
|
host_indices: torch.Tensor,
|
||||||
|
token_ids: List[int],
|
||||||
|
last_hash: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Write KV caches from host memory to storage backend.
|
||||||
|
"""
|
||||||
|
operation = StorageOperation(host_indices, token_ids, last_hash)
|
||||||
|
self.backup_queue.put(operation)
|
||||||
|
return operation.id
|
||||||
|
|
||||||
|
def backup_thread_func(self):
|
||||||
|
"""
|
||||||
|
Manage backup operations from host memory to storage backend.
|
||||||
|
"""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
operation = self.backup_queue.get(block=True, timeout=1)
|
||||||
|
if operation is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
last_hash = operation.last_hash
|
||||||
|
tokens_to_backup = operation.token_ids
|
||||||
|
|
||||||
|
for i in range(0, len(tokens_to_backup), self.page_size):
|
||||||
|
last_hash = get_hash_str(
|
||||||
|
tokens_to_backup[i : i + self.page_size], last_hash
|
||||||
|
)
|
||||||
|
# todo, handle failures in storage backend
|
||||||
|
self.storage_backend.set(
|
||||||
|
last_hash,
|
||||||
|
self.mem_pool_host.get_flat_data_page(
|
||||||
|
operation.host_indices[i]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
operation.completed_tokens += self.page_size
|
||||||
|
operation.hash_value.append(last_hash)
|
||||||
|
|
||||||
|
self.ack_backup_queue.put((operation.id, operation.hash_value))
|
||||||
|
|
||||||
|
except Empty:
|
||||||
|
continue
|
||||||
|
|||||||
@@ -262,6 +262,7 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
||||||
|
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
|
||||||
self.page_size = server_args.page_size
|
self.page_size = server_args.page_size
|
||||||
self.dp_size = server_args.dp_size
|
self.dp_size = server_args.dp_size
|
||||||
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
||||||
@@ -614,6 +615,7 @@ class Scheduler(
|
|||||||
== "fa3" # hot fix for incompatibility
|
== "fa3" # hot fix for incompatibility
|
||||||
else server_args.hicache_io_backend
|
else server_args.hicache_io_backend
|
||||||
),
|
),
|
||||||
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||||
)
|
)
|
||||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||||
self.tree_cache.cache_controller.layer_done_counter
|
self.tree_cache.cache_controller.layer_done_counter
|
||||||
@@ -1258,6 +1260,15 @@ class Scheduler(
|
|||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.disagg_decode_prealloc_queue.add(req)
|
self.disagg_decode_prealloc_queue.add(req)
|
||||||
else:
|
else:
|
||||||
|
if self.enable_hicache_storage:
|
||||||
|
req.init_next_round_input(self.tree_cache)
|
||||||
|
last_hash = req.last_host_node.get_last_hash_value()
|
||||||
|
matched_len = len(req.prefix_indices) + req.host_hit_length
|
||||||
|
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
|
||||||
|
new_input_tokens = req.fill_ids[matched_len:]
|
||||||
|
self.tree_cache.prefetch_from_storage(
|
||||||
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
||||||
|
)
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||||
@@ -1731,6 +1742,9 @@ class Scheduler(
|
|||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if self.enable_hicache_storage:
|
||||||
|
self.tree_cache.check_prefetch_progress(req.rid)
|
||||||
|
|
||||||
req.init_next_round_input(self.tree_cache)
|
req.init_next_round_input(self.tree_cache)
|
||||||
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
||||||
|
|
||||||
|
|||||||
152
python/sglang/srt/mem_cache/hicache_storage.py
Normal file
152
python/sglang/srt/mem_cache/hicache_storage.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
||||||
|
hasher = hashlib.sha256()
|
||||||
|
|
||||||
|
if prior_hash:
|
||||||
|
hasher.update(bytes.fromhex(prior_hash))
|
||||||
|
|
||||||
|
for t in token_ids:
|
||||||
|
hasher.update(t.to_bytes(4, byteorder="little", signed=False))
|
||||||
|
|
||||||
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class HiCacheStorage(ABC):
|
||||||
|
"""
|
||||||
|
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
||||||
|
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# todo, translate tensor object access for different TP ranks
|
||||||
|
# potentially pass model and TP configs into storage backend
|
||||||
|
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
"""
|
||||||
|
Retrieve the value associated with the given key.
|
||||||
|
Returns None if the key does not exist.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def batch_get(
|
||||||
|
self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
|
||||||
|
) -> List[torch.Tensor | None]:
|
||||||
|
"""
|
||||||
|
Retrieve values for multiple keys.
|
||||||
|
Returns a list of tensors or None for each key.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set(self, key, value) -> bool:
|
||||||
|
"""
|
||||||
|
Store the value associated with the given key.
|
||||||
|
Returns True if the operation was successful, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||||
|
"""
|
||||||
|
Store multiple key-value pairs.
|
||||||
|
Returns True if all operations were successful, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def exists(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the key exists in the storage.
|
||||||
|
Returns True if the key exists, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HiCacheFile(HiCacheStorage):
|
||||||
|
|
||||||
|
def __init__(self, file_path: str = "/tmp/hicache"):
|
||||||
|
self.file_path = file_path
|
||||||
|
if not os.path.exists(self.file_path):
|
||||||
|
os.makedirs(self.file_path)
|
||||||
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
|
try:
|
||||||
|
# todo: fixing the target_location logic to enable in-place loading
|
||||||
|
loaded_tensor = torch.load(tensor_path)
|
||||||
|
if isinstance(loaded_tensor, torch.Tensor):
|
||||||
|
return loaded_tensor
|
||||||
|
else:
|
||||||
|
logger.error(f"Loaded data for key {key} is not a tensor.")
|
||||||
|
return None
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def batch_get(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
target_locations: Optional[List[torch.Tensor]] = None,
|
||||||
|
) -> List[torch.Tensor | None]:
|
||||||
|
return [
|
||||||
|
self.get(key, target_location)
|
||||||
|
for key, target_location in zip(
|
||||||
|
keys, target_locations or [None] * len(keys)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def set(self, key: str, value: torch.Tensor) -> bool:
|
||||||
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
|
if self.exists(key):
|
||||||
|
logger.debug(f"Key {key} already exists. Skipped.")
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
torch.save(value, tensor_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save tensor {key}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
if not self.set(key, value):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def exists(self, key: str) -> bool:
|
||||||
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
|
return os.path.exists(tensor_path)
|
||||||
|
|
||||||
|
def delete(self, key: str) -> None:
|
||||||
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||||
|
try:
|
||||||
|
os.remove(tensor_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
||||||
|
return
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
try:
|
||||||
|
for filename in os.listdir(self.file_path):
|
||||||
|
file_path = os.path.join(self.file_path, filename)
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
logger.info("Cleared all entries in HiCacheFile storage.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to clear HiCacheFile storage: {e}")
|
||||||
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
|
|||||||
hicache_size: int,
|
hicache_size: int,
|
||||||
hicache_write_policy: str,
|
hicache_write_policy: str,
|
||||||
hicache_io_backend: str,
|
hicache_io_backend: str,
|
||||||
|
hicache_storage_backend: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||||
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
|
|||||||
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||||
|
|
||||||
self.tp_group = tp_cache_group
|
self.tp_group = tp_cache_group
|
||||||
|
self.enable_storage = hicache_storage_backend is not None
|
||||||
|
# todo: customizable storage prefetch threshold
|
||||||
|
self.prefetch_threshold = 256
|
||||||
|
|
||||||
self.load_cache_event = threading.Event()
|
self.load_cache_event = threading.Event()
|
||||||
self.cache_controller = HiCacheController(
|
self.cache_controller = HiCacheController(
|
||||||
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
|
|||||||
load_cache_event=self.load_cache_event,
|
load_cache_event=self.load_cache_event,
|
||||||
write_policy=hicache_write_policy,
|
write_policy=hicache_write_policy,
|
||||||
io_backend=hicache_io_backend,
|
io_backend=hicache_io_backend,
|
||||||
|
storage_backend=hicache_storage_backend,
|
||||||
|
prefetch_threshold=self.prefetch_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
# record the nodes with ongoing write through
|
# record the nodes with ongoing write through
|
||||||
self.ongoing_write_through = {}
|
self.ongoing_write_through = {}
|
||||||
# record the node segments with ongoing load back
|
# record the node segments with ongoing load back
|
||||||
self.ongoing_load_back = {}
|
self.ongoing_load_back = {}
|
||||||
|
# record the ongoing prefetch requests
|
||||||
|
self.ongoing_prefetch = {}
|
||||||
|
self.ongoing_backup = {}
|
||||||
# todo: dynamically adjust the threshold
|
# todo: dynamically adjust the threshold
|
||||||
self.write_through_threshold = (
|
self.write_through_threshold = (
|
||||||
1 if hicache_write_policy == "write_through" else 3
|
1 if hicache_write_policy == "write_through" else 3
|
||||||
)
|
)
|
||||||
|
self.write_through_threshold_storage = 3
|
||||||
self.load_back_threshold = 10
|
self.load_back_threshold = 10
|
||||||
super().__init__(
|
super().__init__(
|
||||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||||
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
return len(host_indices)
|
return len(host_indices)
|
||||||
|
|
||||||
|
def write_backup_storage(self, node: TreeNode):
|
||||||
|
operation_id = self.cache_controller.write_storage(
|
||||||
|
node.host_value, node.key, node.parent.get_last_hash_value()
|
||||||
|
)
|
||||||
|
self.ongoing_backup[operation_id] = node
|
||||||
|
node.protect_host()
|
||||||
|
|
||||||
def inc_hit_count(self, node: TreeNode):
|
def inc_hit_count(self, node: TreeNode):
|
||||||
if node.backuped or self.cache_controller.write_policy == "write_back":
|
if self.cache_controller.write_policy == "write_back":
|
||||||
return
|
return
|
||||||
node.hit_count += 1
|
node.hit_count += 1
|
||||||
if node.hit_count >= self.write_through_threshold:
|
|
||||||
self.write_backup(node)
|
if not node.backuped:
|
||||||
node.hit_count = 0
|
if node.hit_count >= self.write_through_threshold:
|
||||||
|
# write to host if the node is not backuped
|
||||||
|
self.write_backup(node)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
self.enable_storage
|
||||||
|
and (not node.backuped_storage)
|
||||||
|
and node.hit_count >= self.write_through_threshold_storage
|
||||||
|
):
|
||||||
|
# if the node is backuped on host memory but not on storage
|
||||||
|
self.write_backup_storage(node)
|
||||||
|
|
||||||
def writing_check(self, write_back=False):
|
def writing_check(self, write_back=False):
|
||||||
if write_back:
|
if write_back:
|
||||||
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
|
|||||||
if not x.evicted:
|
if not x.evicted:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# node is protected from eviction as it has ongoing prefetch or backup to storage
|
||||||
|
if x.host_ref_counter > 0:
|
||||||
|
continue
|
||||||
|
|
||||||
num_evicted += self.cache_controller.evict_host(x.host_value)
|
num_evicted += self.cache_controller.evict_host(x.host_value)
|
||||||
|
|
||||||
for k, v in x.parent.children.items():
|
for k, v in x.parent.children.items():
|
||||||
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
|
|||||||
def check_hicache_events(self):
|
def check_hicache_events(self):
|
||||||
self.writing_check()
|
self.writing_check()
|
||||||
self.loading_check()
|
self.loading_check()
|
||||||
|
if self.enable_storage:
|
||||||
|
self.check_revoked_prefetch()
|
||||||
|
self.check_backup_progress()
|
||||||
|
|
||||||
|
def check_revoked_prefetch(self):
|
||||||
|
queue_size = torch.tensor(
|
||||||
|
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
||||||
|
)
|
||||||
|
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||||
|
# synchrnoize TP workers to make the same update to hiradix cache
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
queue_size,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=self.tp_group,
|
||||||
|
)
|
||||||
|
for _ in range(queue_size.item()):
|
||||||
|
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
||||||
|
if req_id in self.ongoing_prefetch:
|
||||||
|
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
|
||||||
|
last_host_node.release_host()
|
||||||
|
self.cache_controller.mem_pool_host.free(host_indices)
|
||||||
|
del self.ongoing_prefetch[req_id]
|
||||||
|
|
||||||
|
def check_backup_progress(self):
|
||||||
|
queue_size = torch.tensor(
|
||||||
|
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
||||||
|
)
|
||||||
|
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||||
|
# synchrnoize TP workers to make the same update to hiradix cache
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
queue_size,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=self.tp_group,
|
||||||
|
)
|
||||||
|
for _ in range(queue_size.item()):
|
||||||
|
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
|
||||||
|
self.ongoing_backup[ack_id].hash_value = hash_value
|
||||||
|
self.ongoing_backup[ack_id].release_host()
|
||||||
|
del self.ongoing_backup[ack_id]
|
||||||
|
|
||||||
|
def check_prefetch_progress(self, req_id: str):
|
||||||
|
if req_id not in self.ongoing_prefetch:
|
||||||
|
# there is no ongoing prefetch for this request or it has been revoked
|
||||||
|
return
|
||||||
|
|
||||||
|
# todo: more policies for prefetch progress such as timeout
|
||||||
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
||||||
|
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
||||||
|
req_id
|
||||||
|
]
|
||||||
|
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
||||||
|
operation
|
||||||
|
)
|
||||||
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||||
|
|
||||||
|
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
|
||||||
|
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||||
|
# synchrnoize TP workers to make the same update to hiradix cache
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
min_completed_tokens,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=self.tp_group,
|
||||||
|
)
|
||||||
|
min_completed_tokens = min_completed_tokens.item()
|
||||||
|
fetched_token_ids = token_ids[:min_completed_tokens]
|
||||||
|
written_indices = host_indices[:min_completed_tokens]
|
||||||
|
matched_length = self._insert_helper_host(
|
||||||
|
last_host_node,
|
||||||
|
fetched_token_ids,
|
||||||
|
written_indices,
|
||||||
|
hash_value[:min_completed_tokens],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||||
|
self.cache_controller.mem_pool_host.free(
|
||||||
|
host_indices[min_completed_tokens:completed_tokens]
|
||||||
|
)
|
||||||
|
last_host_node.release_host()
|
||||||
|
del self.ongoing_prefetch[req_id]
|
||||||
|
|
||||||
def match_prefix(self, key: List[int], **kwargs):
|
def match_prefix(self, key: List[int], **kwargs):
|
||||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||||
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
|
|||||||
host_hit_length=host_hit_length,
|
host_hit_length=host_hit_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prefetch_from_storage(
|
||||||
|
self,
|
||||||
|
req_id: str,
|
||||||
|
last_host_node: TreeNode,
|
||||||
|
new_input_tokens: List[int],
|
||||||
|
last_hash: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
|
||||||
|
return
|
||||||
|
|
||||||
|
last_host_node.protect_host()
|
||||||
|
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
|
||||||
|
if host_indices is None:
|
||||||
|
self.evict_host(len(new_input_tokens))
|
||||||
|
host_indices = self.cache_controller.mem_pool_host.alloc(
|
||||||
|
len(new_input_tokens)
|
||||||
|
)
|
||||||
|
if host_indices is None:
|
||||||
|
last_host_node.release_host()
|
||||||
|
# no sufficient host memory to prefetch
|
||||||
|
return
|
||||||
|
operation = self.cache_controller.prefetch(
|
||||||
|
req_id, host_indices, new_input_tokens, last_hash
|
||||||
|
)
|
||||||
|
self.ongoing_prefetch[req_id] = (
|
||||||
|
last_host_node,
|
||||||
|
new_input_tokens,
|
||||||
|
host_indices,
|
||||||
|
operation,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
||||||
|
node.last_access_time = time.monotonic()
|
||||||
|
if len(key) == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
child_key = self.get_child_key_fn(key)
|
||||||
|
|
||||||
|
matched_length = 0
|
||||||
|
while len(key) > 0 and child_key in node.children.keys():
|
||||||
|
node = node.children[child_key]
|
||||||
|
node.last_access_time = time.monotonic()
|
||||||
|
prefix_len = self.key_match_fn(node.key, key)
|
||||||
|
key = key[prefix_len:]
|
||||||
|
host_value = host_value[prefix_len:]
|
||||||
|
hash_value = hash_value[prefix_len:]
|
||||||
|
matched_length += prefix_len
|
||||||
|
|
||||||
|
if prefix_len < len(node.key):
|
||||||
|
new_node = self._split_node(node.key, node, prefix_len)
|
||||||
|
node = new_node
|
||||||
|
|
||||||
|
if len(key):
|
||||||
|
child_key = self.get_child_key_fn(key)
|
||||||
|
|
||||||
|
if len(key):
|
||||||
|
new_node = TreeNode()
|
||||||
|
new_node.parent = node
|
||||||
|
new_node.key = key
|
||||||
|
new_node.value = None
|
||||||
|
new_node.host_value = host_value
|
||||||
|
new_node.hash_value = hash_value
|
||||||
|
node.children[child_key] = new_node
|
||||||
|
return matched_length
|
||||||
|
|
||||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||||
node.last_access_time = time.monotonic()
|
node.last_access_time = time.monotonic()
|
||||||
child_key = self.get_child_key_fn(key)
|
child_key = self.get_child_key_fn(key)
|
||||||
|
|||||||
@@ -99,6 +99,20 @@ class HostKVCache(abc.ABC):
|
|||||||
def init_kv_buffer(self):
|
def init_kv_buffer(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get a flat data page from the host memory pool.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Set a flat data page to the host memory pool.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@synchronized()
|
@synchronized()
|
||||||
def clear(self):
|
def clear(self):
|
||||||
# Initialize memory states and tracking structures.
|
# Initialize memory states and tracking structures.
|
||||||
@@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# todo, page first memory layout
|
||||||
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||||
|
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
||||||
|
|
||||||
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||||
|
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
|
||||||
|
2,
|
||||||
|
self.layer_num,
|
||||||
|
self.page_size,
|
||||||
|
self.head_num,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def k_buffer(self):
|
def k_buffer(self):
|
||||||
return self.kv_buffer[0]
|
return self.kv_buffer[0]
|
||||||
@@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||||
|
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
||||||
|
|
||||||
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||||
|
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
||||||
|
self.layer_num,
|
||||||
|
self.page_size,
|
||||||
|
1,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
|||||||
@@ -55,8 +55,13 @@ class TreeNode:
|
|||||||
self.hit_count = 0
|
self.hit_count = 0
|
||||||
# indicating the node is loading KV cache from host
|
# indicating the node is loading KV cache from host
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
# indicating the node is locked to protect from eviction
|
||||||
|
# incremented when the node is referenced by a storage operation
|
||||||
|
self.host_ref_counter = 0
|
||||||
# store the host indices of KV cache
|
# store the host indices of KV cache
|
||||||
self.host_value: Optional[torch.Tensor] = None
|
self.host_value: Optional[torch.Tensor] = None
|
||||||
|
# store hash values of each pages
|
||||||
|
self.hash_value: Optional[List[str]] = None
|
||||||
|
|
||||||
self.id = TreeNode.counter if id is None else id
|
self.id = TreeNode.counter if id is None else id
|
||||||
TreeNode.counter += 1
|
TreeNode.counter += 1
|
||||||
@@ -69,6 +74,27 @@ class TreeNode:
|
|||||||
def backuped(self):
|
def backuped(self):
|
||||||
return self.host_value is not None
|
return self.host_value is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def backuped_storage(self):
|
||||||
|
return self.hash_value is not None and len(self.hash_value) > 0
|
||||||
|
|
||||||
|
def protect_host(self):
|
||||||
|
"""Protect the host value from eviction."""
|
||||||
|
self.host_ref_counter += 1
|
||||||
|
|
||||||
|
def release_host(self):
|
||||||
|
"""Release the host value, allowing it to be evicted."""
|
||||||
|
if self.host_ref_counter > 0:
|
||||||
|
self.host_ref_counter -= 1
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Host reference counter is already zero.")
|
||||||
|
|
||||||
|
def get_last_hash_value(self) -> Optional[str]:
|
||||||
|
"""Returns the hash value of the last page in this node."""
|
||||||
|
if self.hash_value is None or len(self.hash_value) == 0:
|
||||||
|
return None
|
||||||
|
return self.hash_value[-1]
|
||||||
|
|
||||||
def __lt__(self, other: "TreeNode"):
|
def __lt__(self, other: "TreeNode"):
|
||||||
return self.last_access_time < other.last_access_time
|
return self.last_access_time < other.last_access_time
|
||||||
|
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ class ServerArgs:
|
|||||||
hicache_size: int = 0
|
hicache_size: int = 0
|
||||||
hicache_write_policy: str = "write_through_selective"
|
hicache_write_policy: str = "write_through_selective"
|
||||||
hicache_io_backend: str = ""
|
hicache_io_backend: str = ""
|
||||||
|
hicache_storage_backend: Optional[str] = None
|
||||||
flashinfer_mla_disable_ragged: bool = False
|
flashinfer_mla_disable_ragged: bool = False
|
||||||
disable_shared_experts_fusion: bool = False
|
disable_shared_experts_fusion: bool = False
|
||||||
disable_chunked_prefix_cache: bool = False
|
disable_chunked_prefix_cache: bool = False
|
||||||
@@ -1604,6 +1605,13 @@ class ServerArgs:
|
|||||||
default=ServerArgs.hicache_io_backend,
|
default=ServerArgs.hicache_io_backend,
|
||||||
help="The IO backend for KV cache transfer between CPU and GPU",
|
help="The IO backend for KV cache transfer between CPU and GPU",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hicache-storage-backend",
|
||||||
|
type=str,
|
||||||
|
choices=["file"], # todo, mooncacke
|
||||||
|
default=ServerArgs.hicache_storage_backend,
|
||||||
|
help="The storage backend for hierarchical KV cache.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--flashinfer-mla-disable-ragged",
|
"--flashinfer-mla-disable-ragged",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ suites = {
|
|||||||
TestFile("test_fused_moe.py", 30),
|
TestFile("test_fused_moe.py", 30),
|
||||||
TestFile("test_hicache.py", 116),
|
TestFile("test_hicache.py", 116),
|
||||||
TestFile("test_hicache_mla.py", 127),
|
TestFile("test_hicache_mla.py", 127),
|
||||||
|
TestFile("test_hicache_storage.py", 127),
|
||||||
TestFile("test_hidden_states.py", 55),
|
TestFile("test_hidden_states.py", 55),
|
||||||
TestFile("test_int8_kernel.py", 8),
|
TestFile("test_int8_kernel.py", 8),
|
||||||
TestFile("test_input_embeddings.py", 38),
|
TestFile("test_input_embeddings.py", 38),
|
||||||
|
|||||||
55
test/srt/test_hicache_storage.py
Normal file
55
test/srt/test_hicache_storage.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHiCache(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--enable-hierarchical-cache",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.7,
|
||||||
|
"--hicache-size",
|
||||||
|
100,
|
||||||
|
"--page-size",
|
||||||
|
"64",
|
||||||
|
"--hicache-storage-backend",
|
||||||
|
"file",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user