Support chunked prefill when radix cache is disabled (#811)

This commit is contained in:
Liangsheng Yin
2024-08-01 00:29:01 -07:00
committed by GitHub
parent ca600e8cd6
commit c020f9ceda
9 changed files with 163 additions and 26 deletions

View File

@@ -23,6 +23,8 @@ from collections import defaultdict
import torch
from sglang.srt.mem_cache.base_cache import BasePrefixCache
class TreeNode:
def __init__(self):
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
return i
class RadixCache:
class RadixCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
@@ -62,7 +64,7 @@ class RadixCache:
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key):
def match_prefix(self, key, **kwargs):
if self.disable:
return [], self.root_node
@@ -90,6 +92,7 @@ class RadixCache:
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
**kwargs,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]