Add storage read/write bandwidth logs to monitor kvcache performance (#9965)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
pansicheng
2025-09-06 07:52:55 +08:00
committed by GitHub
parent efb0de2c8d
commit f84db115b1
6 changed files with 174 additions and 3 deletions

View File

@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
@@ -402,9 +403,11 @@ class HiCacheController:
if is_dp_attention_enabled(): if is_dp_attention_enabled():
self.tp_rank = get_attention_tp_rank() self.tp_rank = get_attention_tp_rank()
self.tp_size = get_attention_tp_size() self.tp_size = get_attention_tp_size()
self.dp_rank = get_attention_dp_rank()
else: else:
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.dp_rank = 0
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
@@ -885,7 +888,7 @@ class HiCacheController:
if not self.backup_skip: if not self.backup_skip:
self._page_backup(operation) self._page_backup(operation)
self.ack_backup_queue.put(operation.id) self.ack_backup_queue.put(operation)
except Empty: except Empty:
continue continue

View File

@@ -623,6 +623,7 @@ class Scheduler(
hicache_write_policy=server_args.hicache_write_policy, hicache_write_policy=server_args.hicache_write_policy,
hicache_io_backend=server_args.hicache_io_backend, hicache_io_backend=server_args.hicache_io_backend,
hicache_mem_layout=server_args.hicache_mem_layout, hicache_mem_layout=server_args.hicache_mem_layout,
enable_metrics=self.enable_metrics,
hicache_storage_backend=server_args.hicache_storage_backend, hicache_storage_backend=server_args.hicache_storage_backend,
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy, hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
model_name=server_args.served_model_name, model_name=server_args.served_model_name,

View File

@@ -128,6 +128,9 @@ class HiCacheStorage(ABC):
return i return i
return len(keys) return len(keys)
def get_stats(self):
return None
class HiCacheFile(HiCacheStorage): class HiCacheFile(HiCacheStorage):

View File

@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
MLATokenToKVPoolHost, MLATokenToKVPoolHost,
) )
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.metrics.collector import StorageMetricsCollector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
hicache_write_policy: str, hicache_write_policy: str,
hicache_io_backend: str, hicache_io_backend: str,
hicache_mem_layout: str, hicache_mem_layout: str,
enable_metrics: bool,
hicache_storage_backend: Optional[str] = None, hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort", hicache_storage_prefetch_policy: Optional[str] = "best_effort",
model_name: Optional[str] = None, model_name: Optional[str] = None,
@@ -73,6 +75,8 @@ class HiRadixCache(RadixCache):
self.tp_group = tp_cache_group self.tp_group = tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.enable_storage = hicache_storage_backend is not None self.enable_storage = hicache_storage_backend is not None
self.enable_storage_metrics = self.enable_storage and enable_metrics
# todo: customizable storage prefetch threshold and timeout # todo: customizable storage prefetch threshold and timeout
self.prefetch_threshold = 256 self.prefetch_threshold = 256
self.prefetch_timeout = 3 # seconds self.prefetch_timeout = 3 # seconds
@@ -92,6 +96,14 @@ class HiRadixCache(RadixCache):
model_name=model_name, model_name=model_name,
storage_backend_extra_config=storage_backend_extra_config, storage_backend_extra_config=storage_backend_extra_config,
) )
if self.enable_storage_metrics:
# TODO: support pp
labels = {
"storage_backend": hicache_storage_backend,
"tp_rank": self.cache_controller.tp_rank,
"dp_rank": self.cache_controller.dp_rank,
}
self.metrics_collector = StorageMetricsCollector(labels=labels)
# record the nodes with ongoing write through # record the nodes with ongoing write through
self.ongoing_write_through = {} self.ongoing_write_through = {}
@@ -379,6 +391,10 @@ class HiRadixCache(RadixCache):
self.loading_check() self.loading_check()
if self.enable_storage: if self.enable_storage:
self.drain_storage_control_queues() self.drain_storage_control_queues()
if self.enable_storage_metrics:
self.metrics_collector.log_storage_metrics(
self.cache_controller.storage_backend.get_stats()
)
def drain_storage_control_queues(self): def drain_storage_control_queues(self):
""" """
@@ -414,10 +430,13 @@ class HiRadixCache(RadixCache):
# process backup acks # process backup acks
for _ in range(n_backup): for _ in range(n_backup):
ack_id = cc.ack_backup_queue.get() operation = cc.ack_backup_queue.get()
ack_id = operation.id
entry = self.ongoing_backup.pop(ack_id, None) entry = self.ongoing_backup.pop(ack_id, None)
if entry is not None: if entry is not None:
entry.release_host() entry.release_host()
if self.enable_storage_metrics:
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
# release host memory # release host memory
host_indices_list = [] host_indices_list = []
@@ -515,6 +534,11 @@ class HiRadixCache(RadixCache):
del self.ongoing_prefetch[req_id] del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids) self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
if self.enable_storage_metrics:
self.metrics_collector.log_prefetched_tokens(
min_completed_tokens - matched_length
)
return True return True
def match_prefix(self, key: List[int], **kwargs): def match_prefix(self, key: List[int], **kwargs):

