Support page size > 1 (#4356)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
|
||||
class BasePrefixCache(ABC):
|
||||
@@ -26,24 +26,22 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||
def evict(self, num_tokens: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inc_lock_ref(self, node):
|
||||
def inc_lock_ref(self, node: Any):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dec_lock_ref(self, node):
|
||||
def dec_lock_ref(self, node: Any):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evictable_size(self):
|
||||
pass
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def protected_size(self):
|
||||
raise NotImplementedError()
|
||||
return 0
|
||||
|
||||
def total_size(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache):
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
):
|
||||
self.disable = True
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.entries: Dict[str, ChunkCacheEntry] = {}
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.entries = {}
|
||||
pass
|
||||
|
||||
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
|
||||
if rid not in self.entries:
|
||||
return [], None
|
||||
|
||||
entry = self.entries[rid]
|
||||
max_prefix_len = len(key)
|
||||
return entry.value[:max_prefix_len], entry
|
||||
|
||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
if token_ids is None:
|
||||
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
else:
|
||||
token_id_len = len(token_ids)
|
||||
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
|
||||
return [], None
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :token_id_len
|
||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
]
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
|
||||
if req.rid in self.entries:
|
||||
del self.entries[req.rid]
|
||||
|
||||
def cache_unfinished_req(self, req: Req):
|
||||
token_id_len = len(req.fill_ids)
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :token_id_len
|
||||
req.req_pool_idx, : len(req.fill_ids)
|
||||
]
|
||||
|
||||
if req.rid not in self.entries:
|
||||
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
|
||||
|
||||
entry = self.entries[req.rid]
|
||||
entry.value = kv_indices
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
req.prefix_indices = kv_indices
|
||||
req.last_node = entry
|
||||
|
||||
def insert(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||
def evict(self, num_tokens: int):
|
||||
pass
|
||||
|
||||
def inc_lock_ref(self, node):
|
||||
def inc_lock_ref(self, node: Any):
|
||||
return 0
|
||||
|
||||
def dec_lock_ref(self, node):
|
||||
return 0
|
||||
|
||||
def evictable_size(self):
|
||||
return 0
|
||||
|
||||
def pretty_print(self):
|
||||
return ""
|
||||
|
||||
def protected_size(self):
|
||||
def dec_lock_ref(self, node: Any):
|
||||
return 0
|
||||
|
||||
def pretty_print(self):
|
||||
|
||||
@@ -7,13 +7,13 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -122,7 +122,7 @@ class HiRadixCache(RadixCache):
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback=None):
|
||||
def evict(self, num_tokens: int):
|
||||
leaves = self._collect_leaves_device()
|
||||
heapq.heapify(leaves)
|
||||
|
||||
|
||||
@@ -129,6 +129,7 @@ class TokenToKVPoolAllocator:
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = 1
|
||||
|
||||
self.free_slots = None
|
||||
self.is_not_in_free_group = True
|
||||
@@ -149,15 +150,14 @@ class TokenToKVPoolAllocator:
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
return select_index.to(self.device, non_blocking=True)
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
||||
self.free_slots = torch.concat((self.free_slots, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
@@ -172,7 +172,9 @@ class TokenToKVPoolAllocator:
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
|
||||
self.free_slots = torch.arange(
|
||||
1, self.size + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.is_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
@@ -182,6 +184,7 @@ class MHATokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
@@ -190,6 +193,7 @@ class MHATokenToKVPool(KVCache):
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
@@ -207,6 +211,8 @@ class MHATokenToKVPool(KVCache):
|
||||
self._create_buffers()
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
self.capture_mode = False
|
||||
self.alt_stream = torch.cuda.Stream()
|
||||
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
logger.info(
|
||||
@@ -218,16 +224,16 @@ class MHATokenToKVPool(KVCache):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -315,14 +321,44 @@ class MHATokenToKVPool(KVCache):
|
||||
cache_v.div_(v_scale)
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
if self.capture_mode:
|
||||
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
torch.cuda.current_stream().wait_stream(self.alt_stream)
|
||||
else:
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
|
||||
|
||||
@torch.compile
|
||||
def fused_downcast(
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
store_dtype: torch.dtype,
|
||||
max_fp8: float,
|
||||
min_fp8: float,
|
||||
):
|
||||
cache_k = cache_k / k_scale
|
||||
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
|
||||
cache_v = cache_v / v_scale
|
||||
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
|
||||
cache_k = cache_k.to(dtype)
|
||||
cache_v = cache_v.to(dtype)
|
||||
cache_k = cache_k.view(store_dtype)
|
||||
cache_v = cache_v.view(store_dtype)
|
||||
return cache_k, cache_v
|
||||
|
||||
|
||||
# This compiled version is slower in the unit test
|
||||
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
@@ -335,6 +371,7 @@ class MLATokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
@@ -359,8 +396,8 @@ class MLATokenToKVPool(KVCache):
|
||||
with memory_saver_adapter.region():
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
torch.zeros(
|
||||
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
@@ -400,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
@@ -409,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
@@ -423,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
with memory_saver_adapter.region():
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
torch.zeros(
|
||||
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
torch.zeros(
|
||||
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
# [size, head_num, heavy_channel_num] for each layer
|
||||
self.label_buffer = [
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
@@ -528,7 +571,7 @@ class MHATokenToKVPoolHost:
|
||||
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
||||
)
|
||||
|
||||
self.kv_buffer = torch.empty(
|
||||
self.kv_buffer = torch.zeros(
|
||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
@@ -548,9 +591,6 @@ class MHATokenToKVPoolHost:
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[:, layer_id, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
|
||||
283
python/sglang/srt/mem_cache/paged_allocator.py
Normal file
283
python/sglang/srt/mem_cache/paged_allocator.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Copyright 2025 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Page-aligned memory pool.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def alloc_extend_kernel(
|
||||
pre_lens_ptr,
|
||||
seq_lens_ptr,
|
||||
last_loc_ptr,
|
||||
free_page_ptr,
|
||||
out_indices,
|
||||
ret_values,
|
||||
bs_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
max_num_extend_tokens: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
load_offset = tl.arange(0, bs_upper)
|
||||
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
||||
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
|
||||
extend_lens = seq_lens - pre_lens
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + pid)
|
||||
pre_len = tl.load(pre_lens_ptr + pid)
|
||||
extend_len = seq_len - pre_len
|
||||
|
||||
sum_extend_lens = tl.sum(extend_lens)
|
||||
output_start_loc = sum_extend_lens - extend_len
|
||||
|
||||
num_pages_after = (seq_lens + page_size - 1) // page_size
|
||||
num_pages_before = (pre_lens + page_size - 1) // page_size
|
||||
num_new_pages = num_pages_after - num_pages_before
|
||||
|
||||
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
||||
pre_len + page_size - 1
|
||||
) // page_size
|
||||
sum_num_new_pages = tl.sum(num_new_pages)
|
||||
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
||||
|
||||
# Return value
|
||||
if pid == tl.num_programs(0) - 1:
|
||||
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
|
||||
tl.int64
|
||||
)
|
||||
tl.store(ret_values, merged_value)
|
||||
|
||||
# Part 1: fill the old partial page
|
||||
last_loc = tl.load(last_loc_ptr + pid)
|
||||
num_part1 = (
|
||||
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
|
||||
)
|
||||
offset_one_page = tl.arange(0, page_size)
|
||||
tl.store(
|
||||
out_indices + output_start_loc + offset_one_page,
|
||||
last_loc + 1 + offset_one_page,
|
||||
mask=offset_one_page < num_part1,
|
||||
)
|
||||
if pre_len + num_part1 == seq_len:
|
||||
return
|
||||
|
||||
# Part 2: fill the new full pages
|
||||
num_part2 = (
|
||||
seq_len // page_size * page_size
|
||||
- (pre_len + page_size - 1) // page_size * page_size
|
||||
)
|
||||
|
||||
offset_many_page = tl.arange(0, max_num_extend_tokens)
|
||||
page_start = tl.load(
|
||||
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
|
||||
mask=offset_many_page < num_part2,
|
||||
)
|
||||
tl.store(
|
||||
out_indices + output_start_loc + num_part1 + offset_many_page,
|
||||
page_start * page_size + offset_many_page % page_size,
|
||||
mask=offset_many_page < num_part2,
|
||||
)
|
||||
if pre_len + num_part1 + num_part2 == seq_len:
|
||||
return
|
||||
|
||||
# Part 3: fill the new partial page
|
||||
num_part3 = seq_len - seq_len // page_size * page_size
|
||||
start_loc = tl.load(
|
||||
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
|
||||
)
|
||||
tl.store(
|
||||
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
|
||||
start_loc * page_size + offset_one_page,
|
||||
mask=offset_one_page < num_part3,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def alloc_decode_kernel(
|
||||
seq_lens_ptr,
|
||||
last_loc_ptr,
|
||||
free_page_ptr,
|
||||
out_indices,
|
||||
ret_values,
|
||||
bs_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
load_offset = tl.arange(0, bs_upper)
|
||||
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
||||
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + pid)
|
||||
pre_len = seq_len - 1
|
||||
|
||||
num_pages_after = (seq_lens + page_size - 1) // page_size
|
||||
num_pages_before = (pre_lens + page_size - 1) // page_size
|
||||
num_new_pages = num_pages_after - num_pages_before
|
||||
|
||||
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
||||
pre_len + page_size - 1
|
||||
) // page_size
|
||||
sum_num_new_pages = tl.sum(num_new_pages)
|
||||
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
||||
|
||||
# Return value
|
||||
if pid == tl.num_programs(0) - 1:
|
||||
tl.store(ret_values, sum_num_new_pages)
|
||||
|
||||
if num_page_start_loc_self == 0:
|
||||
last_loc = tl.load(last_loc_ptr + pid)
|
||||
tl.store(out_indices + pid, last_loc + 1)
|
||||
else:
|
||||
page = tl.load(free_page_ptr + new_page_start_loc)
|
||||
tl.store(out_indices + pid, page * page_size)
|
||||
|
||||
|
||||
class PagedTokenToKVPoolAllocator:
|
||||
"""
|
||||
An allocator managing the indices to kv cache data.
|
||||
|
||||
This class has the same interface as `TokenToKVPoolAllocator` but the output
|
||||
of one request is always page-aligned.
|
||||
|
||||
TODO: fuse last_loc into the kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
self.num_pages = size // page_size
|
||||
|
||||
self.free_pages = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
self.clear()
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
|
||||
self._kvcache = kvcache
|
||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
|
||||
def alloc_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(prefix_lens)
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
alloc_extend_kernel[(bs,)](
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
next_power_of_2(extend_num_tokens),
|
||||
)
|
||||
|
||||
merged_value = self.ret_values.item()
|
||||
num_new_pages = merged_value >> 32
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def alloc_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(seq_lens)
|
||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
alloc_decode_kernel[(bs,)](
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
num_new_pages = self.ret_values.item()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
free_page_indices = torch.unique(free_index // self.page_size)
|
||||
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.concat(self.free_group))
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_pages = torch.arange(
|
||||
1, self.num_pages + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.is_in_free_group = False
|
||||
self.free_group = []
|
||||
@@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache.
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -67,7 +68,7 @@ class TreeNode:
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def _key_match(key0: List, key1: List):
|
||||
def _key_match_page_size1(key0: List, key1: List):
|
||||
i = 0
|
||||
for k0, k1 in zip(key0, key1):
|
||||
if k0 != k1:
|
||||
@@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List):
|
||||
return i
|
||||
|
||||
|
||||
def _key_match_paged(key0: List, key1: List, page_size: int):
|
||||
min_len = min(len(key0), len(key1))
|
||||
|
||||
i = 0
|
||||
while i < min_len:
|
||||
if key0[i : i + page_size] != key1[i : i + page_size]:
|
||||
break
|
||||
i += page_size
|
||||
|
||||
return i
|
||||
|
||||
|
||||
class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = page_size
|
||||
self.disable = disable
|
||||
|
||||
if self.token_to_kv_pool_allocator:
|
||||
self.device = self.token_to_kv_pool_allocator.device
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
if self.page_size == 1:
|
||||
self.key_match_fn = _key_match_page_size1
|
||||
self.get_child_key_fn = lambda key: key[0]
|
||||
else:
|
||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
||||
self.reset()
|
||||
|
||||
##### Public API #####
|
||||
@@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache):
|
||||
The last node create a new child if the prefix is shorter
|
||||
than the last node's value.
|
||||
"""
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
if self.disable or len(key) == 0:
|
||||
return (
|
||||
torch.empty(
|
||||
(0,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
self.root_node,
|
||||
)
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(key) // self.page_size * self.page_size
|
||||
key = key[:page_aligned_len]
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
else:
|
||||
value = torch.tensor([], dtype=torch.int32)
|
||||
value = torch.empty((0,), dtype=torch.int32, device=self.device)
|
||||
return value, last_node
|
||||
|
||||
def insert(self, key: List, value=None):
|
||||
@@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache):
|
||||
value = [x for x in key]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
def cache_finished_req(self, req: Req):
|
||||
"""Cache request when it finishes."""
|
||||
if self.disable:
|
||||
if token_ids is None:
|
||||
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
else:
|
||||
token_ids_len = len(token_ids)
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :token_ids_len
|
||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
new_prefix_len = self.insert(
|
||||
token_ids[:page_aligned_len], page_aligned_kv_indices
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
@@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache):
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.dec_lock_ref(req.last_node)
|
||||
|
||||
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
def cache_unfinished_req(self, req: Req):
|
||||
"""Cache request when it is unfinished."""
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = req.fill_ids
|
||||
|
||||
token_ids = req.fill_ids
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||
assert len(new_indices) == len(token_ids)
|
||||
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
@@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
req.prefix_indices = new_indices
|
||||
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
if self.page_size != 1:
|
||||
req.prefix_indices = torch.cat(
|
||||
[new_indices, kv_indices[len(new_indices) :]]
|
||||
)
|
||||
else:
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
@@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache):
|
||||
def total_size(self):
|
||||
return self._total_size_helper()
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||
def evict(self, num_tokens: int):
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
@@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache):
|
||||
if x.lock_ref > 0:
|
||||
continue
|
||||
|
||||
evict_callback(x.value)
|
||||
self.token_to_kv_pool_allocator.free(x.value)
|
||||
num_evicted += len(x.value)
|
||||
self._delete_leaf(x)
|
||||
|
||||
@@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache):
|
||||
# protected size refers to the size of the cache that is locked
|
||||
return self.protected_size_
|
||||
|
||||
def all_values_flatten(self):
|
||||
values = []
|
||||
|
||||
def _dfs_helper(node: TreeNode):
|
||||
for _, child in node.children.items():
|
||||
values.append(child.value)
|
||||
_dfs_helper(child)
|
||||
|
||||
_dfs_helper(self.root_node)
|
||||
return torch.concat(values)
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.time()
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
value = []
|
||||
while len(key) > 0 and key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
child = node.children[child_key]
|
||||
child.last_access_time = time.time()
|
||||
prefix_len = _key_match(child.key, key)
|
||||
prefix_len = self.key_match_fn(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
value.append(new_node.value)
|
||||
@@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache):
|
||||
value.append(child.value)
|
||||
node = child
|
||||
key = key[prefix_len:]
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
# new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {key[split_len]: child}
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
new_node.parent = child.parent
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = child.key[:split_len]
|
||||
@@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache):
|
||||
child.parent = new_node
|
||||
child.key = child.key[split_len:]
|
||||
child.value = child.value[split_len:]
|
||||
new_node.parent.children[key[0]] = new_node
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||
@@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache):
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
total_prefix_length = 0
|
||||
while len(key) > 0 and key[0] in node.children.keys():
|
||||
node = node.children[key[0]]
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
node = node.children[child_key]
|
||||
node.last_access_time = time.time()
|
||||
prefix_len = _key_match(node.key, key)
|
||||
prefix_len = self.key_match_fn(node.key, key)
|
||||
total_prefix_length += prefix_len
|
||||
key = key[prefix_len:]
|
||||
value = value[prefix_len:]
|
||||
@@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache):
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
node = new_node
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = value
|
||||
node.children[key[0]] = new_node
|
||||
node.children[child_key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
return total_prefix_length
|
||||
|
||||
@@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache):
|
||||
current_node.key[:10],
|
||||
f"r={current_node.lock_ref}",
|
||||
)
|
||||
for _, child in current_node.children.items():
|
||||
for key, child in current_node.children.items():
|
||||
stack.append((child, current_indent + 2))
|
||||
|
||||
assert key == self.get_child_key_fn(
|
||||
child.key
|
||||
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
||||
|
||||
def _delete_leaf(self, node):
|
||||
for k, v in node.parent.children.items():
|
||||
if v == node:
|
||||
@@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tree = RadixCache(None, None, False)
|
||||
tree = RadixCache(None, None, page_size=1, disable=False)
|
||||
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello")
|
||||
|
||||
Reference in New Issue
Block a user