Support chunked prefill when radix cache is disabled (#811)

This commit is contained in:
Liangsheng Yin
2024-08-01 00:29:01 -07:00
committed by GitHub
parent ca600e8cd6
commit c020f9ceda
9 changed files with 163 additions and 26 deletions

View File

@@ -0,0 +1,43 @@
from abc import ABC, abstractmethod
class BasePrefixCache(ABC):
"""Cache can be indexed by either rid or key."""
@abstractmethod
def reset(self):
pass
@abstractmethod
def match_prefix(self, **kwargs):
pass
@abstractmethod
def insert(self, **kwargs):
pass
@abstractmethod
def cache_req(self, **kwargs):
pass
@abstractmethod
def evict(self, num_tokens, evict_callback):
pass
@abstractmethod
def inc_lock_ref(self, node):
pass
@abstractmethod
def dec_lock_ref(self, node):
pass
@abstractmethod
def evictable_size(self):
pass
def total_size(self):
raise NotImplementedError
def pretty_print(self):
raise NotImplementedError

View File

@@ -0,0 +1,60 @@
"""Cache for chunked prefill, used when RadixCache is disabled."""
from sglang.srt.mem_cache.base_cache import BasePrefixCache
class ChunkCacheEntry:
def __init__(self, rid, value):
self.rid = rid
self.value = value
class ChunkCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.reset()
def reset(self):
self.entries = {}
def match_prefix(self, rid, **kwargs):
if rid not in self.entries:
return [], None
entry = self.entries[rid]
return entry.value, entry
def cache_req(
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
):
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
if del_in_memory_pool:
assert rid in self.entries
self.req_to_token_pool.free(req_pool_idx)
self.token_to_kv_pool.free(indices)
return
if rid not in self.entries:
self.entries[rid] = ChunkCacheEntry(rid, indices)
entry = self.entries[rid]
entry.value = indices
return indices, entry
def insert(self):
raise NotImplementedError
def evict(self, num_tokens, evict_callback):
pass
def inc_lock_ref(self, node):
return 0
def dec_lock_ref(self, node):
return 0
def evictable_size(self):
return 0

View File

@@ -23,6 +23,8 @@ from collections import defaultdict
import torch
from sglang.srt.mem_cache.base_cache import BasePrefixCache
class TreeNode:
def __init__(self):
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
return i
class RadixCache:
class RadixCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
@@ -62,7 +64,7 @@ class RadixCache:
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key):
def match_prefix(self, key, **kwargs):
if self.disable:
return [], self.root_node
@@ -90,6 +92,7 @@ class RadixCache:
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
**kwargs,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]