Support page size > 1 (#4356)

This commit is contained in:
Lianmin Zheng
2025-03-12 22:22:39 -07:00
committed by GitHub
parent 2f6bacee03
commit c76040e31b
23 changed files with 877 additions and 284 deletions

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple
from typing import Any, List, Tuple
class BasePrefixCache(ABC):
@@ -26,24 +26,22 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def evict(self, num_tokens: int, evict_callback: Callable):
def evict(self, num_tokens: int):
pass
@abstractmethod
def inc_lock_ref(self, node):
def inc_lock_ref(self, node: Any):
pass
@abstractmethod
def dec_lock_ref(self, node):
def dec_lock_ref(self, node: Any):
pass
@abstractmethod
def evictable_size(self):
pass
return 0
@abstractmethod
def protected_size(self):
raise NotImplementedError()
return 0
def total_size(self):
raise NotImplementedError()

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
import torch
@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache):
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset()
def reset(self):
self.entries = {}
pass
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
if rid not in self.entries:
return [], None
entry = self.entries[rid]
max_prefix_len = len(key)
return entry.value[:max_prefix_len], entry
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None:
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_id_len = len(token_ids)
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
return [], None
def cache_finished_req(self, req: Req):
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
]
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool_allocator.free(kv_indices)
if req.rid in self.entries:
del self.entries[req.rid]
def cache_unfinished_req(self, req: Req):
token_id_len = len(req.fill_ids)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len
req.req_pool_idx, : len(req.fill_ids)
]
if req.rid not in self.entries:
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
entry = self.entries[req.rid]
entry.value = kv_indices
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices
req.last_node = entry
def insert(self):
raise NotImplementedError()
def evict(self, num_tokens: int, evict_callback: Callable):
def evict(self, num_tokens: int):
pass
def inc_lock_ref(self, node):
def inc_lock_ref(self, node: Any):
return 0
def dec_lock_ref(self, node):
return 0
def evictable_size(self):
return 0
def pretty_print(self):
return ""
def protected_size(self):
def dec_lock_ref(self, node: Any):
return 0
def pretty_print(self):

View File

@@ -7,13 +7,13 @@ from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
logger = logging.getLogger(__name__)
@@ -122,7 +122,7 @@ class HiRadixCache(RadixCache):
def evictable_size(self):
return self.evictable_size_
def evict(self, num_tokens: int, evict_callback=None):
def evict(self, num_tokens: int):
leaves = self._collect_leaves_device()
heapq.heapify(leaves)

View File

