SWA Prefix Cache (#7367)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -129,10 +129,10 @@ from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
@@ -390,6 +390,14 @@ class Scheduler(
|
||||
global_server_args_dict.update(worker_global_server_args_dict)
|
||||
set_random_seed(self.random_seed)
|
||||
|
||||
# Hybrid
|
||||
self.is_hybrid = self.tp_worker.is_hybrid
|
||||
if self.is_hybrid:
|
||||
self.sliding_window_size = self.tp_worker.sliding_window_size
|
||||
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
|
||||
self.tp_worker.get_tokens_per_layer_info()
|
||||
)
|
||||
|
||||
# Print debug info
|
||||
if tp_rank == 0:
|
||||
avail_mem = get_available_gpu_memory(
|
||||
@@ -570,7 +578,7 @@ class Scheduler(
|
||||
server_args.chunked_prefill_size is not None
|
||||
and server_args.disable_radix_cache
|
||||
):
|
||||
if self.model_config.is_hybrid:
|
||||
if self.is_hybrid:
|
||||
ChunkCacheClass = SWAChunkCache
|
||||
else:
|
||||
ChunkCacheClass = ChunkCache
|
||||
@@ -603,6 +611,17 @@ class Scheduler(
|
||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||
self.tree_cache.cache_controller.layer_done_counter
|
||||
)
|
||||
elif self.is_hybrid:
|
||||
assert (
|
||||
self.server_args.disaggregation_mode == "null"
|
||||
), "Hybrid mode does not support disaggregation yet"
|
||||
self.tree_cache = SWARadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
@@ -774,6 +793,7 @@ class Scheduler(
|
||||
else:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
self.maybe_sleep_on_idle()
|
||||
|
||||
@@ -819,6 +839,7 @@ class Scheduler(
|
||||
elif batch is None:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
self.maybe_sleep_on_idle()
|
||||
|
||||
@@ -955,6 +976,7 @@ class Scheduler(
|
||||
# When the server is idle, self-check and re-init some states
|
||||
if server_is_idle:
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
self.maybe_sleep_on_idle()
|
||||
|
||||
@@ -1306,9 +1328,26 @@ class Scheduler(
|
||||
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
||||
self.last_prefill_tokens = adder.log_input_tokens
|
||||
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
token_msg = (
|
||||
f"full token usage: {full_token_usage:.2f}, "
|
||||
f"swa token usage: {swa_token_usage:.2f}, "
|
||||
)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
token_msg = f"token usage: {token_usage:.2f}, "
|
||||
|
||||
num_new_seq = len(can_run_list)
|
||||
f = (
|
||||
@@ -1316,7 +1355,7 @@ class Scheduler(
|
||||
f"#new-seq: {num_new_seq}, "
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"{usage_msg}"
|
||||
f"{token_msg}"
|
||||
)
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
@@ -1338,7 +1377,7 @@ class Scheduler(
|
||||
)
|
||||
self.stats.num_running_reqs = running_bs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.cache_hit_rate = cache_hit_rate
|
||||
|
||||
@@ -1361,16 +1400,35 @@ class Scheduler(
|
||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(batch.reqs)
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
token_msg = (
|
||||
f"#full token: {full_num_used}, "
|
||||
f"full token usage: {full_token_usage:.2f}, "
|
||||
f"#swa token: {swa_num_used}, "
|
||||
f"swa token usage: {swa_token_usage:.2f}, "
|
||||
)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
self.step_time_dict[num_running_reqs].append(
|
||||
gap_latency / self.server_args.decode_log_interval
|
||||
)
|
||||
|
||||
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
|
||||
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
spec_accept_length = 0
|
||||
@@ -1398,7 +1456,7 @@ class Scheduler(
|
||||
if self.enable_metrics:
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = num_used / self.max_total_num_tokens
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.cache_hit_rate = 0.0
|
||||
self.stats.gen_throughput = self.last_gen_throughput
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
@@ -1409,24 +1467,34 @@ class Scheduler(
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_memory(self):
|
||||
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
||||
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
else:
|
||||
available_token_size = self.token_to_kv_pool_allocator.available_size()
|
||||
available_size = available_token_size + self.tree_cache.evictable_size()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = available_size != (
|
||||
self.max_total_num_tokens
|
||||
if not self.enable_hierarchical_cache
|
||||
else self.max_total_num_tokens - protected_size
|
||||
)
|
||||
if memory_leak:
|
||||
msg = (
|
||||
"token_to_kv_pool_allocator memory leak detected! "
|
||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||
f"{available_token_size=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
swa_available_size,
|
||||
swa_evictable_size,
|
||||
) = self._get_swa_token_info()
|
||||
memory_leak = full_num_used != 0 or swa_num_used != 0
|
||||
token_msg = (
|
||||
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
||||
)
|
||||
else:
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = (available_size + evictable_size) != (
|
||||
self.max_total_num_tokens
|
||||
if not self.enable_hierarchical_cache
|
||||
else self.max_total_num_tokens - protected_size
|
||||
)
|
||||
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
||||
|
||||
if memory_leak:
|
||||
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
@@ -1450,20 +1518,66 @@ class Scheduler(
|
||||
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
||||
):
|
||||
# During idle time, also collect metrics every 30 seconds.
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = num_used / self.max_total_num_tokens
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.gen_throughput = 0
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_tree_cache(self):
|
||||
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
|
||||
self.tree_cache.sanity_check()
|
||||
|
||||
def _get_token_info(self):
|
||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
evictable_size = self.tree_cache.evictable_size()
|
||||
num_used = self.max_total_num_tokens - (available_size + evictable_size)
|
||||
token_usage = num_used / self.max_total_num_tokens
|
||||
return num_used, token_usage, available_size, evictable_size
|
||||
|
||||
def _get_swa_token_info(self):
|
||||
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
full_evictable_size = self.tree_cache.full_evictable_size()
|
||||
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
||||
swa_evictable_size = self.tree_cache.swa_evictable_size()
|
||||
full_num_used = self.full_tokens_per_layer - (
|
||||
full_available_size + full_evictable_size
|
||||
)
|
||||
swa_num_used = self.swa_tokens_per_layer - (
|
||||
swa_available_size + swa_evictable_size
|
||||
)
|
||||
full_token_usage = full_num_used / self.full_tokens_per_layer
|
||||
swa_token_usage = swa_num_used / self.swa_tokens_per_layer
|
||||
return (
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
swa_available_size,
|
||||
swa_evictable_size,
|
||||
)
|
||||
|
||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||
# Merge the prefill batch into the running batch
|
||||
chunked_req_to_exclude = set()
|
||||
@@ -2042,11 +2156,30 @@ class Scheduler(
|
||||
|
||||
if not disable_request_logging():
|
||||
# Print batch size and memory pool info to check whether there are de-sync issues.
|
||||
if self.is_hybrid:
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
swa_available_size,
|
||||
swa_evictable_size,
|
||||
) = self._get_swa_token_info()
|
||||
info_msg = (
|
||||
f"{full_available_size=}, "
|
||||
f"{full_evictable_size=}, "
|
||||
f"{swa_available_size=}, "
|
||||
f"{swa_evictable_size=}, "
|
||||
)
|
||||
else:
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
info_msg = f"{available_size=}, " f"{evictable_size=}, "
|
||||
logger.error(
|
||||
f"{self.cur_batch.batch_size()=}, "
|
||||
f"{self.cur_batch.reqs=}, "
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
||||
f"{self.tree_cache.evictable_size()=}, "
|
||||
f"{info_msg}"
|
||||
)
|
||||
|
||||
pyspy_dump_schedulers()
|
||||
@@ -2101,11 +2234,24 @@ class Scheduler(
|
||||
|
||||
def get_load(self):
|
||||
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
||||
load = (
|
||||
self.max_total_num_tokens
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
- self.tree_cache.evictable_size()
|
||||
)
|
||||
if self.is_hybrid:
|
||||
load_full = (
|
||||
self.full_tokens_per_layer
|
||||
- self.token_to_kv_pool_allocator.full_available_size()
|
||||
- self.tree_cache.full_evictable_size()
|
||||
)
|
||||
load_swa = (
|
||||
self.swa_tokens_per_layer
|
||||
- self.token_to_kv_pool_allocator.swa_available_size()
|
||||
- self.tree_cache.swa_evictable_size()
|
||||
)
|
||||
load = max(load_full, load_swa)
|
||||
else:
|
||||
load = (
|
||||
self.max_total_num_tokens
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
- self.tree_cache.evictable_size()
|
||||
)
|
||||
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
load += sum(
|
||||
|
||||
Reference in New Issue
Block a user