Hybrid kv cache for LLaMA4 (#6563)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: tarinkk <rt572@physics.rutger.edu>
Co-authored-by: tarinkk <rt572@rutgers.physics.edu>
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
tarinkk
2025-06-27 21:58:55 -04:00
committed by GitHub
parent 357921aa51
commit eb6c2c1663
11 changed files with 519 additions and 59 deletions

View File

@@ -56,7 +56,7 @@ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
@@ -485,6 +485,9 @@ class Req:
# for corss-endoder model
self.token_type_ids = token_type_ids
# The length of KV that have been removed in local attention chunked prefill
self.evicted_seqlen_local = 0
# Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
@@ -1191,6 +1194,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
if isinstance(self.tree_cache, SWAChunkCache):
self.tree_cache.evict(
req, pre_len, self.model_config.attention_chunk_size
)
# If input_embeds are available, store them
if req.input_embeds is not None:
@@ -1383,7 +1390,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
* buf_multiplier
* self.token_to_kv_pool_allocator.page_size
)
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
return True
@@ -1564,6 +1570,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens.add_(1)
self.seq_lens_sum += bs
# free memory
if isinstance(self.tree_cache, SWAChunkCache):
for req in self.reqs:
self.tree_cache.evict(
req, req.seqlen - 1, self.model_config.attention_chunk_size
)
# Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1:
self.out_cache_loc = self.alloc_token_slots(bs)
@@ -1798,7 +1811,6 @@ class ModelWorkerBatch:
seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor
# The sequence length tensor on CPU
seq_lens_cpu: Optional[torch.Tensor]
seq_lens_sum: int

View File

@@ -126,7 +126,8 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -570,7 +571,11 @@ class Scheduler(
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
if self.model_config.is_hybrid:
ChunkCacheClass = SWAChunkCache
else:
ChunkCacheClass = ChunkCache
self.tree_cache = ChunkCacheClass(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
@@ -1283,9 +1288,8 @@ class Scheduler(
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
num_new_seq = len(can_run_list)
@@ -1294,7 +1298,7 @@ class Scheduler(
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"{usage_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -1337,9 +1341,8 @@ class Scheduler(
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
if RECORD_STEP_TIME:
@@ -1347,12 +1350,7 @@ class Scheduler(
gap_latency / self.server_args.decode_log_interval
)
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
)
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
@@ -1390,10 +1388,11 @@ class Scheduler(
self._publish_kv_events()
def check_memory(self):
available_size = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
else:
available_token_size = self.token_to_kv_pool_allocator.available_size()
available_size = available_token_size + self.tree_cache.evictable_size()
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
@@ -1404,7 +1403,7 @@ class Scheduler(
msg = (
"token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{available_token_size=}\n"
f"{self.tree_cache.evictable_size()=}\n"
)
raise ValueError(msg)