@@ -129,6 +129,7 @@ class TokenToKVPoolAllocator:
self.size = size
self.dtype = dtype
self.device = device
self.page_size = 1
self.free_slots = None
self.is_not_in_free_group = True
@@ -149,15 +150,14 @@ class TokenToKVPoolAllocator:
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index.to(self.device, non_blocking=True)
return select_index
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
self.free_slots = torch.concat((self.free_slots, free_index))
else:
self.free_group.append(free_index)
@@ -172,7 +172,9 @@ class TokenToKVPoolAllocator:
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
self.free_slots = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_in_free_group = False
self.free_group = []
@@ -182,6 +184,7 @@ class MHATokenToKVPool(KVCache):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
@@ -190,6 +193,7 @@ class MHATokenToKVPool(KVCache):
enable_memory_saver: bool,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
@@ -207,6 +211,8 @@ class MHATokenToKVPool(KVCache):
self._create_buffers()
self.layer_transfer_counter = None
self.capture_mode = False
self.alt_stream = torch.cuda.Stream()
k_size, v_size = self.get_kv_size_bytes()
logger.info(
@@ -218,16 +224,16 @@ class MHATokenToKVPool(KVCache):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
@@ -315,14 +321,44 @@ class MHATokenToKVPool(KVCache):
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
if self.capture_mode:
self.alt_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.alt_stream):
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
torch.cuda.current_stream().wait_stream(self.alt_stream)
else:
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
@torch.compile
def fused_downcast(
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
dtype: torch.dtype,
store_dtype: torch.dtype,
max_fp8: float,
min_fp8: float,
):
cache_k = cache_k / k_scale
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
cache_v = cache_v / v_scale
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
cache_k = cache_k.to(dtype)
cache_v = cache_v.to(dtype)
cache_k = cache_k.view(store_dtype)
cache_v = cache_v.view(store_dtype)
return cache_k, cache_v
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@torch.compile(dynamic=True, backend=get_compiler_backend())
@@ -335,6 +371,7 @@ class MLATokenToKVPool(KVCache):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
@@ -359,8 +396,8 @@ class MLATokenToKVPool(KVCache):
with memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
@@ -400,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
@@ -409,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache):
enable_memory_saver: bool,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
@@ -423,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache):
with memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
torch.zeros(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
@@ -528,7 +571,7 @@ class MHATokenToKVPoolHost:
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
)
self.kv_buffer = torch.empty(
self.kv_buffer = torch.zeros(
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
@@ -548,9 +591,6 @@ class MHATokenToKVPoolHost:
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data

View File

@@ -0,0 +1,283 @@
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Page-aligned memory pool.
"""
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_bool_env_var, next_power_of_2
@triton.jit
def alloc_extend_kernel(
pre_lens_ptr,
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
max_num_extend_tokens: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
extend_lens = seq_lens - pre_lens
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = tl.load(pre_lens_ptr + pid)
extend_len = seq_len - pre_len
sum_extend_lens = tl.sum(extend_lens)
output_start_loc = sum_extend_lens - extend_len
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
tl.int64
)
tl.store(ret_values, merged_value)
# Part 1: fill the old partial page
last_loc = tl.load(last_loc_ptr + pid)
num_part1 = (
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
)
offset_one_page = tl.arange(0, page_size)
tl.store(
out_indices + output_start_loc + offset_one_page,
last_loc + 1 + offset_one_page,
mask=offset_one_page < num_part1,
)
if pre_len + num_part1 == seq_len:
return
# Part 2: fill the new full pages
num_part2 = (
seq_len // page_size * page_size
- (pre_len + page_size - 1) // page_size * page_size
)
offset_many_page = tl.arange(0, max_num_extend_tokens)
page_start = tl.load(
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
mask=offset_many_page < num_part2,
)
tl.store(
out_indices + output_start_loc + num_part1 + offset_many_page,
page_start * page_size + offset_many_page % page_size,
mask=offset_many_page < num_part2,
)
if pre_len + num_part1 + num_part2 == seq_len:
return
# Part 3: fill the new partial page
num_part3 = seq_len - seq_len // page_size * page_size
start_loc = tl.load(
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
)
tl.store(
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
start_loc * page_size + offset_one_page,
mask=offset_one_page < num_part3,
)
@triton.jit
def alloc_decode_kernel(
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = seq_len - 1
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
tl.store(ret_values, sum_num_new_pages)
if num_page_start_loc_self == 0:
last_loc = tl.load(last_loc_ptr + pid)
tl.store(out_indices + pid, last_loc + 1)
else:
page = tl.load(free_page_ptr + new_page_start_loc)
tl.store(out_indices + pid, page * page_size)
class PagedTokenToKVPoolAllocator:
"""
An allocator managing the indices to kv cache data.
This class has the same interface as `TokenToKVPoolAllocator` but the output
of one request is always page-aligned.
TODO: fuse last_loc into the kernel.
"""
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = page_size
self.num_pages = size // page_size
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self._kvcache = kvcache
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
def available_size(self):
return len(self.free_pages) * self.page_size
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
next_power_of_2(extend_num_tokens),
)
merged_value = self.ret_values.item()
num_new_pages = merged_value >> 32
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
)
num_new_pages = self.ret_values.item()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size)
self.free_pages = torch.cat((free_page_indices, self.free_pages))
else:
self.free_group.append(free_index)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.concat(self.free_group))
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.num_pages + 1, dtype=torch.int64, device=self.device
)
self.is_in_free_group = False
self.free_group = []

View File

@@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
@@ -67,7 +68,7 @@ class TreeNode:
return self.last_access_time < other.last_access_time
def _key_match(key0: List, key1: List):
def _key_match_page_size1(key0: List, key1: List):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
@@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List):
return i
def _key_match_paged(key0: List, key1: List, page_size: int):
min_len = min(len(key0), len(key1))
i = 0
while i < min_len:
if key0[i : i + page_size] != key1[i : i + page_size]:
break
i += page_size
return i
class RadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
else:
self.device = torch.device("cpu")
if self.page_size == 1:
self.key_match_fn = _key_match_page_size1
self.get_child_key_fn = lambda key: key[0]
else:
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = lambda key: tuple(key[:page_size])
self.reset()
##### Public API #####
@@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache):
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if self.disable:
return [], self.root_node
if self.disable or len(key) == 0:
return (
torch.empty(
(0,),
dtype=torch.int32,
device=self.device,
),
self.root_node,
)
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int32)
value = torch.empty((0,), dtype=torch.int32, device=self.device)
return value, last_node
def insert(self, key: List, value=None):
@@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache):
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
def cache_finished_req(self, req: Req):
"""Cache request when it finishes."""
if self.disable:
if token_ids is None:
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_ids_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_ids_len
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
]
self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
new_prefix_len = self.insert(
token_ids[:page_aligned_len], page_aligned_kv_indices
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
@@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
def cache_unfinished_req(self, req: Req):
"""Cache request when it is unfinished."""
if self.disable:
return
if token_ids is None:
token_ids = req.fill_ids
token_ids = req.fill_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids)
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
@@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache):
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
req.prefix_indices = new_indices
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
if self.page_size != 1:
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self):
@@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache):
def total_size(self):
return self._total_size_helper()
def evict(self, num_tokens: int, evict_callback: Callable):
def evict(self, num_tokens: int):
if self.disable:
return
@@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache):
if x.lock_ref > 0:
continue
evict_callback(x.value)
self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
@@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache):
# protected size refers to the size of the cache that is locked
return self.protected_size_
def all_values_flatten(self):
values = []
def _dfs_helper(node: TreeNode):
for _, child in node.children.items():
values.append(child.value)
_dfs_helper(child)
_dfs_helper(self.root_node)
return torch.concat(values)
##### Internal Helper Functions #####
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
@@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache):
value.append(child.value)
node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
return value, node
def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len]: child}
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
@@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache):
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:]
new_node.parent.children[key[0]] = new_node
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
def _insert_helper(self, node: TreeNode, key: List, value):
@@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache):
if len(key) == 0:
return 0
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
while len(key) > 0 and key[0] in node.children.keys():
node = node.children[key[0]]
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.time()
prefix_len = _key_match(node.key, key)
prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len
key = key[prefix_len:]
value = value[prefix_len:]
@@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[key[0]] = new_node
node.children[child_key] = new_node
self.evictable_size_ += len(value)
return total_prefix_length
@@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache):
current_node.key[:10],
f"r={current_node.lock_ref}",
)
for _, child in current_node.children.items():
for key, child in current_node.children.items():
stack.append((child, current_indent + 2))
assert key == self.get_child_key_fn(
child.key
), f"{key=}, {self.get_child_key_fn(child.key)=}"
def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
@@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache):
if __name__ == "__main__":
tree = RadixCache(None, None, False)
tree = RadixCache(None, None, page_size=1, disable=False)
tree.insert("Hello")
tree.insert("Hello")