RadixCache method adjust (#977)

This commit is contained in:
Liangsheng Yin
2024-08-07 15:52:24 -07:00
committed by GitHub
parent f724f1f1e9
commit 7623091d97
5 changed files with 140 additions and 118 deletions

View File

@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING
import torch
from sglang.srt.mem_cache.base_cache import BasePrefixCache
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class TreeNode:
@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_req(
self,
token_ids,
last_uncached_pos,
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
**kwargs,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
new_prefix_len = self.insert(token_ids, indices.clone())
def cache_finished_req(self, req: "Req", token_ids=None):
"""Cache request when it finishes."""
if token_ids is None:
token_ids = (req.input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.disable:
if del_in_memory_pool:
self.token_to_kv_pool.free(indices)
else:
return torch.tensor([], dtype=torch.int32), self.root_node
self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
# Radix Cache takes one ref in memory pool
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
if del_in_memory_pool:
self.req_to_token_pool.free(req_pool_idx)
else:
cached_indices, new_last_node = self.match_prefix(token_ids)
assert len(cached_indices) == len(token_ids)
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node)
self.req_to_token_pool.req_to_token[
req_pool_idx, last_uncached_pos : len(cached_indices)
] = cached_indices[last_uncached_pos:]
self.dec_lock_ref(old_last_node)
self.inc_lock_ref(new_last_node)
return cached_indices, new_last_node
def cache_unfinished_req(self, req: "Req", token_ids=None):
"""Cache request when it is unfinished."""
if self.disable:
return
if token_ids is None:
token_ids = req.input_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
] = new_indices[len(req.prefix_indices) :]
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self):
self._print_helper(self.root_node, 0)