Hybrid kv cache for LLaMA4 (#6563)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: tarinkk <rt572@physics.rutger.edu> Co-authored-by: tarinkk <rt572@rutgers.physics.edu> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -126,7 +126,8 @@ from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
@@ -570,7 +571,11 @@ class Scheduler(
|
||||
server_args.chunked_prefill_size is not None
|
||||
and server_args.disable_radix_cache
|
||||
):
|
||||
self.tree_cache = ChunkCache(
|
||||
if self.model_config.is_hybrid:
|
||||
ChunkCacheClass = SWAChunkCache
|
||||
else:
|
||||
ChunkCacheClass = ChunkCache
|
||||
self.tree_cache = ChunkCacheClass(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
@@ -1283,9 +1288,8 @@ class Scheduler(
|
||||
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
||||
self.last_prefill_tokens = adder.log_input_tokens
|
||||
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
num_new_seq = len(can_run_list)
|
||||
@@ -1294,7 +1298,7 @@ class Scheduler(
|
||||
f"#new-seq: {num_new_seq}, "
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"{usage_msg}"
|
||||
)
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
@@ -1337,9 +1341,8 @@ class Scheduler(
|
||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(batch.reqs)
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
@@ -1347,12 +1350,7 @@ class Scheduler(
|
||||
gap_latency / self.server_args.decode_log_interval
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"Decode batch. "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
)
|
||||
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
spec_accept_length = 0
|
||||
@@ -1390,10 +1388,11 @@ class Scheduler(
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_memory(self):
|
||||
available_size = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
||||
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
else:
|
||||
available_token_size = self.token_to_kv_pool_allocator.available_size()
|
||||
available_size = available_token_size + self.tree_cache.evictable_size()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = available_size != (
|
||||
self.max_total_num_tokens
|
||||
@@ -1404,7 +1403,7 @@ class Scheduler(
|
||||
msg = (
|
||||
"token_to_kv_pool_allocator memory leak detected! "
|
||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{available_token_size=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
Reference in New Issue
Block a user