Hybrid kv cache for LLaMA4 (#6563)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: tarinkk <rt572@physics.rutger.edu> Co-authored-by: tarinkk <rt572@rutgers.physics.edu> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -56,7 +56,7 @@ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.metrics.collector import TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
@@ -485,6 +485,9 @@ class Req:
|
||||
# for corss-endoder model
|
||||
self.token_type_ids = token_type_ids
|
||||
|
||||
# The length of KV that have been removed in local attention chunked prefill
|
||||
self.evicted_seqlen_local = 0
|
||||
|
||||
# Sampling info
|
||||
if isinstance(sampling_params.custom_params, dict):
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
@@ -1191,6 +1194,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||
)
|
||||
if isinstance(self.tree_cache, SWAChunkCache):
|
||||
self.tree_cache.evict(
|
||||
req, pre_len, self.model_config.attention_chunk_size
|
||||
)
|
||||
|
||||
# If input_embeds are available, store them
|
||||
if req.input_embeds is not None:
|
||||
@@ -1383,7 +1390,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
* buf_multiplier
|
||||
* self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
|
||||
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
||||
return True
|
||||
|
||||
@@ -1564,6 +1570,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
|
||||
# free memory
|
||||
if isinstance(self.tree_cache, SWAChunkCache):
|
||||
for req in self.reqs:
|
||||
self.tree_cache.evict(
|
||||
req, req.seqlen - 1, self.model_config.attention_chunk_size
|
||||
)
|
||||
|
||||
# Allocate memory
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
@@ -1798,7 +1811,6 @@ class ModelWorkerBatch:
|
||||
seq_lens: torch.Tensor
|
||||
# The indices of output tokens in the token_to_kv_pool_allocator
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# The sequence length tensor on CPU
|
||||
seq_lens_cpu: Optional[torch.Tensor]
|
||||
seq_lens_sum: int
|
||||
|
||||
@@ -126,7 +126,8 @@ from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
@@ -570,7 +571,11 @@ class Scheduler(
|
||||
server_args.chunked_prefill_size is not None
|
||||
and server_args.disable_radix_cache
|
||||
):
|
||||
self.tree_cache = ChunkCache(
|
||||
if self.model_config.is_hybrid:
|
||||
ChunkCacheClass = SWAChunkCache
|
||||
else:
|
||||
ChunkCacheClass = ChunkCache
|
||||
self.tree_cache = ChunkCacheClass(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
@@ -1283,9 +1288,8 @@ class Scheduler(
|
||||
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
||||
self.last_prefill_tokens = adder.log_input_tokens
|
||||
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
num_new_seq = len(can_run_list)
|
||||
@@ -1294,7 +1298,7 @@ class Scheduler(
|
||||
f"#new-seq: {num_new_seq}, "
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"{usage_msg}"
|
||||
)
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
@@ -1337,9 +1341,8 @@ class Scheduler(
|
||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(batch.reqs)
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
@@ -1347,12 +1350,7 @@ class Scheduler(
|
||||
gap_latency / self.server_args.decode_log_interval
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"Decode batch. "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
)
|
||||
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
spec_accept_length = 0
|
||||
@@ -1390,10 +1388,11 @@ class Scheduler(
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_memory(self):
|
||||
available_size = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
||||
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
else:
|
||||
available_token_size = self.token_to_kv_pool_allocator.available_size()
|
||||
available_size = available_token_size + self.tree_cache.evictable_size()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = available_size != (
|
||||
self.max_total_num_tokens
|
||||
@@ -1404,7 +1403,7 @@ class Scheduler(
|
||||
msg = (
|
||||
"token_to_kv_pool_allocator memory leak detected! "
|
||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{available_token_size=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
Reference in New Issue
Block a user