RadixCache method adjust (#977)

This commit is contained in:
Liangsheng Yin
2024-08-07 15:52:24 -07:00
committed by GitHub
parent f724f1f1e9
commit 7623091d97
5 changed files with 140 additions and 118 deletions

View File

@@ -17,7 +17,11 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def cache_req(self, **kwargs):
def cache_finished_req(self, **kwargs):
pass
@abstractmethod
def cache_unfinished_req(self, **kwargs):
pass
@abstractmethod

View File

@@ -1,6 +1,11 @@
"""Cache for chunked prefill, used when RadixCache is disabled."""
from sglang.srt.mem_cache.base_cache import BasePrefixCache
from typing import TYPE_CHECKING
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class ChunkCacheEntry:
@@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache):
entry = self.entries[rid]
return entry.value, entry
def cache_req(
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
):
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
if del_in_memory_pool:
assert rid in self.entries
self.req_to_token_pool.free(req_pool_idx)
self.token_to_kv_pool.free(indices)
return
def cache_finished_req(self, req: "Req", token_ids=None):
if token_ids is None:
token_ids = (req.input_ids + req.output_ids)[:-1]
if rid not in self.entries:
self.entries[rid] = ChunkCacheEntry(rid, indices)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
assert req.rid in self.entries
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices)
entry = self.entries[rid]
entry.value = indices
return indices, entry
def cache_unfinished_req(self, req: "Req", token_ids=None):
if token_ids is None:
token_ids = req.input_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if req.rid not in self.entries:
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
entry = self.entries[req.rid]
entry.value = kv_indices
return kv_indices, entry
def insert(self):
raise NotImplementedError

View File

@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING
import torch
from sglang.srt.mem_cache.base_cache import BasePrefixCache
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class TreeNode:
@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_req(
self,
token_ids,
last_uncached_pos,
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
**kwargs,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
new_prefix_len = self.insert(token_ids, indices.clone())
def cache_finished_req(self, req: "Req", token_ids=None):
"""Cache request when it finishes."""
if token_ids is None:
token_ids = (req.input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.disable:
if del_in_memory_pool:
self.token_to_kv_pool.free(indices)
else:
return torch.tensor([], dtype=torch.int32), self.root_node
self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
# Radix Cache takes one ref in memory pool
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
if del_in_memory_pool:
self.req_to_token_pool.free(req_pool_idx)
else:
cached_indices, new_last_node = self.match_prefix(token_ids)
assert len(cached_indices) == len(token_ids)
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node)
self.req_to_token_pool.req_to_token[
req_pool_idx, last_uncached_pos : len(cached_indices)
] = cached_indices[last_uncached_pos:]
self.dec_lock_ref(old_last_node)
self.inc_lock_ref(new_last_node)
return cached_indices, new_last_node
def cache_unfinished_req(self, req: "Req", token_ids=None):
"""Cache request when it is unfinished."""
if self.disable:
return
if token_ids is None:
token_ids = req.input_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
] = new_indices[len(req.prefix_indices) :]
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self):
self._print_helper(self.root_node, 0)