From 32852fe9e9f75abb92eb5b49b03237c8e515a3e2 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 23 Oct 2025 20:53:26 +0800 Subject: [PATCH] Move memory runtime checker to mixin class (#12014) --- python/sglang/srt/managers/scheduler.py | 139 +-------------- .../scheduler_runtime_checker_mixin.py | 164 ++++++++++++++++++ 2 files changed, 168 insertions(+), 135 deletions(-) create mode 100644 python/sglang/srt/managers/scheduler_runtime_checker_mixin.py diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3a9f92ffc..917d33e87 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -137,6 +137,9 @@ from sglang.srt.managers.scheduler_output_processor_mixin import ( from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper +from sglang.srt.managers.scheduler_runtime_checker_mixin import ( + SchedulerRuntimeCheckerMixin, +) from sglang.srt.managers.scheduler_update_weights_mixin import ( SchedulerUpdateWeightsMixin, ) @@ -207,6 +210,7 @@ class Scheduler( SchedulerMetricsMixin, SchedulerDisaggregationDecodeMixin, SchedulerDisaggregationPrefillMixin, + SchedulerRuntimeCheckerMixin, SchedulerPPMixin, ): """A scheduler that manages a tensor parallel GPU worker.""" @@ -1506,141 +1510,6 @@ class Scheduler( for tokenized_req in recv_req: self.handle_embedding_request(tokenized_req) - def self_check_during_idle(self): - self.check_memory() - self.check_tree_cache() - self.new_token_ratio = self.init_new_token_ratio - self.maybe_sleep_on_idle() - - def check_memory(self): - if self.is_hybrid: - ( - full_num_used, - swa_num_used, - _, - _, - full_available_size, - full_evictable_size, - swa_available_size, - swa_evictable_size, - ) = self._get_swa_token_info() - memory_leak = full_num_used != 0 or swa_num_used != 0 - token_msg = ( - f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n" - f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n" - ) - elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache): - ( - full_num_used, - mamba_num_used, - _, - _, - full_available_size, - full_evictable_size, - mamba_available_size, - mamba_evictable_size, - ) = self._get_mamba_token_info() - memory_leak = ( - full_num_used != self.tree_cache.full_protected_size() - or mamba_num_used != self.tree_cache.mamba_protected_size() - ) - token_msg = ( - f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n" - f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n" - ) - else: - _, _, available_size, evictable_size = self._get_token_info() - protected_size = self.tree_cache.protected_size() - memory_leak = (available_size + evictable_size) != ( - # self.max_total_num_tokens - # if not self.enable_hierarchical_cache - # else self.max_total_num_tokens - protected_size - self.max_total_num_tokens - - protected_size - ) - token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" - - if memory_leak: - msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" - raise ValueError(msg) - - if self.disaggregation_mode == DisaggregationMode.DECODE: - req_total_size = ( - self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size - ) - else: - req_total_size = self.req_to_token_pool.size - - if len(self.req_to_token_pool.free_slots) != req_total_size: - msg = ( - "req_to_token_pool memory leak detected!" - f"available_size={len(self.req_to_token_pool.free_slots)}, " - f"total_size={self.req_to_token_pool.size}\n" - ) - raise ValueError(msg) - - if ( - self.enable_metrics - and self.current_scheduler_metrics_enabled() - and time.perf_counter() > self.metrics_collector.last_log_time + 30 - ): - # During idle time, also collect metrics every 30 seconds. - if self.is_hybrid: - ( - full_num_used, - swa_num_used, - full_token_usage, - swa_token_usage, - _, - _, - _, - _, - ) = self._get_swa_token_info() - num_used = max(full_num_used, swa_num_used) - token_usage = max(full_token_usage, swa_token_usage) - elif self.is_hybrid_gdn: - ( - num_used, - _, - token_usage, - _, - _, - _, - _, - _, - ) = self._get_mamba_token_info() - else: - num_used, token_usage, _, _ = self._get_token_info() - num_running_reqs = len(self.running_batch.reqs) - self.stats.num_running_reqs = num_running_reqs - self.stats.num_used_tokens = num_used - self.stats.token_usage = round(token_usage, 2) - self.stats.gen_throughput = 0 - self.stats.num_queue_reqs = len(self.waiting_queue) - self.stats.num_grammar_queue_reqs = len(self.grammar_queue) - if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.stats.num_prefill_prealloc_queue_reqs = len( - self.disagg_prefill_bootstrap_queue.queue - ) - self.stats.num_prefill_inflight_queue_reqs = len( - self.disagg_prefill_inflight_queue - ) - if self.disaggregation_mode == DisaggregationMode.DECODE: - self.stats.num_decode_prealloc_queue_reqs = len( - self.disagg_decode_prealloc_queue.queue - ) - self.stats.num_decode_transfer_queue_reqs = len( - self.disagg_decode_transfer_queue.queue - ) - self.metrics_collector.log_stats(self.stats) - self._publish_kv_events() - - def check_tree_cache(self): - if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or ( - self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache) - ): - self.tree_cache.sanity_check() - def _get_token_info(self): available_size = self.token_to_kv_pool_allocator.available_size() evictable_size = self.tree_cache.evictable_size() diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py new file mode 100644 index 000000000..20a57aa83 --- /dev/null +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache +from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache + +if TYPE_CHECKING: + from sglang.srt.managers.scheduler import Scheduler + + +class SchedulerRuntimeCheckerMixin: + + def _check_hybrid_memory(self: Scheduler): + ( + full_num_used, + swa_num_used, + _, + _, + full_available_size, + full_evictable_size, + swa_available_size, + swa_evictable_size, + ) = self._get_swa_token_info() + memory_leak = full_num_used != 0 or swa_num_used != 0 + token_msg = ( + f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n" + f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n" + ) + return memory_leak, token_msg + + def _check_mamba_memory(self: Scheduler): + ( + full_num_used, + mamba_num_used, + _, + _, + full_available_size, + full_evictable_size, + mamba_available_size, + mamba_evictable_size, + ) = self._get_mamba_token_info() + memory_leak = ( + full_num_used != self.tree_cache.full_protected_size() + or mamba_num_used != self.tree_cache.mamba_protected_size() + ) + token_msg = ( + f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n" + f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n" + ) + return memory_leak, token_msg + + def _check_radix_cache_memory(self: Scheduler): + _, _, available_size, evictable_size = self._get_token_info() + protected_size = self.tree_cache.protected_size() + memory_leak = (available_size + evictable_size) != ( + # self.max_total_num_tokens + # if not self.enable_hierarchical_cache + # else self.max_total_num_tokens - protected_size + self.max_total_num_tokens + - protected_size + ) + token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" + return memory_leak, token_msg + + def _check_req_pool(self: Scheduler): + if self.disaggregation_mode == DisaggregationMode.DECODE: + req_total_size = ( + self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size + ) + else: + req_total_size = self.req_to_token_pool.size + + if len(self.req_to_token_pool.free_slots) != req_total_size: + msg = ( + "req_to_token_pool memory leak detected!" + f"available_size={len(self.req_to_token_pool.free_slots)}, " + f"total_size={self.req_to_token_pool.size}\n" + ) + raise ValueError(msg) + + def check_memory(self: Scheduler): + if self.is_hybrid: + memory_leak, token_msg = self._check_hybrid_memory() + elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache): + memory_leak, token_msg = self._check_mamba_memory() + else: + memory_leak, token_msg = self._check_radix_cache_memory() + + if memory_leak: + msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" + raise ValueError(msg) + + self._check_req_pool() + + if ( + self.enable_metrics + and self.current_scheduler_metrics_enabled() + and time.perf_counter() > self.metrics_collector.last_log_time + 30 + ): + # During idle time, also collect metrics every 30 seconds. + if self.is_hybrid: + ( + full_num_used, + swa_num_used, + full_token_usage, + swa_token_usage, + _, + _, + _, + _, + ) = self._get_swa_token_info() + num_used = max(full_num_used, swa_num_used) + token_usage = max(full_token_usage, swa_token_usage) + elif self.is_hybrid_gdn: + ( + num_used, + _, + token_usage, + _, + _, + _, + _, + _, + ) = self._get_mamba_token_info() + else: + num_used, token_usage, _, _ = self._get_token_info() + num_running_reqs = len(self.running_batch.reqs) + self.stats.num_running_reqs = num_running_reqs + self.stats.num_used_tokens = num_used + self.stats.token_usage = round(token_usage, 2) + self.stats.gen_throughput = 0 + self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.num_grammar_queue_reqs = len(self.grammar_queue) + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.stats.num_prefill_prealloc_queue_reqs = len( + self.disagg_prefill_bootstrap_queue.queue + ) + self.stats.num_prefill_inflight_queue_reqs = len( + self.disagg_prefill_inflight_queue + ) + if self.disaggregation_mode == DisaggregationMode.DECODE: + self.stats.num_decode_prealloc_queue_reqs = len( + self.disagg_decode_prealloc_queue.queue + ) + self.stats.num_decode_transfer_queue_reqs = len( + self.disagg_decode_transfer_queue.queue + ) + self.metrics_collector.log_stats(self.stats) + self._publish_kv_events() + + def check_tree_cache(self: Scheduler): + if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or ( + self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache) + ): + self.tree_cache.sanity_check() + + def self_check_during_idle(self: Scheduler): + self.check_memory() + self.check_tree_cache() + self.new_token_ratio = self.init_new_token_ratio + self.maybe_sleep_on_idle()