Fix eagle radix cache (#10846)
This commit is contained in:
@@ -547,6 +547,8 @@ class Req:
|
||||
self.host_hit_length = 0
|
||||
# The node to lock until for swa radix tree lock ref
|
||||
self.swa_uuid_for_lock: Optional[int] = None
|
||||
# The prefix length of the last prefix matching
|
||||
self.last_matched_prefix_len: int = 0
|
||||
|
||||
# Whether or not if it is chunked. It increments whenever
|
||||
# it is chunked, and decrement whenever chunked request is
|
||||
@@ -701,6 +703,7 @@ class Req:
|
||||
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
|
||||
),
|
||||
)
|
||||
self.last_matched_prefix_len = len(self.prefix_indices)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
|
||||
@@ -756,6 +756,7 @@ class Scheduler(
|
||||
disable=server_args.disable_radix_cache,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
eviction_policy=server_args.radix_eviction_policy,
|
||||
is_eagle=self.spec_algorithm.is_eagle(),
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -23,7 +23,7 @@ import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -159,6 +159,16 @@ def get_child_key(key: RadixKey, page_size: int = 1):
|
||||
return (key.extra_key, plain_key)
|
||||
|
||||
|
||||
def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
|
||||
# EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
|
||||
# [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
|
||||
if len(tokens) < 2:
|
||||
return []
|
||||
if isinstance(tokens[0], tuple):
|
||||
return tokens
|
||||
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
|
||||
|
||||
|
||||
class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -168,6 +178,7 @@ class RadixCache(BasePrefixCache):
|
||||
disable: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
eviction_policy: str = "lru",
|
||||
is_eagle: bool = False,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
@@ -175,6 +186,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.disable = disable
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue = []
|
||||
self.is_eagle = is_eagle
|
||||
|
||||
if self.token_to_kv_pool_allocator:
|
||||
self.device = self.token_to_kv_pool_allocator.device
|
||||
@@ -188,6 +200,11 @@ class RadixCache(BasePrefixCache):
|
||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
||||
|
||||
if is_eagle:
|
||||
self.key_convert_fn = _convert_to_bigram_key
|
||||
else:
|
||||
self.key_convert_fn = lambda key: key
|
||||
|
||||
if eviction_policy.lower() == "lru":
|
||||
self.eviction_strategy: EvictionStrategy = LRUStrategy()
|
||||
elif eviction_policy.lower() == "lfu":
|
||||
@@ -248,6 +265,8 @@ class RadixCache(BasePrefixCache):
|
||||
to expose a precise boundary; this structural refinement improves
|
||||
subsequent match efficiency and does not duplicate data.
|
||||
"""
|
||||
key.token_ids = self.key_convert_fn(key.token_ids)
|
||||
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
device_indices=torch.empty(
|
||||
@@ -278,8 +297,15 @@ class RadixCache(BasePrefixCache):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
key.token_ids = self.key_convert_fn(key.token_ids)
|
||||
|
||||
if value is None:
|
||||
value = torch.tensor(key.token_ids, dtype=torch.int64)
|
||||
|
||||
if self.is_eagle:
|
||||
# Make sure the value len equal to the EAGLE bigram key len
|
||||
value = value[: len(key)]
|
||||
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
@@ -293,28 +319,39 @@ class RadixCache(BasePrefixCache):
|
||||
return
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
all_token_len = len(token_ids)
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
req.req_pool_idx, :all_token_len
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
if self.is_eagle:
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
|
||||
page_aligned_token_len = (
|
||||
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
||||
)
|
||||
|
||||
old_prefix_len = len(req.prefix_indices)
|
||||
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
||||
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
||||
old_prefix_len -= 1
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
||||
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
@@ -326,19 +363,32 @@ class RadixCache(BasePrefixCache):
|
||||
return
|
||||
|
||||
token_ids = req.fill_ids
|
||||
all_token_len = len(token_ids)
|
||||
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
req.req_pool_idx, :all_token_len
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||
|
||||
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
||||
page_aligned_token_len = (
|
||||
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
||||
)
|
||||
page_aligned_token_ids = token_ids[:page_aligned_token_len]
|
||||
|
||||
old_prefix_len = len(req.prefix_indices)
|
||||
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
||||
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
||||
old_prefix_len -= 1
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
@@ -346,29 +396,40 @@ class RadixCache(BasePrefixCache):
|
||||
page_aligned_kv_indices,
|
||||
chunked=chunked,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
||||
)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
||||
new_indices[old_prefix_len:],
|
||||
)
|
||||
|
||||
# The last_matched_prefix_len is not always equal to len(req.prefix_indices)
|
||||
# since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
|
||||
# It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
|
||||
# So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
|
||||
req.last_matched_prefix_len = len(new_indices)
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
if self.page_size != 1:
|
||||
# Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
|
||||
req.prefix_indices = torch.cat(
|
||||
[new_indices, kv_indices[len(new_indices) :]]
|
||||
)
|
||||
else:
|
||||
req.prefix_indices = new_indices
|
||||
if self.is_eagle:
|
||||
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
|
||||
req.prefix_indices = torch.cat(
|
||||
[new_indices, kv_indices[actual_kv_len:]]
|
||||
)
|
||||
else:
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
|
||||
Reference in New Issue
Block a user