Add storage read/write bandwidth logs to monitor kvcache performance (#9965)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
pansicheng
2025-09-06 07:52:55 +08:00
committed by GitHub
parent efb0de2c8d
commit f84db115b1
6 changed files with 174 additions and 3 deletions

View File

@@ -128,6 +128,9 @@ class HiCacheStorage(ABC):
return i
return len(keys)
def get_stats(self):
return None
class HiCacheFile(HiCacheStorage):

View File

@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
MLATokenToKVPoolHost,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.metrics.collector import StorageMetricsCollector
logger = logging.getLogger(__name__)
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
hicache_write_policy: str,
hicache_io_backend: str,
hicache_mem_layout: str,
enable_metrics: bool,
hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
model_name: Optional[str] = None,
@@ -73,6 +75,8 @@ class HiRadixCache(RadixCache):
self.tp_group = tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.enable_storage = hicache_storage_backend is not None
self.enable_storage_metrics = self.enable_storage and enable_metrics
# todo: customizable storage prefetch threshold and timeout
self.prefetch_threshold = 256
self.prefetch_timeout = 3 # seconds
@@ -92,6 +96,14 @@ class HiRadixCache(RadixCache):
model_name=model_name,
storage_backend_extra_config=storage_backend_extra_config,
)
if self.enable_storage_metrics:
# TODO: support pp
labels = {
"storage_backend": hicache_storage_backend,
"tp_rank": self.cache_controller.tp_rank,
"dp_rank": self.cache_controller.dp_rank,
}
self.metrics_collector = StorageMetricsCollector(labels=labels)
# record the nodes with ongoing write through
self.ongoing_write_through = {}
@@ -379,6 +391,10 @@ class HiRadixCache(RadixCache):
self.loading_check()
if self.enable_storage:
self.drain_storage_control_queues()
if self.enable_storage_metrics:
self.metrics_collector.log_storage_metrics(
self.cache_controller.storage_backend.get_stats()
)
def drain_storage_control_queues(self):
"""
@@ -414,10 +430,13 @@ class HiRadixCache(RadixCache):
# process backup acks
for _ in range(n_backup):
ack_id = cc.ack_backup_queue.get()
operation = cc.ack_backup_queue.get()
ack_id = operation.id
entry = self.ongoing_backup.pop(ack_id, None)
if entry is not None:
entry.release_host()
if self.enable_storage_metrics:
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
# release host memory
host_indices_list = []
@@ -515,6 +534,11 @@ class HiRadixCache(RadixCache):
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
if self.enable_storage_metrics:
self.metrics_collector.log_prefetched_tokens(
min_completed_tokens - matched_length
)
return True
def match_prefix(self, key: List[int], **kwargs):

View File

@@ -5,6 +5,7 @@ import logging
import os
import signal
import threading
import time
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, List, Optional, Tuple
@@ -13,6 +14,7 @@ import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
from sglang.srt.metrics.collector import StorageMetrics
logger = logging.getLogger(__name__)
@@ -135,6 +137,7 @@ class HiCacheHF3FS(HiCacheStorage):
self.file_size = file_size
self.numjobs = numjobs
self.bytes_per_page = bytes_per_page
self.gb_per_page = bytes_per_page / (1 << 30)
self.entries = entries
self.dtype = dtype
self.metadata_client = metadata_client
@@ -174,6 +177,11 @@ class HiCacheHF3FS(HiCacheStorage):
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
self.prefetch_pgs = []
self.backup_pgs = []
self.prefetch_bandwidth = []
self.backup_bandwidth = []
@staticmethod
def from_env_config(
bytes_per_page: int,
@@ -308,6 +316,8 @@ class HiCacheHF3FS(HiCacheStorage):
for _ in range(len(batch_indices))
]
start_time = time.perf_counter()
futures = [
self.executor.submit(
self.clients[self.ac.next()].batch_read,
@@ -318,6 +328,13 @@ class HiCacheHF3FS(HiCacheStorage):
]
read_results = [result for future in futures for result in future.result()]
end_time = time.perf_counter()
ionum = len(batch_indices)
self.prefetch_pgs.append(ionum)
self.prefetch_bandwidth.append(
ionum / (end_time - start_time) * self.gb_per_page
)
results = [None] * len(keys)
for batch_index, file_result, read_result in zip(
batch_indices, file_results, read_results
@@ -345,6 +362,7 @@ class HiCacheHF3FS(HiCacheStorage):
[target_sizes] if target_sizes is not None else None,
)
@synchronized()
def batch_set(
self,
keys: List[str],
@@ -374,6 +392,8 @@ class HiCacheHF3FS(HiCacheStorage):
assert value.is_contiguous()
file_values.append(value)
start_time = time.perf_counter()
futures = [
self.executor.submit(
self.clients[self.ac.next()].batch_write,
@@ -388,6 +408,11 @@ class HiCacheHF3FS(HiCacheStorage):
for result in future.result()
]
end_time = time.perf_counter()
ionum = len(batch_indices)
self.backup_pgs.append(ionum)
self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
written_keys_to_confirm = []
results = [index[0] for index in indices]
for batch_index, write_result in zip(batch_indices, write_results):
@@ -439,3 +464,16 @@ class HiCacheHF3FS(HiCacheStorage):
except Exception as e:
logger.error(f"close HiCacheHF3FS: {e}")
logger.info("close HiCacheHF3FS")
@synchronized()
def get_stats(self):
storage_metrics = StorageMetrics()
storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
storage_metrics.backup_pgs.extend(self.backup_pgs)
storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
self.prefetch_pgs.clear()
self.backup_pgs.clear()
self.prefetch_bandwidth.clear()
self.backup_bandwidth.clear()
return storage_metrics