RadixCache method adjust (#977)
This commit is contained in:
@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
class TreeNode:
|
||||
@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
|
||||
value = [x for x in key]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_req(
|
||||
self,
|
||||
token_ids,
|
||||
last_uncached_pos,
|
||||
req_pool_idx,
|
||||
del_in_memory_pool=True,
|
||||
old_last_node=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Insert the request into radix cache
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||
new_prefix_len = self.insert(token_ids, indices.clone())
|
||||
def cache_finished_req(self, req: "Req", token_ids=None):
|
||||
"""Cache request when it finishes."""
|
||||
if token_ids is None:
|
||||
token_ids = (req.input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
if self.disable:
|
||||
if del_in_memory_pool:
|
||||
self.token_to_kv_pool.free(indices)
|
||||
else:
|
||||
return torch.tensor([], dtype=torch.int32), self.root_node
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||
|
||||
if del_in_memory_pool:
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
else:
|
||||
cached_indices, new_last_node = self.match_prefix(token_ids)
|
||||
assert len(cached_indices) == len(token_ids)
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.dec_lock_ref(req.last_node)
|
||||
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req_pool_idx, last_uncached_pos : len(cached_indices)
|
||||
] = cached_indices[last_uncached_pos:]
|
||||
self.dec_lock_ref(old_last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
return cached_indices, new_last_node
|
||||
def cache_unfinished_req(self, req: "Req", token_ids=None):
|
||||
"""Cache request when it is unfinished."""
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = req.input_ids
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||
assert len(new_indices) == len(token_ids)
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
|
||||
] = new_indices[len(req.prefix_indices) :]
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
self._print_helper(self.root_node, 0)
|
||||
|
||||
Reference in New Issue
Block a user