Hybrid kv cache for LLaMA4 (#6563)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: tarinkk <rt572@physics.rutger.edu>
Co-authored-by: tarinkk <rt572@rutgers.physics.edu>
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
tarinkk
2025-06-27 21:58:55 -04:00
committed by GitHub
parent 357921aa51
commit eb6c2c1663
11 changed files with 519 additions and 59 deletions

View File

@@ -126,7 +126,8 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -570,7 +571,11 @@ class Scheduler(
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
if self.model_config.is_hybrid:
ChunkCacheClass = SWAChunkCache
else:
ChunkCacheClass = ChunkCache
self.tree_cache = ChunkCacheClass(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
@@ -1283,9 +1288,8 @@ class Scheduler(
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
num_new_seq = len(can_run_list)
@@ -1294,7 +1298,7 @@ class Scheduler(
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"{usage_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -1337,9 +1341,8 @@ class Scheduler(
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
if RECORD_STEP_TIME:
@@ -1347,12 +1350,7 @@ class Scheduler(
gap_latency / self.server_args.decode_log_interval
)
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
)
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
@@ -1390,10 +1388,11 @@ class Scheduler(
self._publish_kv_events()
def check_memory(self):
available_size = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
else:
available_token_size = self.token_to_kv_pool_allocator.available_size()
available_size = available_token_size + self.tree_cache.evictable_size()
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
@@ -1404,7 +1403,7 @@ class Scheduler(
msg = (
"token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{available_token_size=}\n"
f"{self.tree_cache.evictable_size()=}\n"
)
raise ValueError(msg)