Move memory runtime checker to mixin class (#12014)
This commit is contained in:
@@ -137,6 +137,9 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
|||||||
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
|
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
|
||||||
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
||||||
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
||||||
|
from sglang.srt.managers.scheduler_runtime_checker_mixin import (
|
||||||
|
SchedulerRuntimeCheckerMixin,
|
||||||
|
)
|
||||||
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
||||||
SchedulerUpdateWeightsMixin,
|
SchedulerUpdateWeightsMixin,
|
||||||
)
|
)
|
||||||
@@ -207,6 +210,7 @@ class Scheduler(
|
|||||||
SchedulerMetricsMixin,
|
SchedulerMetricsMixin,
|
||||||
SchedulerDisaggregationDecodeMixin,
|
SchedulerDisaggregationDecodeMixin,
|
||||||
SchedulerDisaggregationPrefillMixin,
|
SchedulerDisaggregationPrefillMixin,
|
||||||
|
SchedulerRuntimeCheckerMixin,
|
||||||
SchedulerPPMixin,
|
SchedulerPPMixin,
|
||||||
):
|
):
|
||||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||||
@@ -1506,141 +1510,6 @@ class Scheduler(
|
|||||||
for tokenized_req in recv_req:
|
for tokenized_req in recv_req:
|
||||||
self.handle_embedding_request(tokenized_req)
|
self.handle_embedding_request(tokenized_req)
|
||||||
|
|
||||||
def self_check_during_idle(self):
|
|
||||||
self.check_memory()
|
|
||||||
self.check_tree_cache()
|
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
|
||||||
self.maybe_sleep_on_idle()
|
|
||||||
|
|
||||||
def check_memory(self):
|
|
||||||
if self.is_hybrid:
|
|
||||||
(
|
|
||||||
full_num_used,
|
|
||||||
swa_num_used,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
full_available_size,
|
|
||||||
full_evictable_size,
|
|
||||||
swa_available_size,
|
|
||||||
swa_evictable_size,
|
|
||||||
) = self._get_swa_token_info()
|
|
||||||
memory_leak = full_num_used != 0 or swa_num_used != 0
|
|
||||||
token_msg = (
|
|
||||||
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
|
||||||
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
|
||||||
)
|
|
||||||
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
|
|
||||||
(
|
|
||||||
full_num_used,
|
|
||||||
mamba_num_used,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
full_available_size,
|
|
||||||
full_evictable_size,
|
|
||||||
mamba_available_size,
|
|
||||||
mamba_evictable_size,
|
|
||||||
) = self._get_mamba_token_info()
|
|
||||||
memory_leak = (
|
|
||||||
full_num_used != self.tree_cache.full_protected_size()
|
|
||||||
or mamba_num_used != self.tree_cache.mamba_protected_size()
|
|
||||||
)
|
|
||||||
token_msg = (
|
|
||||||
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
|
|
||||||
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_, _, available_size, evictable_size = self._get_token_info()
|
|
||||||
protected_size = self.tree_cache.protected_size()
|
|
||||||
memory_leak = (available_size + evictable_size) != (
|
|
||||||
# self.max_total_num_tokens
|
|
||||||
# if not self.enable_hierarchical_cache
|
|
||||||
# else self.max_total_num_tokens - protected_size
|
|
||||||
self.max_total_num_tokens
|
|
||||||
- protected_size
|
|
||||||
)
|
|
||||||
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
|
||||||
|
|
||||||
if memory_leak:
|
|
||||||
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
||||||
req_total_size = (
|
|
||||||
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
req_total_size = self.req_to_token_pool.size
|
|
||||||
|
|
||||||
if len(self.req_to_token_pool.free_slots) != req_total_size:
|
|
||||||
msg = (
|
|
||||||
"req_to_token_pool memory leak detected!"
|
|
||||||
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
|
||||||
f"total_size={self.req_to_token_pool.size}\n"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.enable_metrics
|
|
||||||
and self.current_scheduler_metrics_enabled()
|
|
||||||
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
|
||||||
):
|
|
||||||
# During idle time, also collect metrics every 30 seconds.
|
|
||||||
if self.is_hybrid:
|
|
||||||
(
|
|
||||||
full_num_used,
|
|
||||||
swa_num_used,
|
|
||||||
full_token_usage,
|
|
||||||
swa_token_usage,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = self._get_swa_token_info()
|
|
||||||
num_used = max(full_num_used, swa_num_used)
|
|
||||||
token_usage = max(full_token_usage, swa_token_usage)
|
|
||||||
elif self.is_hybrid_gdn:
|
|
||||||
(
|
|
||||||
num_used,
|
|
||||||
_,
|
|
||||||
token_usage,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = self._get_mamba_token_info()
|
|
||||||
else:
|
|
||||||
num_used, token_usage, _, _ = self._get_token_info()
|
|
||||||
num_running_reqs = len(self.running_batch.reqs)
|
|
||||||
self.stats.num_running_reqs = num_running_reqs
|
|
||||||
self.stats.num_used_tokens = num_used
|
|
||||||
self.stats.token_usage = round(token_usage, 2)
|
|
||||||
self.stats.gen_throughput = 0
|
|
||||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
|
||||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
||||||
self.stats.num_prefill_prealloc_queue_reqs = len(
|
|
||||||
self.disagg_prefill_bootstrap_queue.queue
|
|
||||||
)
|
|
||||||
self.stats.num_prefill_inflight_queue_reqs = len(
|
|
||||||
self.disagg_prefill_inflight_queue
|
|
||||||
)
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
||||||
self.stats.num_decode_prealloc_queue_reqs = len(
|
|
||||||
self.disagg_decode_prealloc_queue.queue
|
|
||||||
)
|
|
||||||
self.stats.num_decode_transfer_queue_reqs = len(
|
|
||||||
self.disagg_decode_transfer_queue.queue
|
|
||||||
)
|
|
||||||
self.metrics_collector.log_stats(self.stats)
|
|
||||||
self._publish_kv_events()
|
|
||||||
|
|
||||||
def check_tree_cache(self):
|
|
||||||
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
|
|
||||||
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
|
|
||||||
):
|
|
||||||
self.tree_cache.sanity_check()
|
|
||||||
|
|
||||||
def _get_token_info(self):
|
def _get_token_info(self):
|
||||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||||
evictable_size = self.tree_cache.evictable_size()
|
evictable_size = self.tree_cache.evictable_size()
|
||||||
|
|||||||
164
python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
Normal file
164
python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||||
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.scheduler import Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerRuntimeCheckerMixin:
|
||||||
|
|
||||||
|
def _check_hybrid_memory(self: Scheduler):
|
||||||
|
(
|
||||||
|
full_num_used,
|
||||||
|
swa_num_used,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
full_available_size,
|
||||||
|
full_evictable_size,
|
||||||
|
swa_available_size,
|
||||||
|
swa_evictable_size,
|
||||||
|
) = self._get_swa_token_info()
|
||||||
|
memory_leak = full_num_used != 0 or swa_num_used != 0
|
||||||
|
token_msg = (
|
||||||
|
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||||
|
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
||||||
|
)
|
||||||
|
return memory_leak, token_msg
|
||||||
|
|
||||||
|
def _check_mamba_memory(self: Scheduler):
|
||||||
|
(
|
||||||
|
full_num_used,
|
||||||
|
mamba_num_used,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
full_available_size,
|
||||||
|
full_evictable_size,
|
||||||
|
mamba_available_size,
|
||||||
|
mamba_evictable_size,
|
||||||
|
) = self._get_mamba_token_info()
|
||||||
|
memory_leak = (
|
||||||
|
full_num_used != self.tree_cache.full_protected_size()
|
||||||
|
or mamba_num_used != self.tree_cache.mamba_protected_size()
|
||||||
|
)
|
||||||
|
token_msg = (
|
||||||
|
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||||
|
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
|
||||||
|
)
|
||||||
|
return memory_leak, token_msg
|
||||||
|
|
||||||
|
def _check_radix_cache_memory(self: Scheduler):
|
||||||
|
_, _, available_size, evictable_size = self._get_token_info()
|
||||||
|
protected_size = self.tree_cache.protected_size()
|
||||||
|
memory_leak = (available_size + evictable_size) != (
|
||||||
|
# self.max_total_num_tokens
|
||||||
|
# if not self.enable_hierarchical_cache
|
||||||
|
# else self.max_total_num_tokens - protected_size
|
||||||
|
self.max_total_num_tokens
|
||||||
|
- protected_size
|
||||||
|
)
|
||||||
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
||||||
|
return memory_leak, token_msg
|
||||||
|
|
||||||
|
def _check_req_pool(self: Scheduler):
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
req_total_size = (
|
||||||
|
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
req_total_size = self.req_to_token_pool.size
|
||||||
|
|
||||||
|
if len(self.req_to_token_pool.free_slots) != req_total_size:
|
||||||
|
msg = (
|
||||||
|
"req_to_token_pool memory leak detected!"
|
||||||
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
||||||
|
f"total_size={self.req_to_token_pool.size}\n"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def check_memory(self: Scheduler):
|
||||||
|
if self.is_hybrid:
|
||||||
|
memory_leak, token_msg = self._check_hybrid_memory()
|
||||||
|
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
|
||||||
|
memory_leak, token_msg = self._check_mamba_memory()
|
||||||
|
else:
|
||||||
|
memory_leak, token_msg = self._check_radix_cache_memory()
|
||||||
|
|
||||||
|
if memory_leak:
|
||||||
|
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
self._check_req_pool()
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.enable_metrics
|
||||||
|
and self.current_scheduler_metrics_enabled()
|
||||||
|
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
||||||
|
):
|
||||||
|
# During idle time, also collect metrics every 30 seconds.
|
||||||
|
if self.is_hybrid:
|
||||||
|
(
|
||||||
|
full_num_used,
|
||||||
|
swa_num_used,
|
||||||
|
full_token_usage,
|
||||||
|
swa_token_usage,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self._get_swa_token_info()
|
||||||
|
num_used = max(full_num_used, swa_num_used)
|
||||||
|
token_usage = max(full_token_usage, swa_token_usage)
|
||||||
|
elif self.is_hybrid_gdn:
|
||||||
|
(
|
||||||
|
num_used,
|
||||||
|
_,
|
||||||
|
token_usage,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self._get_mamba_token_info()
|
||||||
|
else:
|
||||||
|
num_used, token_usage, _, _ = self._get_token_info()
|
||||||
|
num_running_reqs = len(self.running_batch.reqs)
|
||||||
|
self.stats.num_running_reqs = num_running_reqs
|
||||||
|
self.stats.num_used_tokens = num_used
|
||||||
|
self.stats.token_usage = round(token_usage, 2)
|
||||||
|
self.stats.gen_throughput = 0
|
||||||
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||||
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
self.stats.num_prefill_prealloc_queue_reqs = len(
|
||||||
|
self.disagg_prefill_bootstrap_queue.queue
|
||||||
|
)
|
||||||
|
self.stats.num_prefill_inflight_queue_reqs = len(
|
||||||
|
self.disagg_prefill_inflight_queue
|
||||||
|
)
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
self.stats.num_decode_prealloc_queue_reqs = len(
|
||||||
|
self.disagg_decode_prealloc_queue.queue
|
||||||
|
)
|
||||||
|
self.stats.num_decode_transfer_queue_reqs = len(
|
||||||
|
self.disagg_decode_transfer_queue.queue
|
||||||
|
)
|
||||||
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
self._publish_kv_events()
|
||||||
|
|
||||||
|
def check_tree_cache(self: Scheduler):
|
||||||
|
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
|
||||||
|
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
|
||||||
|
):
|
||||||
|
self.tree_cache.sanity_check()
|
||||||
|
|
||||||
|
def self_check_during_idle(self: Scheduler):
|
||||||
|
self.check_memory()
|
||||||
|
self.check_tree_cache()
|
||||||
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
self.maybe_sleep_on_idle()
|
||||||
Reference in New Issue
Block a user