SWA Prefix Cache (#7367)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Hanming Lu
2025-07-13 12:31:07 -07:00
committed by GitHub
parent 0c55cbcfc5
commit 9379da77de
16 changed files with 1742 additions and 158 deletions

View File

@@ -52,10 +52,14 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -527,6 +531,8 @@ class Req:
self.last_node: Any = None
self.last_host_node: Any = None
self.host_hit_length = 0
# The node to lock until for swa radix tree lock ref
self.swa_uuid_for_lock: Optional[int] = None
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
@@ -745,6 +751,7 @@ class Req:
def reset_for_retract(self):
self.prefix_indices = []
self.last_node = None
self.swa_uuid_for_lock = None
self.extend_input_len = 0
self.is_retracted = True
self.input_token_logprobs = None
@@ -813,6 +820,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None
is_hybrid: bool = False
# Batch configs
model_config: ModelConfig = None
@@ -918,11 +926,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
):
return_logprob = any(req.return_logprob for req in reqs)
is_hybrid = False
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
assert isinstance(tree_cache, SWARadixCache) or isinstance(
tree_cache, SWAChunkCache
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
is_hybrid = True
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache,
is_hybrid=is_hybrid,
model_config=model_config,
enable_overlap=enable_overlap,
return_logprob=return_logprob,
@@ -953,9 +969,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
return req_pool_indices
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens)
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
@@ -966,7 +980,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
error_msg = (
f"{phase_str} out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n"
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
if self.tree_cache is not None:
@@ -986,16 +1000,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_num_tokens: int,
backup_state: bool = False,
):
if (
self.token_to_kv_pool_allocator.available_size()
< extend_num_tokens
num_tokens = (
extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
):
if self.tree_cache is not None:
self.tree_cache.evict(
extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
)
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
@@ -1007,9 +1016,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
error_msg = (
f"Prefill out of memory. Try to lower your batch size.\n"
f"Try to allocate {extend_num_tokens} tokens.\n"
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
@@ -1025,14 +1032,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
last_loc: torch.Tensor,
backup_state: bool = False,
):
if self.tree_cache is not None:
if (
self.token_to_kv_pool_allocator.available_size()
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
):
self.tree_cache.evict(
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
@@ -1042,9 +1044,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
f"Try to allocate {len(seq_lens)} tokens.\n"
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
@@ -1181,7 +1181,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
if isinstance(self.tree_cache, SWAChunkCache):
self.tree_cache.evict(
self.tree_cache.evict_swa(
req, pre_len, self.model_config.attention_chunk_size
)
@@ -1371,17 +1371,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
def check_decode_mem(self, buf_multiplier=1):
tokens_required = (
num_tokens = (
self.new_page_count_next_decode()
* buf_multiplier
* self.token_to_kv_pool_allocator.page_size
)
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
return True
self.tree_cache.evict(tokens_required)
return self.token_to_kv_pool_allocator.available_size() >= tokens_required
self._evict_tree_cache_if_needed(num_tokens)
return self._is_available_size_sufficient(num_tokens)
def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory."""
@@ -1414,19 +1411,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
def _get_available_size():
if self.is_hybrid:
return min(
self.token_to_kv_pool_allocator.full_available_size(),
self.token_to_kv_pool_allocator.swa_available_size(),
)
else:
return self.token_to_kv_pool_allocator.available_size()
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
self.token_to_kv_pool_allocator.available_size()
< get_required_tokens(len(sorted_indices))
_get_available_size() < get_required_tokens(len(sorted_indices))
or first_iter
):
if len(sorted_indices) == 1:
# Corner case: only one request left
assert (
self.token_to_kv_pool_allocator.available_size() > 0
), "No space left for only one request"
if self.is_hybrid:
full_available_size = (
self.token_to_kv_pool_allocator.full_available_size()
)
swa_available_size = (
self.token_to_kv_pool_allocator.swa_available_size()
)
assert (
full_available_size > 0 and swa_available_size > 0
), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
else:
assert (
self.token_to_kv_pool_allocator.available_size() > 0
), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
break
first_iter = False
@@ -1458,15 +1474,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool_allocator.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size)
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract()
@@ -1559,7 +1574,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# free memory
if isinstance(self.tree_cache, SWAChunkCache):
for req in self.reqs:
self.tree_cache.evict(
self.tree_cache.evict_swa(
req, req.seqlen - 1, self.model_config.attention_chunk_size
)
@@ -1778,6 +1793,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_extend_in_batch=self.is_extend_in_batch,
)
def _evict_tree_cache_if_needed(
self,
num_tokens: int,
) -> None:
if isinstance(self.tree_cache, SWAChunkCache):
return
if self.is_hybrid:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
if full_available_size < num_tokens or swa_available_size < num_tokens:
if self.tree_cache is not None:
full_num_tokens = max(0, num_tokens - full_available_size)
swa_num_tokens = max(0, num_tokens - swa_available_size)
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
else:
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens)
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
if self.is_hybrid:
return (
self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
)
else:
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
def _available_and_evictable_str(self) -> str:
if self.is_hybrid:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
swa_evictable_size = self.tree_cache.swa_evictable_size()
return (
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
)
else:
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "

View File

@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
@@ -311,21 +312,43 @@ class PrefillAdder:
]
)
@property
def rem_total_tokens(self):
return (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.rem_total_token_offset
self.is_hybrid = isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
@property
def rem_total_tokens(self):
if self.is_hybrid:
available_and_evictable = min(
self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.full_evictable_size(),
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
return available_and_evictable - self.rem_total_token_offset
@property
def cur_rem_tokens(self):
return (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.cur_rem_token_offset
)
if self.is_hybrid:
available_and_evictable = min(
self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.full_evictable_size(),
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
return available_and_evictable - self.cur_rem_token_offset
def ceil_paged_tokens(self, tokens: int) -> int:
return -(-tokens // self.page_size) * self.page_size
@@ -376,11 +399,18 @@ class PrefillAdder:
@contextmanager
def _lock_node(self, last_node: TreeNode):
try:
self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
self.tree_cache.dec_lock_ref(last_node)
if self.is_hybrid:
try:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock)
else:
try:
self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
self.tree_cache.dec_lock_ref(last_node)
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
# Early exit if no enough tokens for the input tokens
@@ -422,16 +452,19 @@ class PrefillAdder:
else:
add_req_state(req, insert_sort=True)
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
# tokens_left gives a reservative calculation as the last token is not stored
bs = len(self.req_states) - i
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
# reserve tokens for corner cases
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied
if not self.is_hybrid:
# Skip this logic for swa. The SWA has different memory management, and
# this mechanism is underestimating the memory usage.
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
# tokens_left gives a reservative calculation as the last token is not stored
bs = len(self.req_states) - i
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
# reserve tokens for corner cases
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied
if (
self.rem_chunk_tokens is None # chunked prefill is disabled
@@ -499,7 +532,11 @@ class PrefillAdder:
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill
self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
if self.is_hybrid:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
req.swa_uuid_for_lock = swa_uuid_for_lock
else:
self.tree_cache.inc_lock_ref(req.last_node)
self._update_prefill_budget(
prefix_len,
input_tokens,
@@ -520,7 +557,11 @@ class PrefillAdder:
self.can_run_list.append(req)
self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node)
if self.is_hybrid:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
req.swa_uuid_for_lock = swa_uuid_for_lock
else:
self.tree_cache.inc_lock_ref(req.last_node)
self._update_prefill_budget(prefix_len, trunc_len, 0)
return self.budget_state()

View File

@@ -129,10 +129,10 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.reasoning_parser import ReasoningParser
@@ -390,6 +390,14 @@ class Scheduler(
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Hybrid
self.is_hybrid = self.tp_worker.is_hybrid
if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
self.tp_worker.get_tokens_per_layer_info()
)
# Print debug info
if tp_rank == 0:
avail_mem = get_available_gpu_memory(
@@ -570,7 +578,7 @@ class Scheduler(
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
if self.model_config.is_hybrid:
if self.is_hybrid:
ChunkCacheClass = SWAChunkCache
else:
ChunkCacheClass = ChunkCache
@@ -603,6 +611,17 @@ class Scheduler(
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
)
elif self.is_hybrid:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid mode does not support disaggregation yet"
self.tree_cache = SWARadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
sliding_window_size=self.sliding_window_size,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
else:
self.tree_cache = RadixCache(
@@ -774,6 +793,7 @@ class Scheduler(
else:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
@@ -819,6 +839,7 @@ class Scheduler(
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
@@ -955,6 +976,7 @@ class Scheduler(
# When the server is idle, self-check and re-init some states
if server_is_idle:
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
@@ -1306,9 +1328,26 @@ class Scheduler(
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list)
f = (
@@ -1316,7 +1355,7 @@ class Scheduler(
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"{usage_msg}"
f"{token_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -1338,7 +1377,7 @@ class Scheduler(
)
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.token_usage = round(token_usage, 2)
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
@@ -1361,16 +1400,35 @@ class Scheduler(
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
@@ -1398,7 +1456,7 @@ class Scheduler(
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.token_usage = round(token_usage, 2)
self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
@@ -1409,24 +1467,34 @@ class Scheduler(
self._publish_kv_events()
def check_memory(self):
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
else:
available_token_size = self.token_to_kv_pool_allocator.available_size()
available_size = available_token_size + self.tree_cache.evictable_size()
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size
)
if memory_leak:
msg = (
"token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{available_token_size=}\n"
f"{self.tree_cache.evictable_size()=}\n"
if self.is_hybrid:
(
full_num_used,
swa_num_used,
_,
_,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
memory_leak = full_num_used != 0 or swa_num_used != 0
token_msg = (
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
)
else:
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
memory_leak = (available_size + evictable_size) != (
self.max_total_num_tokens
if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size
)
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
if memory_leak:
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
raise ValueError(msg)
if self.disaggregation_mode == DisaggregationMode.DECODE:
@@ -1450,20 +1518,66 @@ class Scheduler(
and time.perf_counter() > self.metrics_collector.last_log_time + 30
):
# During idle time, also collect metrics every 30 seconds.
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
else:
num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.token_usage = round(token_usage, 2)
self.stats.gen_throughput = 0
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def check_tree_cache(self):
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
self.tree_cache.sanity_check()
def _get_token_info(self):
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
num_used = self.max_total_num_tokens - (available_size + evictable_size)
token_usage = num_used / self.max_total_num_tokens
return num_used, token_usage, available_size, evictable_size
def _get_swa_token_info(self):
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
swa_evictable_size = self.tree_cache.swa_evictable_size()
full_num_used = self.full_tokens_per_layer - (
full_available_size + full_evictable_size
)
swa_num_used = self.swa_tokens_per_layer - (
swa_available_size + swa_evictable_size
)
full_token_usage = full_num_used / self.full_tokens_per_layer
swa_token_usage = swa_num_used / self.swa_tokens_per_layer
return (
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
)
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
chunked_req_to_exclude = set()
@@ -2042,11 +2156,30 @@ class Scheduler(
if not disable_request_logging():
# Print batch size and memory pool info to check whether there are de-sync issues.
if self.is_hybrid:
(
_,
_,
_,
_,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
info_msg = (
f"{full_available_size=}, "
f"{full_evictable_size=}, "
f"{swa_available_size=}, "
f"{swa_evictable_size=}, "
)
else:
_, _, available_size, evictable_size = self._get_token_info()
info_msg = f"{available_size=}, " f"{evictable_size=}, "
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
f"{info_msg}"
)
pyspy_dump_schedulers()
@@ -2101,11 +2234,24 @@ class Scheduler(
def get_load(self):
# TODO(lsyin): use dynamically maintained num_waiting_tokens
load = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
if self.is_hybrid:
load_full = (
self.full_tokens_per_layer
- self.token_to_kv_pool_allocator.full_available_size()
- self.tree_cache.full_evictable_size()
)
load_swa = (
self.swa_tokens_per_layer
- self.token_to_kv_pool_allocator.swa_available_size()
- self.tree_cache.swa_evictable_size()
)
load = max(load_full, load_swa)
else:
load = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
load += sum(

View File

@@ -174,6 +174,20 @@ class TpModelWorker:
self.model_runner.token_to_kv_pool.size,
)
@property
def sliding_window_size(self) -> Optional[int]:
return self.model_runner.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.model_runner.is_hybrid is not None
def get_tokens_per_layer_info(self):
return (
self.model_runner.full_max_total_num_tokens,
self.model_runner.swa_max_total_num_tokens,
)
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)

View File

@@ -102,6 +102,17 @@ class TpModelWorkerClient:
def get_worker_info(self):
return self.worker.get_worker_info()
def get_tokens_per_layer_info(self):
return self.worker.get_tokens_per_layer_info()
@property
def sliding_window_size(self) -> Optional[int]:
return self.worker.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.worker.is_hybrid
def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func()