165 lines
6.4 KiB
Python
165 lines
6.4 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
from typing import TYPE_CHECKING
|
|
|
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.managers.scheduler import Scheduler
|
|
|
|
|
|
class SchedulerRuntimeCheckerMixin:
|
|
|
|
def _check_hybrid_memory(self: Scheduler):
|
|
(
|
|
full_num_used,
|
|
swa_num_used,
|
|
_,
|
|
_,
|
|
full_available_size,
|
|
full_evictable_size,
|
|
swa_available_size,
|
|
swa_evictable_size,
|
|
) = self._get_swa_token_info()
|
|
memory_leak = full_num_used != 0 or swa_num_used != 0
|
|
token_msg = (
|
|
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
|
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
|
)
|
|
return memory_leak, token_msg
|
|
|
|
def _check_mamba_memory(self: Scheduler):
|
|
(
|
|
full_num_used,
|
|
mamba_num_used,
|
|
_,
|
|
_,
|
|
full_available_size,
|
|
full_evictable_size,
|
|
mamba_available_size,
|
|
mamba_evictable_size,
|
|
) = self._get_mamba_token_info()
|
|
memory_leak = (
|
|
full_num_used != self.tree_cache.full_protected_size()
|
|
or mamba_num_used != self.tree_cache.mamba_protected_size()
|
|
)
|
|
token_msg = (
|
|
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
|
|
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
|
|
)
|
|
return memory_leak, token_msg
|
|
|
|
def _check_radix_cache_memory(self: Scheduler):
|
|
_, _, available_size, evictable_size = self._get_token_info()
|
|
protected_size = self.tree_cache.protected_size()
|
|
memory_leak = (available_size + evictable_size) != (
|
|
# self.max_total_num_tokens
|
|
# if not self.enable_hierarchical_cache
|
|
# else self.max_total_num_tokens - protected_size
|
|
self.max_total_num_tokens
|
|
- protected_size
|
|
)
|
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
|
return memory_leak, token_msg
|
|
|
|
def _check_req_pool(self: Scheduler):
|
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
req_total_size = (
|
|
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
|
|
)
|
|
else:
|
|
req_total_size = self.req_to_token_pool.size
|
|
|
|
if len(self.req_to_token_pool.free_slots) != req_total_size:
|
|
msg = (
|
|
"req_to_token_pool memory leak detected!"
|
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
|
f"total_size={self.req_to_token_pool.size}\n"
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
def check_memory(self: Scheduler):
|
|
if self.is_hybrid:
|
|
memory_leak, token_msg = self._check_hybrid_memory()
|
|
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
|
|
memory_leak, token_msg = self._check_mamba_memory()
|
|
else:
|
|
memory_leak, token_msg = self._check_radix_cache_memory()
|
|
|
|
if memory_leak:
|
|
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
|
raise ValueError(msg)
|
|
|
|
self._check_req_pool()
|
|
|
|
if (
|
|
self.enable_metrics
|
|
and self.current_scheduler_metrics_enabled()
|
|
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
|
):
|
|
# During idle time, also collect metrics every 30 seconds.
|
|
if self.is_hybrid:
|
|
(
|
|
full_num_used,
|
|
swa_num_used,
|
|
full_token_usage,
|
|
swa_token_usage,
|
|
_,
|
|
_,
|
|
_,
|
|
_,
|
|
) = self._get_swa_token_info()
|
|
num_used = max(full_num_used, swa_num_used)
|
|
token_usage = max(full_token_usage, swa_token_usage)
|
|
elif self.is_hybrid_gdn:
|
|
(
|
|
num_used,
|
|
_,
|
|
token_usage,
|
|
_,
|
|
_,
|
|
_,
|
|
_,
|
|
_,
|
|
) = self._get_mamba_token_info()
|
|
else:
|
|
num_used, token_usage, _, _ = self._get_token_info()
|
|
num_running_reqs = len(self.running_batch.reqs)
|
|
self.stats.num_running_reqs = num_running_reqs
|
|
self.stats.num_used_tokens = num_used
|
|
self.stats.token_usage = round(token_usage, 2)
|
|
self.stats.gen_throughput = 0
|
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
self.stats.num_prefill_prealloc_queue_reqs = len(
|
|
self.disagg_prefill_bootstrap_queue.queue
|
|
)
|
|
self.stats.num_prefill_inflight_queue_reqs = len(
|
|
self.disagg_prefill_inflight_queue
|
|
)
|
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
self.stats.num_decode_prealloc_queue_reqs = len(
|
|
self.disagg_decode_prealloc_queue.queue
|
|
)
|
|
self.stats.num_decode_transfer_queue_reqs = len(
|
|
self.disagg_decode_transfer_queue.queue
|
|
)
|
|
self.metrics_collector.log_stats(self.stats)
|
|
self._publish_kv_events()
|
|
|
|
def check_tree_cache(self: Scheduler):
|
|
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
|
|
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
|
|
):
|
|
self.tree_cache.sanity_check()
|
|
|
|
def self_check_during_idle(self: Scheduler):
|
|
self.check_memory()
|
|
self.check_tree_cache()
|
|
self.new_token_ratio = self.init_new_token_ratio
|
|
self.maybe_sleep_on_idle()
|