Move memory runtime checker to mixin class (#12014)
This commit is contained in:
@@ -137,6 +137,9 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
|
||||
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
||||
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
||||
from sglang.srt.managers.scheduler_runtime_checker_mixin import (
|
||||
SchedulerRuntimeCheckerMixin,
|
||||
)
|
||||
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
||||
SchedulerUpdateWeightsMixin,
|
||||
)
|
||||
@@ -207,6 +210,7 @@ class Scheduler(
|
||||
SchedulerMetricsMixin,
|
||||
SchedulerDisaggregationDecodeMixin,
|
||||
SchedulerDisaggregationPrefillMixin,
|
||||
SchedulerRuntimeCheckerMixin,
|
||||
SchedulerPPMixin,
|
||||
):
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
@@ -1506,141 +1510,6 @@ class Scheduler(
|
||||
for tokenized_req in recv_req:
|
||||
self.handle_embedding_request(tokenized_req)
|
||||
|
||||
def self_check_during_idle(self):
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
self.maybe_sleep_on_idle()
|
||||
|
||||
def check_memory(self):
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
swa_available_size,
|
||||
swa_evictable_size,
|
||||
) = self._get_swa_token_info()
|
||||
memory_leak = full_num_used != 0 or swa_num_used != 0
|
||||
token_msg = (
|
||||
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
||||
)
|
||||
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
|
||||
(
|
||||
full_num_used,
|
||||
mamba_num_used,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
mamba_available_size,
|
||||
mamba_evictable_size,
|
||||
) = self._get_mamba_token_info()
|
||||
memory_leak = (
|
||||
full_num_used != self.tree_cache.full_protected_size()
|
||||
or mamba_num_used != self.tree_cache.mamba_protected_size()
|
||||
)
|
||||
token_msg = (
|
||||
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
|
||||
)
|
||||
else:
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = (available_size + evictable_size) != (
|
||||
# self.max_total_num_tokens
|
||||
# if not self.enable_hierarchical_cache
|
||||
# else self.max_total_num_tokens - protected_size
|
||||
self.max_total_num_tokens
|
||||
- protected_size
|
||||
)
|
||||
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
||||
|
||||
if memory_leak:
|
||||
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
req_total_size = (
|
||||
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
|
||||
)
|
||||
else:
|
||||
req_total_size = self.req_to_token_pool.size
|
||||
|
||||
if len(self.req_to_token_pool.free_slots) != req_total_size:
|
||||
msg = (
|
||||
"req_to_token_pool memory leak detected!"
|
||||
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
||||
f"total_size={self.req_to_token_pool.size}\n"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if (
|
||||
self.enable_metrics
|
||||
and self.current_scheduler_metrics_enabled()
|
||||
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
||||
):
|
||||
# During idle time, also collect metrics every 30 seconds.
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
elif self.is_hybrid_gdn:
|
||||
(
|
||||
num_used,
|
||||
_,
|
||||
token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_mamba_token_info()
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.gen_throughput = 0
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.stats.num_prefill_prealloc_queue_reqs = len(
|
||||
self.disagg_prefill_bootstrap_queue.queue
|
||||
)
|
||||
self.stats.num_prefill_inflight_queue_reqs = len(
|
||||
self.disagg_prefill_inflight_queue
|
||||
)
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.stats.num_decode_prealloc_queue_reqs = len(
|
||||
self.disagg_decode_prealloc_queue.queue
|
||||
)
|
||||
self.stats.num_decode_transfer_queue_reqs = len(
|
||||
self.disagg_decode_transfer_queue.queue
|
||||
)
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_tree_cache(self):
|
||||
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
|
||||
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
|
||||
):
|
||||
self.tree_cache.sanity_check()
|
||||
|
||||
def _get_token_info(self):
|
||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
evictable_size = self.tree_cache.evictable_size()
|
||||
|
||||
164
python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
Normal file
164
python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
|
||||
|
||||
class SchedulerRuntimeCheckerMixin:
|
||||
|
||||
def _check_hybrid_memory(self: Scheduler):
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
swa_available_size,
|
||||
swa_evictable_size,
|
||||
) = self._get_swa_token_info()
|
||||
memory_leak = full_num_used != 0 or swa_num_used != 0
|
||||
token_msg = (
|
||||
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
||||
)
|
||||
return memory_leak, token_msg
|
||||
|
||||
def _check_mamba_memory(self: Scheduler):
|
||||
(
|
||||
full_num_used,
|
||||
mamba_num_used,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
mamba_available_size,
|
||||
mamba_evictable_size,
|
||||
) = self._get_mamba_token_info()
|
||||
memory_leak = (
|
||||
full_num_used != self.tree_cache.full_protected_size()
|
||||
or mamba_num_used != self.tree_cache.mamba_protected_size()
|
||||
)
|
||||
token_msg = (
|
||||
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
|
||||
)
|
||||
return memory_leak, token_msg
|
||||
|
||||
def _check_radix_cache_memory(self: Scheduler):
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = (available_size + evictable_size) != (
|
||||
# self.max_total_num_tokens
|
||||
# if not self.enable_hierarchical_cache
|
||||
# else self.max_total_num_tokens - protected_size
|
||||
self.max_total_num_tokens
|
||||
- protected_size
|
||||
)
|
||||
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
||||
return memory_leak, token_msg
|
||||
|
||||
def _check_req_pool(self: Scheduler):
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
req_total_size = (
|
||||
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
|
||||
)
|
||||
else:
|
||||
req_total_size = self.req_to_token_pool.size
|
||||
|
||||
if len(self.req_to_token_pool.free_slots) != req_total_size:
|
||||
msg = (
|
||||
"req_to_token_pool memory leak detected!"
|
||||
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
||||
f"total_size={self.req_to_token_pool.size}\n"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def check_memory(self: Scheduler):
|
||||
if self.is_hybrid:
|
||||
memory_leak, token_msg = self._check_hybrid_memory()
|
||||
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
|
||||
memory_leak, token_msg = self._check_mamba_memory()
|
||||
else:
|
||||
memory_leak, token_msg = self._check_radix_cache_memory()
|
||||
|
||||
if memory_leak:
|
||||
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
||||
raise ValueError(msg)
|
||||
|
||||
self._check_req_pool()
|
||||
|
||||
if (
|
||||
self.enable_metrics
|
||||
and self.current_scheduler_metrics_enabled()
|
||||
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
||||
):
|
||||
# During idle time, also collect metrics every 30 seconds.
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
elif self.is_hybrid_gdn:
|
||||
(
|
||||
num_used,
|
||||
_,
|
||||
token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_mamba_token_info()
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.gen_throughput = 0
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.stats.num_prefill_prealloc_queue_reqs = len(
|
||||
self.disagg_prefill_bootstrap_queue.queue
|
||||
)
|
||||
self.stats.num_prefill_inflight_queue_reqs = len(
|
||||
self.disagg_prefill_inflight_queue
|
||||
)
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.stats.num_decode_prealloc_queue_reqs = len(
|
||||
self.disagg_decode_prealloc_queue.queue
|
||||
)
|
||||
self.stats.num_decode_transfer_queue_reqs = len(
|
||||
self.disagg_decode_transfer_queue.queue
|
||||
)
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_tree_cache(self: Scheduler):
|
||||
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
|
||||
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
|
||||
):
|
||||
self.tree_cache.sanity_check()
|
||||
|
||||
def self_check_during_idle(self: Scheduler):
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
self.maybe_sleep_on_idle()
|
||||
Reference in New Issue
Block a user