101 lines
3.1 KiB
Python
101 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
|
|
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from sglang.srt.mem_cache.allocator import (
|
|
BaseTokenToKVPoolAllocator,
|
|
SWATokenToKVPoolAllocator,
|
|
)
|
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.managers.schedule_batch import Req
|
|
|
|
|
|
class ChunkCache(BasePrefixCache):
|
|
def __init__(
|
|
self,
|
|
req_to_token_pool: ReqToTokenPool,
|
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
|
page_size: int,
|
|
):
|
|
self.req_to_token_pool = req_to_token_pool
|
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
|
self.page_size = page_size
|
|
|
|
def reset(self):
|
|
pass
|
|
|
|
def match_prefix(self, **unused_kwargs) -> MatchResult:
|
|
return MatchResult(
|
|
device_indices=torch.empty((0,), dtype=torch.int64),
|
|
last_device_node=None,
|
|
last_host_node=None,
|
|
)
|
|
|
|
def cache_finished_req(self, req: Req):
|
|
kv_indices = self.req_to_token_pool.req_to_token[
|
|
req.req_pool_idx,
|
|
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
|
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
|
]
|
|
self.req_to_token_pool.free(req.req_pool_idx)
|
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
|
|
|
def cache_unfinished_req(self, req: Req):
|
|
kv_indices = self.req_to_token_pool.req_to_token[
|
|
req.req_pool_idx, : len(req.fill_ids)
|
|
]
|
|
|
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
|
req.prefix_indices = kv_indices
|
|
|
|
def evict(self, num_tokens: int):
|
|
pass
|
|
|
|
def inc_lock_ref(self, node: Any):
|
|
return 0
|
|
|
|
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
|
|
return 0
|
|
|
|
def pretty_print(self):
|
|
return ""
|
|
|
|
|
|
class SWAChunkCache(ChunkCache):
|
|
"""ChunkCache with support for hybrid KV cache operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
req_to_token_pool: ReqToTokenPool,
|
|
token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
|
|
page_size: int,
|
|
):
|
|
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
|
|
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
|
|
|
|
def evict_swa(
|
|
self,
|
|
req: Req,
|
|
prelen: int,
|
|
attention_chunk_size: int,
|
|
):
|
|
if prelen >= req.evicted_seqlen_local + attention_chunk_size:
|
|
new_evicted_seqlen_local = attention_chunk_size * (
|
|
prelen // attention_chunk_size
|
|
)
|
|
free_slots = self.req_to_token_pool.req_to_token[
|
|
req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local
|
|
]
|
|
self.token_to_kv_pool_allocator.free_swa(free_slots)
|
|
req.evicted_seqlen_local = new_evicted_seqlen_local
|
|
|
|
def evict(self, num_tokens: int):
|
|
pass
|