View File

@@ -5,6 +5,7 @@ import logging
import os import os
import signal import signal
import threading import threading
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import wraps from functools import wraps
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
@@ -13,6 +14,7 @@ import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
from sglang.srt.metrics.collector import StorageMetrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -135,6 +137,7 @@ class HiCacheHF3FS(HiCacheStorage):
self.file_size = file_size self.file_size = file_size
self.numjobs = numjobs self.numjobs = numjobs
self.bytes_per_page = bytes_per_page self.bytes_per_page = bytes_per_page
self.gb_per_page = bytes_per_page / (1 << 30)
self.entries = entries self.entries = entries
self.dtype = dtype self.dtype = dtype
self.metadata_client = metadata_client self.metadata_client = metadata_client
@@ -174,6 +177,11 @@ class HiCacheHF3FS(HiCacheStorage):
signal.signal(signal.SIGTERM, lambda sig, frame: self.close()) signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close()) signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
self.prefetch_pgs = []
self.backup_pgs = []
self.prefetch_bandwidth = []
self.backup_bandwidth = []
@staticmethod @staticmethod
def from_env_config( def from_env_config(
bytes_per_page: int, bytes_per_page: int,
@@ -308,6 +316,8 @@ class HiCacheHF3FS(HiCacheStorage):
for _ in range(len(batch_indices)) for _ in range(len(batch_indices))
] ]
start_time = time.perf_counter()
futures = [ futures = [
self.executor.submit( self.executor.submit(
self.clients[self.ac.next()].batch_read, self.clients[self.ac.next()].batch_read,
@@ -318,6 +328,13 @@ class HiCacheHF3FS(HiCacheStorage):
] ]
read_results = [result for future in futures for result in future.result()] read_results = [result for future in futures for result in future.result()]
end_time = time.perf_counter()
ionum = len(batch_indices)
self.prefetch_pgs.append(ionum)
self.prefetch_bandwidth.append(
ionum / (end_time - start_time) * self.gb_per_page
)
results = [None] * len(keys) results = [None] * len(keys)
for batch_index, file_result, read_result in zip( for batch_index, file_result, read_result in zip(
batch_indices, file_results, read_results batch_indices, file_results, read_results
@@ -345,6 +362,7 @@ class HiCacheHF3FS(HiCacheStorage):
[target_sizes] if target_sizes is not None else None, [target_sizes] if target_sizes is not None else None,
) )
@synchronized()
def batch_set( def batch_set(
self, self,
keys: List[str], keys: List[str],
@@ -374,6 +392,8 @@ class HiCacheHF3FS(HiCacheStorage):
assert value.is_contiguous() assert value.is_contiguous()
file_values.append(value) file_values.append(value)
start_time = time.perf_counter()
futures = [ futures = [
self.executor.submit( self.executor.submit(
self.clients[self.ac.next()].batch_write, self.clients[self.ac.next()].batch_write,
@@ -388,6 +408,11 @@ class HiCacheHF3FS(HiCacheStorage):
for result in future.result() for result in future.result()
] ]
end_time = time.perf_counter()
ionum = len(batch_indices)
self.backup_pgs.append(ionum)
self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
written_keys_to_confirm = [] written_keys_to_confirm = []
results = [index[0] for index in indices] results = [index[0] for index in indices]
for batch_index, write_result in zip(batch_indices, write_results): for batch_index, write_result in zip(batch_indices, write_results):
@@ -439,3 +464,16 @@ class HiCacheHF3FS(HiCacheStorage):
except Exception as e: except Exception as e:
logger.error(f"close HiCacheHF3FS: {e}") logger.error(f"close HiCacheHF3FS: {e}")
logger.info("close HiCacheHF3FS") logger.info("close HiCacheHF3FS")
@synchronized()
def get_stats(self):
storage_metrics = StorageMetrics()
storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
storage_metrics.backup_pgs.extend(self.backup_pgs)
storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
self.prefetch_pgs.clear()
self.backup_pgs.clear()
self.prefetch_bandwidth.clear()
self.backup_bandwidth.clear()
return storage_metrics

View File

@@ -14,7 +14,7 @@
"""Utilities for Prometheus Metrics Collection.""" """Utilities for Prometheus Metrics Collection."""
import time import time
from dataclasses import dataclass from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
@@ -559,3 +559,105 @@ class TokenizerMetricsCollector:
def observe_one_aborted_request(self): def observe_one_aborted_request(self):
self.num_aborted_requests_total.labels(**self.labels).inc(1) self.num_aborted_requests_total.labels(**self.labels).inc(1)
@dataclass
class StorageMetrics:
prefetch_pgs: List[int] = field(default_factory=list)
backup_pgs: List[int] = field(default_factory=list)
prefetch_bandwidth: List[float] = field(default_factory=list)
backup_bandwidth: List[float] = field(default_factory=list)
class StorageMetricsCollector:
def __init__(
self,
labels: Dict[str, str],
):
from prometheus_client import Counter, Histogram
self.labels = labels
self.prefetched_tokens_total = Counter(
name="sglang:prefetched_tokens_total",
documentation="Number of prefetched prompt tokens.",
labelnames=labels.keys(),
)
self.backuped_tokens_total = Counter(
name="sglang:backuped_tokens_total",
documentation="Number of backuped tokens.",
labelnames=labels.keys(),
)
bucket_io = [
1,
5,
10,
50,
100,
]
bucket_bandwidth = [
0.1,
0.5,
1,
5,
10,
50,
100,
]
self.histogram_prefetch_pgs = Histogram(
name="sglang:prefetch_pgs",
documentation="Histogram of prefetch pages of batches.",
labelnames=labels.keys(),
buckets=bucket_io,
)
self.histogram_backup_pgs = Histogram(
name="sglang:backup_pgs",
documentation="Histogram of backup pages of batches.",
labelnames=labels.keys(),
buckets=bucket_io,
)
self.histogram_prefetch_bandwidth = Histogram(
name="sglang:prefetch_bandwidth",
documentation="Histogram of prefetch bandwidth in GB/s.",
labelnames=labels.keys(),
buckets=bucket_bandwidth,
)
self.histogram_backup_bandwidth = Histogram(
name="sglang:backup_bandwidth",
documentation="Histogram of backup bandwidth in GB/s.",
labelnames=labels.keys(),
buckets=bucket_bandwidth,
)
def log_prefetched_tokens(self, prefetched_tokens: int):
if prefetched_tokens > 0:
self.prefetched_tokens_total.labels(**self.labels).inc(prefetched_tokens)
def log_backuped_tokens(self, backuped_tokens: int):
if backuped_tokens > 0:
self.backuped_tokens_total.labels(**self.labels).inc(backuped_tokens)
def _log_histogram(self, histogram, data: Union[int, float]):
histogram.labels(**self.labels).observe(data)
def log_storage_metrics(self, storage_metrics: Optional[StorageMetrics] = None):
if storage_metrics is None:
return
assert isinstance(storage_metrics, StorageMetrics)
for v in storage_metrics.prefetch_pgs:
self._log_histogram(self.histogram_prefetch_pgs, v)
for v in storage_metrics.backup_pgs:
self._log_histogram(self.histogram_backup_pgs, v)
for v in storage_metrics.prefetch_bandwidth:
self._log_histogram(self.histogram_prefetch_bandwidth, v)
for v in storage_metrics.backup_bandwidth:
self._log_histogram(self.histogram_backup_bandwidth, v)