Files
sglang/python/sglang/srt/mem_cache/chunk_cache.py

68 lines
1.8 KiB
Python
Raw Normal View History

2024-08-11 02:44:59 -07:00
from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
2025-03-12 22:22:39 -07:00
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
import torch
2024-08-07 15:52:24 -07:00
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
2024-08-07 15:52:24 -07:00
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class ChunkCacheEntry:
def __init__(self, rid: str, value: torch.Tensor):
self.rid = rid
self.value = value
class ChunkCache(BasePrefixCache):
2024-08-11 02:44:59 -07:00
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
2024-08-11 02:44:59 -07:00
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
def reset(self):
2025-03-12 22:22:39 -07:00
pass
2025-03-12 22:22:39 -07:00
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
return [], None
2025-03-12 22:22:39 -07:00
def cache_finished_req(self, req: Req):
2024-08-07 15:52:24 -07:00
kv_indices = self.req_to_token_pool.req_to_token[
2025-03-12 22:22:39 -07:00
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
2024-08-07 15:52:24 -07:00
]
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool_allocator.free(kv_indices)
def cache_unfinished_req(self, req: Req):
2024-08-07 15:52:24 -07:00
kv_indices = self.req_to_token_pool.req_to_token[
2025-03-12 22:22:39 -07:00
req.req_pool_idx, : len(req.fill_ids)
2024-08-07 15:52:24 -07:00
]
2025-03-12 22:22:39 -07:00
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
2024-08-07 22:28:42 -07:00
req.prefix_indices = kv_indices
def insert(self):
2024-08-12 03:39:01 -07:00
raise NotImplementedError()
2025-03-12 22:22:39 -07:00
def evict(self, num_tokens: int):
pass
2025-03-12 22:22:39 -07:00
def inc_lock_ref(self, node: Any):
return 0
2025-03-12 22:22:39 -07:00
def dec_lock_ref(self, node: Any):
return 0
def pretty_print(self):
return ""