From f84db115b15288dd850dca2d799c5222d8d2c55d Mon Sep 17 00:00:00 2001 From: pansicheng Date: Sat, 6 Sep 2025 07:52:55 +0800 Subject: [PATCH] Add storage read/write bandwidth logs to monitor kvcache performance (#9965) Co-authored-by: Zhiqiang Xie --- .../sglang/srt/managers/cache_controller.py | 5 +- python/sglang/srt/managers/scheduler.py | 1 + .../sglang/srt/mem_cache/hicache_storage.py | 3 + python/sglang/srt/mem_cache/hiradix_cache.py | 26 ++++- .../mem_cache/storage/hf3fs/storage_hf3fs.py | 38 +++++++ python/sglang/srt/metrics/collector.py | 104 +++++++++++++++++- 6 files changed, 174 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 6a08cd2eb..6bc7bd8f1 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -33,6 +33,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.dp_attention import ( + get_attention_dp_rank, get_attention_tp_rank, get_attention_tp_size, is_dp_attention_enabled, @@ -402,9 +403,11 @@ class HiCacheController: if is_dp_attention_enabled(): self.tp_rank = get_attention_tp_rank() self.tp_size = get_attention_tp_size() + self.dp_rank = get_attention_dp_rank() else: self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() + self.dp_rank = 0 # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) @@ -885,7 +888,7 @@ class HiCacheController: if not self.backup_skip: self._page_backup(operation) - self.ack_backup_queue.put(operation.id) + self.ack_backup_queue.put(operation) except Empty: continue diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 91901ca8b..2dbc63191 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -623,6 +623,7 @@ class Scheduler( hicache_write_policy=server_args.hicache_write_policy, hicache_io_backend=server_args.hicache_io_backend, hicache_mem_layout=server_args.hicache_mem_layout, + enable_metrics=self.enable_metrics, hicache_storage_backend=server_args.hicache_storage_backend, hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy, model_name=server_args.served_model_name, diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 9112e748d..d5b4540f4 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -128,6 +128,9 @@ class HiCacheStorage(ABC): return i return len(keys) + def get_stats(self): + return None + class HiCacheFile(HiCacheStorage): diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 5f78ee111..d97b0033a 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import ( MLATokenToKVPoolHost, ) from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode +from sglang.srt.metrics.collector import StorageMetricsCollector logger = logging.getLogger(__name__) @@ -37,6 +38,7 @@ class HiRadixCache(RadixCache): hicache_write_policy: str, hicache_io_backend: str, hicache_mem_layout: str, + enable_metrics: bool, hicache_storage_backend: Optional[str] = None, hicache_storage_prefetch_policy: Optional[str] = "best_effort", model_name: Optional[str] = None, @@ -73,6 +75,8 @@ class HiRadixCache(RadixCache): self.tp_group = tp_cache_group self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) self.enable_storage = hicache_storage_backend is not None + self.enable_storage_metrics = self.enable_storage and enable_metrics + # todo: customizable storage prefetch threshold and timeout self.prefetch_threshold = 256 self.prefetch_timeout = 3 # seconds @@ -92,6 +96,14 @@ class HiRadixCache(RadixCache): model_name=model_name, storage_backend_extra_config=storage_backend_extra_config, ) + if self.enable_storage_metrics: + # TODO: support pp + labels = { + "storage_backend": hicache_storage_backend, + "tp_rank": self.cache_controller.tp_rank, + "dp_rank": self.cache_controller.dp_rank, + } + self.metrics_collector = StorageMetricsCollector(labels=labels) # record the nodes with ongoing write through self.ongoing_write_through = {} @@ -379,6 +391,10 @@ class HiRadixCache(RadixCache): self.loading_check() if self.enable_storage: self.drain_storage_control_queues() + if self.enable_storage_metrics: + self.metrics_collector.log_storage_metrics( + self.cache_controller.storage_backend.get_stats() + ) def drain_storage_control_queues(self): """ @@ -414,10 +430,13 @@ class HiRadixCache(RadixCache): # process backup acks for _ in range(n_backup): - ack_id = cc.ack_backup_queue.get() + operation = cc.ack_backup_queue.get() + ack_id = operation.id entry = self.ongoing_backup.pop(ack_id, None) if entry is not None: entry.release_host() + if self.enable_storage_metrics: + self.metrics_collector.log_backuped_tokens(operation.completed_tokens) # release host memory host_indices_list = [] @@ -515,6 +534,11 @@ class HiRadixCache(RadixCache): del self.ongoing_prefetch[req_id] self.cache_controller.prefetch_tokens_occupied -= len(token_ids) + if self.enable_storage_metrics: + self.metrics_collector.log_prefetched_tokens( + min_completed_tokens - matched_length + ) + return True def match_prefix(self, key: List[int], **kwargs): diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index 48d545889..7f64eb837 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -5,6 +5,7 @@ import logging import os import signal import threading +import time from abc import ABC, abstractmethod from functools import wraps from typing import Any, List, Optional, Tuple @@ -13,6 +14,7 @@ import torch from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient +from sglang.srt.metrics.collector import StorageMetrics logger = logging.getLogger(__name__) @@ -135,6 +137,7 @@ class HiCacheHF3FS(HiCacheStorage): self.file_size = file_size self.numjobs = numjobs self.bytes_per_page = bytes_per_page + self.gb_per_page = bytes_per_page / (1 << 30) self.entries = entries self.dtype = dtype self.metadata_client = metadata_client @@ -174,6 +177,11 @@ class HiCacheHF3FS(HiCacheStorage): signal.signal(signal.SIGTERM, lambda sig, frame: self.close()) signal.signal(signal.SIGQUIT, lambda sig, frame: self.close()) + self.prefetch_pgs = [] + self.backup_pgs = [] + self.prefetch_bandwidth = [] + self.backup_bandwidth = [] + @staticmethod def from_env_config( bytes_per_page: int, @@ -308,6 +316,8 @@ class HiCacheHF3FS(HiCacheStorage): for _ in range(len(batch_indices)) ] + start_time = time.perf_counter() + futures = [ self.executor.submit( self.clients[self.ac.next()].batch_read, @@ -318,6 +328,13 @@ class HiCacheHF3FS(HiCacheStorage): ] read_results = [result for future in futures for result in future.result()] + end_time = time.perf_counter() + ionum = len(batch_indices) + self.prefetch_pgs.append(ionum) + self.prefetch_bandwidth.append( + ionum / (end_time - start_time) * self.gb_per_page + ) + results = [None] * len(keys) for batch_index, file_result, read_result in zip( batch_indices, file_results, read_results @@ -345,6 +362,7 @@ class HiCacheHF3FS(HiCacheStorage): [target_sizes] if target_sizes is not None else None, ) + @synchronized() def batch_set( self, keys: List[str], @@ -374,6 +392,8 @@ class HiCacheHF3FS(HiCacheStorage): assert value.is_contiguous() file_values.append(value) + start_time = time.perf_counter() + futures = [ self.executor.submit( self.clients[self.ac.next()].batch_write, @@ -388,6 +408,11 @@ class HiCacheHF3FS(HiCacheStorage): for result in future.result() ] + end_time = time.perf_counter() + ionum = len(batch_indices) + self.backup_pgs.append(ionum) + self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page) + written_keys_to_confirm = [] results = [index[0] for index in indices] for batch_index, write_result in zip(batch_indices, write_results): @@ -439,3 +464,16 @@ class HiCacheHF3FS(HiCacheStorage): except Exception as e: logger.error(f"close HiCacheHF3FS: {e}") logger.info("close HiCacheHF3FS") + + @synchronized() + def get_stats(self): + storage_metrics = StorageMetrics() + storage_metrics.prefetch_pgs.extend(self.prefetch_pgs) + storage_metrics.backup_pgs.extend(self.backup_pgs) + storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth) + storage_metrics.backup_bandwidth.extend(self.backup_bandwidth) + self.prefetch_pgs.clear() + self.backup_pgs.clear() + self.prefetch_bandwidth.clear() + self.backup_bandwidth.clear() + return storage_metrics diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index f1bb74689..b174bbeb3 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -14,7 +14,7 @@ """Utilities for Prometheus Metrics Collection.""" import time -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Union @@ -559,3 +559,105 @@ class TokenizerMetricsCollector: def observe_one_aborted_request(self): self.num_aborted_requests_total.labels(**self.labels).inc(1) + + +@dataclass +class StorageMetrics: + prefetch_pgs: List[int] = field(default_factory=list) + backup_pgs: List[int] = field(default_factory=list) + prefetch_bandwidth: List[float] = field(default_factory=list) + backup_bandwidth: List[float] = field(default_factory=list) + + +class StorageMetricsCollector: + def __init__( + self, + labels: Dict[str, str], + ): + from prometheus_client import Counter, Histogram + + self.labels = labels + + self.prefetched_tokens_total = Counter( + name="sglang:prefetched_tokens_total", + documentation="Number of prefetched prompt tokens.", + labelnames=labels.keys(), + ) + + self.backuped_tokens_total = Counter( + name="sglang:backuped_tokens_total", + documentation="Number of backuped tokens.", + labelnames=labels.keys(), + ) + + bucket_io = [ + 1, + 5, + 10, + 50, + 100, + ] + + bucket_bandwidth = [ + 0.1, + 0.5, + 1, + 5, + 10, + 50, + 100, + ] + + self.histogram_prefetch_pgs = Histogram( + name="sglang:prefetch_pgs", + documentation="Histogram of prefetch pages of batches.", + labelnames=labels.keys(), + buckets=bucket_io, + ) + + self.histogram_backup_pgs = Histogram( + name="sglang:backup_pgs", + documentation="Histogram of backup pages of batches.", + labelnames=labels.keys(), + buckets=bucket_io, + ) + + self.histogram_prefetch_bandwidth = Histogram( + name="sglang:prefetch_bandwidth", + documentation="Histogram of prefetch bandwidth in GB/s.", + labelnames=labels.keys(), + buckets=bucket_bandwidth, + ) + + self.histogram_backup_bandwidth = Histogram( + name="sglang:backup_bandwidth", + documentation="Histogram of backup bandwidth in GB/s.", + labelnames=labels.keys(), + buckets=bucket_bandwidth, + ) + + def log_prefetched_tokens(self, prefetched_tokens: int): + if prefetched_tokens > 0: + self.prefetched_tokens_total.labels(**self.labels).inc(prefetched_tokens) + + def log_backuped_tokens(self, backuped_tokens: int): + if backuped_tokens > 0: + self.backuped_tokens_total.labels(**self.labels).inc(backuped_tokens) + + def _log_histogram(self, histogram, data: Union[int, float]): + histogram.labels(**self.labels).observe(data) + + def log_storage_metrics(self, storage_metrics: Optional[StorageMetrics] = None): + if storage_metrics is None: + return + + assert isinstance(storage_metrics, StorageMetrics) + + for v in storage_metrics.prefetch_pgs: + self._log_histogram(self.histogram_prefetch_pgs, v) + for v in storage_metrics.backup_pgs: + self._log_histogram(self.histogram_backup_pgs, v) + for v in storage_metrics.prefetch_bandwidth: + self._log_histogram(self.histogram_prefetch_bandwidth, v) + for v in storage_metrics.backup_bandwidth: + self._log_histogram(self.histogram_backup_bandwidth, v)