Add storage read/write bandwidth logs to monitor kvcache performance (#9965)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
get_attention_dp_rank,
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
is_dp_attention_enabled,
|
is_dp_attention_enabled,
|
||||||
@@ -402,9 +403,11 @@ class HiCacheController:
|
|||||||
if is_dp_attention_enabled():
|
if is_dp_attention_enabled():
|
||||||
self.tp_rank = get_attention_tp_rank()
|
self.tp_rank = get_attention_tp_rank()
|
||||||
self.tp_size = get_attention_tp_size()
|
self.tp_size = get_attention_tp_size()
|
||||||
|
self.dp_rank = get_attention_dp_rank()
|
||||||
else:
|
else:
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.dp_rank = 0
|
||||||
|
|
||||||
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
||||||
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
||||||
@@ -885,7 +888,7 @@ class HiCacheController:
|
|||||||
|
|
||||||
if not self.backup_skip:
|
if not self.backup_skip:
|
||||||
self._page_backup(operation)
|
self._page_backup(operation)
|
||||||
self.ack_backup_queue.put(operation.id)
|
self.ack_backup_queue.put(operation)
|
||||||
|
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -623,6 +623,7 @@ class Scheduler(
|
|||||||
hicache_write_policy=server_args.hicache_write_policy,
|
hicache_write_policy=server_args.hicache_write_policy,
|
||||||
hicache_io_backend=server_args.hicache_io_backend,
|
hicache_io_backend=server_args.hicache_io_backend,
|
||||||
hicache_mem_layout=server_args.hicache_mem_layout,
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
||||||
|
enable_metrics=self.enable_metrics,
|
||||||
hicache_storage_backend=server_args.hicache_storage_backend,
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||||
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
||||||
model_name=server_args.served_model_name,
|
model_name=server_args.served_model_name,
|
||||||
|
|||||||
@@ -128,6 +128,9 @@ class HiCacheStorage(ABC):
|
|||||||
return i
|
return i
|
||||||
return len(keys)
|
return len(keys)
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HiCacheFile(HiCacheStorage):
|
class HiCacheFile(HiCacheStorage):
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
|||||||
MLATokenToKVPoolHost,
|
MLATokenToKVPoolHost,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||||
|
from sglang.srt.metrics.collector import StorageMetricsCollector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
|
|||||||
hicache_write_policy: str,
|
hicache_write_policy: str,
|
||||||
hicache_io_backend: str,
|
hicache_io_backend: str,
|
||||||
hicache_mem_layout: str,
|
hicache_mem_layout: str,
|
||||||
|
enable_metrics: bool,
|
||||||
hicache_storage_backend: Optional[str] = None,
|
hicache_storage_backend: Optional[str] = None,
|
||||||
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
@@ -73,6 +75,8 @@ class HiRadixCache(RadixCache):
|
|||||||
self.tp_group = tp_cache_group
|
self.tp_group = tp_cache_group
|
||||||
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
||||||
self.enable_storage = hicache_storage_backend is not None
|
self.enable_storage = hicache_storage_backend is not None
|
||||||
|
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
||||||
|
|
||||||
# todo: customizable storage prefetch threshold and timeout
|
# todo: customizable storage prefetch threshold and timeout
|
||||||
self.prefetch_threshold = 256
|
self.prefetch_threshold = 256
|
||||||
self.prefetch_timeout = 3 # seconds
|
self.prefetch_timeout = 3 # seconds
|
||||||
@@ -92,6 +96,14 @@ class HiRadixCache(RadixCache):
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
storage_backend_extra_config=storage_backend_extra_config,
|
storage_backend_extra_config=storage_backend_extra_config,
|
||||||
)
|
)
|
||||||
|
if self.enable_storage_metrics:
|
||||||
|
# TODO: support pp
|
||||||
|
labels = {
|
||||||
|
"storage_backend": hicache_storage_backend,
|
||||||
|
"tp_rank": self.cache_controller.tp_rank,
|
||||||
|
"dp_rank": self.cache_controller.dp_rank,
|
||||||
|
}
|
||||||
|
self.metrics_collector = StorageMetricsCollector(labels=labels)
|
||||||
|
|
||||||
# record the nodes with ongoing write through
|
# record the nodes with ongoing write through
|
||||||
self.ongoing_write_through = {}
|
self.ongoing_write_through = {}
|
||||||
@@ -379,6 +391,10 @@ class HiRadixCache(RadixCache):
|
|||||||
self.loading_check()
|
self.loading_check()
|
||||||
if self.enable_storage:
|
if self.enable_storage:
|
||||||
self.drain_storage_control_queues()
|
self.drain_storage_control_queues()
|
||||||
|
if self.enable_storage_metrics:
|
||||||
|
self.metrics_collector.log_storage_metrics(
|
||||||
|
self.cache_controller.storage_backend.get_stats()
|
||||||
|
)
|
||||||
|
|
||||||
def drain_storage_control_queues(self):
|
def drain_storage_control_queues(self):
|
||||||
"""
|
"""
|
||||||
@@ -414,10 +430,13 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
# process backup acks
|
# process backup acks
|
||||||
for _ in range(n_backup):
|
for _ in range(n_backup):
|
||||||
ack_id = cc.ack_backup_queue.get()
|
operation = cc.ack_backup_queue.get()
|
||||||
|
ack_id = operation.id
|
||||||
entry = self.ongoing_backup.pop(ack_id, None)
|
entry = self.ongoing_backup.pop(ack_id, None)
|
||||||
if entry is not None:
|
if entry is not None:
|
||||||
entry.release_host()
|
entry.release_host()
|
||||||
|
if self.enable_storage_metrics:
|
||||||
|
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
|
||||||
|
|
||||||
# release host memory
|
# release host memory
|
||||||
host_indices_list = []
|
host_indices_list = []
|
||||||
@@ -515,6 +534,11 @@ class HiRadixCache(RadixCache):
|
|||||||
del self.ongoing_prefetch[req_id]
|
del self.ongoing_prefetch[req_id]
|
||||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||||
|
|
||||||
|
if self.enable_storage_metrics:
|
||||||
|
self.metrics_collector.log_prefetched_tokens(
|
||||||
|
min_completed_tokens - matched_length
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def match_prefix(self, key: List[int], **kwargs):
|
def match_prefix(self, key: List[int], **kwargs):
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
@@ -13,6 +14,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
||||||
|
from sglang.srt.metrics.collector import StorageMetrics
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -135,6 +137,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
self.file_size = file_size
|
self.file_size = file_size
|
||||||
self.numjobs = numjobs
|
self.numjobs = numjobs
|
||||||
self.bytes_per_page = bytes_per_page
|
self.bytes_per_page = bytes_per_page
|
||||||
|
self.gb_per_page = bytes_per_page / (1 << 30)
|
||||||
self.entries = entries
|
self.entries = entries
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.metadata_client = metadata_client
|
self.metadata_client = metadata_client
|
||||||
@@ -174,6 +177,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
|
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
|
||||||
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
|
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
|
||||||
|
|
||||||
|
self.prefetch_pgs = []
|
||||||
|
self.backup_pgs = []
|
||||||
|
self.prefetch_bandwidth = []
|
||||||
|
self.backup_bandwidth = []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_env_config(
|
def from_env_config(
|
||||||
bytes_per_page: int,
|
bytes_per_page: int,
|
||||||
@@ -308,6 +316,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
for _ in range(len(batch_indices))
|
for _ in range(len(batch_indices))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
futures = [
|
futures = [
|
||||||
self.executor.submit(
|
self.executor.submit(
|
||||||
self.clients[self.ac.next()].batch_read,
|
self.clients[self.ac.next()].batch_read,
|
||||||
@@ -318,6 +328,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
]
|
]
|
||||||
read_results = [result for future in futures for result in future.result()]
|
read_results = [result for future in futures for result in future.result()]
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
ionum = len(batch_indices)
|
||||||
|
self.prefetch_pgs.append(ionum)
|
||||||
|
self.prefetch_bandwidth.append(
|
||||||
|
ionum / (end_time - start_time) * self.gb_per_page
|
||||||
|
)
|
||||||
|
|
||||||
results = [None] * len(keys)
|
results = [None] * len(keys)
|
||||||
for batch_index, file_result, read_result in zip(
|
for batch_index, file_result, read_result in zip(
|
||||||
batch_indices, file_results, read_results
|
batch_indices, file_results, read_results
|
||||||
@@ -345,6 +362,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
[target_sizes] if target_sizes is not None else None,
|
[target_sizes] if target_sizes is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@synchronized()
|
||||||
def batch_set(
|
def batch_set(
|
||||||
self,
|
self,
|
||||||
keys: List[str],
|
keys: List[str],
|
||||||
@@ -374,6 +392,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
assert value.is_contiguous()
|
assert value.is_contiguous()
|
||||||
file_values.append(value)
|
file_values.append(value)
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
futures = [
|
futures = [
|
||||||
self.executor.submit(
|
self.executor.submit(
|
||||||
self.clients[self.ac.next()].batch_write,
|
self.clients[self.ac.next()].batch_write,
|
||||||
@@ -388,6 +408,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
for result in future.result()
|
for result in future.result()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
ionum = len(batch_indices)
|
||||||
|
self.backup_pgs.append(ionum)
|
||||||
|
self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
|
||||||
|
|
||||||
written_keys_to_confirm = []
|
written_keys_to_confirm = []
|
||||||
results = [index[0] for index in indices]
|
results = [index[0] for index in indices]
|
||||||
for batch_index, write_result in zip(batch_indices, write_results):
|
for batch_index, write_result in zip(batch_indices, write_results):
|
||||||
@@ -439,3 +464,16 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"close HiCacheHF3FS: {e}")
|
logger.error(f"close HiCacheHF3FS: {e}")
|
||||||
logger.info("close HiCacheHF3FS")
|
logger.info("close HiCacheHF3FS")
|
||||||
|
|
||||||
|
@synchronized()
|
||||||
|
def get_stats(self):
|
||||||
|
storage_metrics = StorageMetrics()
|
||||||
|
storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
|
||||||
|
storage_metrics.backup_pgs.extend(self.backup_pgs)
|
||||||
|
storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
|
||||||
|
storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
|
||||||
|
self.prefetch_pgs.clear()
|
||||||
|
self.backup_pgs.clear()
|
||||||
|
self.prefetch_bandwidth.clear()
|
||||||
|
self.backup_bandwidth.clear()
|
||||||
|
return storage_metrics
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
"""Utilities for Prometheus Metrics Collection."""
|
"""Utilities for Prometheus Metrics Collection."""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
@@ -559,3 +559,105 @@ class TokenizerMetricsCollector:
|
|||||||
|
|
||||||
def observe_one_aborted_request(self):
|
def observe_one_aborted_request(self):
|
||||||
self.num_aborted_requests_total.labels(**self.labels).inc(1)
|
self.num_aborted_requests_total.labels(**self.labels).inc(1)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StorageMetrics:
|
||||||
|
prefetch_pgs: List[int] = field(default_factory=list)
|
||||||
|
backup_pgs: List[int] = field(default_factory=list)
|
||||||
|
prefetch_bandwidth: List[float] = field(default_factory=list)
|
||||||
|
backup_bandwidth: List[float] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class StorageMetricsCollector:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
labels: Dict[str, str],
|
||||||
|
):
|
||||||
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
self.prefetched_tokens_total = Counter(
|
||||||
|
name="sglang:prefetched_tokens_total",
|
||||||
|
documentation="Number of prefetched prompt tokens.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.backuped_tokens_total = Counter(
|
||||||
|
name="sglang:backuped_tokens_total",
|
||||||
|
documentation="Number of backuped tokens.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
bucket_io = [
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
50,
|
||||||
|
100,
|
||||||
|
]
|
||||||
|
|
||||||
|
bucket_bandwidth = [
|
||||||
|
0.1,
|
||||||
|
0.5,
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
50,
|
||||||
|
100,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.histogram_prefetch_pgs = Histogram(
|
||||||
|
name="sglang:prefetch_pgs",
|
||||||
|
documentation="Histogram of prefetch pages of batches.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=bucket_io,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_backup_pgs = Histogram(
|
||||||
|
name="sglang:backup_pgs",
|
||||||
|
documentation="Histogram of backup pages of batches.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=bucket_io,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_prefetch_bandwidth = Histogram(
|
||||||
|
name="sglang:prefetch_bandwidth",
|
||||||
|
documentation="Histogram of prefetch bandwidth in GB/s.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=bucket_bandwidth,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_backup_bandwidth = Histogram(
|
||||||
|
name="sglang:backup_bandwidth",
|
||||||
|
documentation="Histogram of backup bandwidth in GB/s.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=bucket_bandwidth,
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_prefetched_tokens(self, prefetched_tokens: int):
|
||||||
|
if prefetched_tokens > 0:
|
||||||
|
self.prefetched_tokens_total.labels(**self.labels).inc(prefetched_tokens)
|
||||||
|
|
||||||
|
def log_backuped_tokens(self, backuped_tokens: int):
|
||||||
|
if backuped_tokens > 0:
|
||||||
|
self.backuped_tokens_total.labels(**self.labels).inc(backuped_tokens)
|
||||||
|
|
||||||
|
def _log_histogram(self, histogram, data: Union[int, float]):
|
||||||
|
histogram.labels(**self.labels).observe(data)
|
||||||
|
|
||||||
|
def log_storage_metrics(self, storage_metrics: Optional[StorageMetrics] = None):
|
||||||
|
if storage_metrics is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert isinstance(storage_metrics, StorageMetrics)
|
||||||
|
|
||||||
|
for v in storage_metrics.prefetch_pgs:
|
||||||
|
self._log_histogram(self.histogram_prefetch_pgs, v)
|
||||||
|
for v in storage_metrics.backup_pgs:
|
||||||
|
self._log_histogram(self.histogram_backup_pgs, v)
|
||||||
|
for v in storage_metrics.prefetch_bandwidth:
|
||||||
|
self._log_histogram(self.histogram_prefetch_bandwidth, v)
|
||||||
|
for v in storage_metrics.backup_bandwidth:
|
||||||
|
self._log_histogram(self.histogram_backup_bandwidth, v)
|
||||||
|
|||||||
Reference in New Issue
Block a user