Support chunked prefill when radix cache is disabled (#811)
This commit is contained in:
43
python/sglang/srt/mem_cache/base_cache.py
Normal file
43
python/sglang/srt/mem_cache/base_cache.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BasePrefixCache(ABC):
|
||||
"""Cache can be indexed by either rid or key."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def match_prefix(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_req(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inc_lock_ref(self, node):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dec_lock_ref(self, node):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evictable_size(self):
|
||||
pass
|
||||
|
||||
def total_size(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def pretty_print(self):
|
||||
raise NotImplementedError
|
||||
60
python/sglang/srt/mem_cache/chunk_cache.py
Normal file
60
python/sglang/srt/mem_cache/chunk_cache.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
||||
|
||||
|
||||
class ChunkCacheEntry:
|
||||
def __init__(self, rid, value):
|
||||
self.rid = rid
|
||||
self.value = value
|
||||
|
||||
|
||||
class ChunkCache(BasePrefixCache):
|
||||
def __init__(self, req_to_token_pool, token_to_kv_pool):
|
||||
self.disable = True
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.entries = {}
|
||||
|
||||
def match_prefix(self, rid, **kwargs):
|
||||
if rid not in self.entries:
|
||||
return [], None
|
||||
|
||||
entry = self.entries[rid]
|
||||
return entry.value, entry
|
||||
|
||||
def cache_req(
|
||||
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
|
||||
):
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||
if del_in_memory_pool:
|
||||
assert rid in self.entries
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.token_to_kv_pool.free(indices)
|
||||
return
|
||||
|
||||
if rid not in self.entries:
|
||||
self.entries[rid] = ChunkCacheEntry(rid, indices)
|
||||
|
||||
entry = self.entries[rid]
|
||||
entry.value = indices
|
||||
return indices, entry
|
||||
|
||||
def insert(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
pass
|
||||
|
||||
def inc_lock_ref(self, node):
|
||||
return 0
|
||||
|
||||
def dec_lock_ref(self, node):
|
||||
return 0
|
||||
|
||||
def evictable_size(self):
|
||||
return 0
|
||||
@@ -23,6 +23,8 @@ from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
||||
|
||||
|
||||
class TreeNode:
|
||||
def __init__(self):
|
||||
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
|
||||
return i
|
||||
|
||||
|
||||
class RadixCache:
|
||||
class RadixCache(BasePrefixCache):
|
||||
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
@@ -62,7 +64,7 @@ class RadixCache:
|
||||
self.root_node.lock_ref = 1
|
||||
self.evictable_size_ = 0
|
||||
|
||||
def match_prefix(self, key):
|
||||
def match_prefix(self, key, **kwargs):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
@@ -90,6 +92,7 @@ class RadixCache:
|
||||
req_pool_idx,
|
||||
del_in_memory_pool=True,
|
||||
old_last_node=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Insert the request into radix cache
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||
|
||||
Reference in New Issue
Block a user