Fix eagle radix cache (#10846)

This commit is contained in:
Ke Bao
2025-09-30 22:59:20 +08:00
committed by GitHub
parent 5a290a5644
commit 91847e382a
6 changed files with 235 additions and 26 deletions

View File

@@ -547,6 +547,8 @@ class Req:
self.host_hit_length = 0
# The node to lock until for swa radix tree lock ref
self.swa_uuid_for_lock: Optional[int] = None
# The prefix length of the last prefix matching
self.last_matched_prefix_len: int = 0
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
@@ -701,6 +703,7 @@ class Req:
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
),
)
self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):

View File

@@ -756,6 +756,7 @@ class Scheduler(
disable=server_args.disable_radix_cache,
enable_kv_cache_events=self.enable_kv_cache_events,
eviction_policy=server_args.radix_eviction_policy,
is_eagle=self.spec_algorithm.is_eagle(),
)
if (

View File

@@ -23,7 +23,7 @@ import heapq
import time
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
import torch
@@ -159,6 +159,16 @@ def get_child_key(key: RadixKey, page_size: int = 1):
return (key.extra_key, plain_key)
def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
# EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
# [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
if len(tokens) < 2:
return []
if isinstance(tokens[0], tuple):
return tokens
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
class RadixCache(BasePrefixCache):
def __init__(
self,
@@ -168,6 +178,7 @@ class RadixCache(BasePrefixCache):
disable: bool = False,
enable_kv_cache_events: bool = False,
eviction_policy: str = "lru",
is_eagle: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
@@ -175,6 +186,7 @@ class RadixCache(BasePrefixCache):
self.disable = disable
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue = []
self.is_eagle = is_eagle
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
@@ -188,6 +200,11 @@ class RadixCache(BasePrefixCache):
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
if is_eagle:
self.key_convert_fn = _convert_to_bigram_key
else:
self.key_convert_fn = lambda key: key
if eviction_policy.lower() == "lru":
self.eviction_strategy: EvictionStrategy = LRUStrategy()
elif eviction_policy.lower() == "lfu":
@@ -248,6 +265,8 @@ class RadixCache(BasePrefixCache):
to expose a precise boundary; this structural refinement improves
subsequent match efficiency and does not duplicate data.
"""
key.token_ids = self.key_convert_fn(key.token_ids)
if self.disable or len(key) == 0:
return MatchResult(
device_indices=torch.empty(
@@ -278,8 +297,15 @@ class RadixCache(BasePrefixCache):
if self.disable:
return 0
key.token_ids = self.key_convert_fn(key.token_ids)
if value is None:
value = torch.tensor(key.token_ids, dtype=torch.int64)
if self.is_eagle:
# Make sure the value len equal to the EAGLE bigram key len
value = value[: len(key)]
return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req):
@@ -293,28 +319,39 @@ class RadixCache(BasePrefixCache):
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
all_token_len = len(token_ids)
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
if self.is_eagle:
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(
RadixKey(token_ids[:page_aligned_len], req.extra_key),
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
page_aligned_kv_indices,
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
@@ -326,19 +363,32 @@ class RadixCache(BasePrefixCache):
return
token_ids = req.fill_ids
all_token_len = len(token_ids)
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
page_aligned_token_ids = token_ids[:page_aligned_len]
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
page_aligned_token_ids = token_ids[:page_aligned_token_len]
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(
@@ -346,29 +396,40 @@ class RadixCache(BasePrefixCache):
page_aligned_kv_indices,
chunked=chunked,
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
new_indices[old_prefix_len:],
)
# The last_matched_prefix_len is not always equal to len(req.prefix_indices)
# since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
# It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
# So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
req.last_matched_prefix_len = len(new_indices)
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
if self.page_size != 1:
# Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
if self.is_eagle:
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
req.prefix_indices = torch.cat(
[new_indices, kv_indices[actual_kv_len:]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self):

View File

@@ -77,7 +77,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
# EAGLE
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct"
)