|
|
|
|
@@ -23,7 +23,7 @@ import heapq
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
@@ -159,6 +159,16 @@ def get_child_key(key: RadixKey, page_size: int = 1):
|
|
|
|
|
return (key.extra_key, plain_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
|
|
|
|
|
# EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
|
|
|
|
|
# [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
|
|
|
|
|
if len(tokens) < 2:
|
|
|
|
|
return []
|
|
|
|
|
if isinstance(tokens[0], tuple):
|
|
|
|
|
return tokens
|
|
|
|
|
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RadixCache(BasePrefixCache):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
@@ -168,6 +178,7 @@ class RadixCache(BasePrefixCache):
|
|
|
|
|
disable: bool = False,
|
|
|
|
|
enable_kv_cache_events: bool = False,
|
|
|
|
|
eviction_policy: str = "lru",
|
|
|
|
|
is_eagle: bool = False,
|
|
|
|
|
):
|
|
|
|
|
self.req_to_token_pool = req_to_token_pool
|
|
|
|
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
|
|
|
|
@@ -175,6 +186,7 @@ class RadixCache(BasePrefixCache):
|
|
|
|
|
self.disable = disable
|
|
|
|
|
self.enable_kv_cache_events = enable_kv_cache_events
|
|
|
|
|
self.kv_event_queue = []
|
|
|
|
|
self.is_eagle = is_eagle
|
|
|
|
|
|
|
|
|
|
if self.token_to_kv_pool_allocator:
|
|
|
|
|
self.device = self.token_to_kv_pool_allocator.device
|
|
|
|
|
@@ -188,6 +200,11 @@ class RadixCache(BasePrefixCache):
|
|
|
|
|
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
|
|
|
|
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
|
|
|
|
|
|
|
|
|
if is_eagle:
|
|
|
|
|
self.key_convert_fn = _convert_to_bigram_key
|
|
|
|
|
else:
|
|
|
|
|
self.key_convert_fn = lambda key: key
|
|
|
|
|
|
|
|
|
|
if eviction_policy.lower() == "lru":
|
|
|
|
|
self.eviction_strategy: EvictionStrategy = LRUStrategy()
|
|
|
|
|
elif eviction_policy.lower() == "lfu":
|
|
|
|
|
@@ -248,6 +265,8 @@ class RadixCache(BasePrefixCache):
|
|
|
|
|
to expose a precise boundary; this structural refinement improves
|
|
|
|
|
subsequent match efficiency and does not duplicate data.
|
|
|
|
|
"""
|
|
|
|
|
key.token_ids = self.key_convert_fn(key.token_ids)
|
|
|
|
|
|
|
|
|
|
if self.disable or len(key) == 0:
|
|
|
|
|
return MatchResult(
|
|
|
|
|
device_indices=torch.empty(
|
|
|
|
|
@@ -278,8 +297,15 @@ class RadixCache(BasePrefixCache):
|
|
|
|
|
if self.disable:
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
key.token_ids = self.key_convert_fn(key.token_ids)
|
|
|
|
|
|
|
|
|
|
if value is None:
|
|
|
|
|
value = torch.tensor(key.token_ids, dtype=torch.int64)
|
|
|
|
|
|
|
|
|
|
if self.is_eagle:
|
|
|
|
|
# Make sure the value len equal to the EAGLE bigram key len
|
|
|
|
|
value = value[: len(key)]
|
|
|
|
|
|
|
|
|
|
return self._insert_helper(self.root_node, key, value)
|
|
|
|
|
|
|
|
|
|
def cache_finished_req(self, req: Req):
|
|
|
|
|
@@ -293,28 +319,39 @@ class RadixCache(BasePrefixCache):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
|
|
|
|
all_token_len = len(token_ids)
|
|
|
|
|
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
|
|
|
|
kv_indices = self.req_to_token_pool.req_to_token[
|
|
|
|
|
req.req_pool_idx, : len(token_ids)
|
|
|
|
|
req.req_pool_idx, :all_token_len
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if self.page_size != 1:
|
|
|
|
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
|
|
|
|
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
|
|
|
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
|
|
|
|
dtype=torch.int64, copy=True
|
|
|
|
|
)
|
|
|
|
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
|
|
|
|
else:
|
|
|
|
|
page_aligned_len = len(kv_indices)
|
|
|
|
|
page_aligned_len = actual_kv_len
|
|
|
|
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
|
|
|
|
if self.is_eagle:
|
|
|
|
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
|
|
|
|
|
|
|
|
|
page_aligned_token_len = (
|
|
|
|
|
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
old_prefix_len = len(req.prefix_indices)
|
|
|
|
|
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
|
|
|
|
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
|
|
|
|
old_prefix_len -= 1
|
|
|
|
|
|
|
|
|
|
# Radix Cache takes one ref in memory pool
|
|
|
|
|
new_prefix_len = self.insert(
|
|
|
|
|
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
|
|
|
|
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
|
|
|
|
page_aligned_kv_indices,
|
|
|
|
|
)
|
|
|
|
|
self.token_to_kv_pool_allocator.free(
|
|
|
|
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
|
|
|
|
)
|
|
|
|
|
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
|
|
|
|
|
|
|
|
|
# Remove req slot release the cache lock
|
|
|
|
|
self.req_to_token_pool.free(req.req_pool_idx)
|
|
|
|
|
@@ -326,19 +363,32 @@ class RadixCache(BasePrefixCache):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
token_ids = req.fill_ids
|
|
|
|
|
all_token_len = len(token_ids)
|
|
|
|
|
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
|
|
|
|
|
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
|
|
|
|
kv_indices = self.req_to_token_pool.req_to_token[
|
|
|
|
|
req.req_pool_idx, : len(token_ids)
|
|
|
|
|
req.req_pool_idx, :all_token_len
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if self.page_size != 1:
|
|
|
|
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
|
|
|
|
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
|
|
|
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
|
|
|
|
dtype=torch.int64, copy=True
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
page_aligned_len = len(kv_indices)
|
|
|
|
|
page_aligned_len = actual_kv_len
|
|
|
|
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
|
|
|
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
|
|
|
|
|
|
|
|
|
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
|
|
|
|
page_aligned_token_len = (
|
|
|
|
|
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
|
|
|
|
)
|
|
|
|
|
page_aligned_token_ids = token_ids[:page_aligned_token_len]
|
|
|
|
|
|
|
|
|
|
old_prefix_len = len(req.prefix_indices)
|
|
|
|
|
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
|
|
|
|
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
|
|
|
|
old_prefix_len -= 1
|
|
|
|
|
|
|
|
|
|
# Radix Cache takes one ref in memory pool
|
|
|
|
|
new_prefix_len = self.insert(
|
|
|
|
|
@@ -346,29 +396,40 @@ class RadixCache(BasePrefixCache):
|
|
|
|
|
page_aligned_kv_indices,
|
|
|
|
|
chunked=chunked,
|
|
|
|
|
)
|
|
|
|
|
self.token_to_kv_pool_allocator.free(
|
|
|
|
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
|
|
|
|
)
|
|
|
|
|
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
|
|
|
|
|
|
|
|
|
# The prefix indices could be updated, reuse it
|
|
|
|
|
new_indices, new_last_node, _, _ = self.match_prefix(
|
|
|
|
|
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
|
|
|
|
)
|
|
|
|
|
self.req_to_token_pool.write(
|
|
|
|
|
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
|
|
|
|
new_indices[len(req.prefix_indices) :],
|
|
|
|
|
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
|
|
|
|
new_indices[old_prefix_len:],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# The last_matched_prefix_len is not always equal to len(req.prefix_indices)
|
|
|
|
|
# since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
|
|
|
|
|
# It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
|
|
|
|
|
# So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
|
|
|
|
|
req.last_matched_prefix_len = len(new_indices)
|
|
|
|
|
|
|
|
|
|
self.dec_lock_ref(req.last_node)
|
|
|
|
|
self.inc_lock_ref(new_last_node)
|
|
|
|
|
|
|
|
|
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
|
|
|
|
if self.page_size != 1:
|
|
|
|
|
# Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
|
|
|
|
|
req.prefix_indices = torch.cat(
|
|
|
|
|
[new_indices, kv_indices[len(new_indices) :]]
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
req.prefix_indices = new_indices
|
|
|
|
|
if self.is_eagle:
|
|
|
|
|
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
|
|
|
|
|
req.prefix_indices = torch.cat(
|
|
|
|
|
[new_indices, kv_indices[actual_kv_len:]]
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
req.prefix_indices = new_indices
|
|
|
|
|
req.last_node = new_last_node
|
|
|
|
|
|
|
|
|
|
def pretty_print(self):
|
|
|
|
|
|