EAGLE cache fix for SWARadixCache (#11231)

Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
Ke Bao
2025-10-07 18:21:37 +08:00
committed by GitHub
parent 8a8a608af9
commit 24bc3fb0f9
8 changed files with 248 additions and 31 deletions

View File

@@ -777,6 +777,7 @@ class Scheduler(
sliding_window_size=self.sliding_window_size,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
is_eagle=self.spec_algorithm.is_eagle(),
)
elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (

View File

@@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.full_to_swa_index_mapping[free_index] = 0
def backup_state(self):
raise NotImplementedError
return [
self.full_attn_allocator.backup_state(),
self.swa_attn_allocator.backup_state(),
]
def restore_state(self, state):
raise NotImplementedError
assert len(state) == 2
self.full_attn_allocator.restore_state(state[0])
self.swa_attn_allocator.restore_state(state[1])
def clear(self):
self.swa_attn_allocator.clear()

View File

@@ -749,6 +749,7 @@ class SWAKVPool(KVCache):
self,
size: int,
size_swa: int,
dtype: torch.dtype,
swa_attention_layer_ids: List[int],
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
@@ -757,6 +758,7 @@ class SWAKVPool(KVCache):
):
self.size = size
self.size_swa = size_swa
self.dtype = dtype
self.swa_layer_nums = len(swa_attention_layer_ids)
self.full_layer_nums = len(full_attention_layer_ids)
kwargs["page_size"] = 1
@@ -766,11 +768,13 @@ class SWAKVPool(KVCache):
self.swa_kv_pool = token_to_kv_pool_class(
size=size_swa,
dtype=dtype,
layer_num=self.swa_layer_nums,
**kwargs,
)
self.full_kv_pool = token_to_kv_pool_class(
size=size,
dtype=dtype,
layer_num=self.full_layer_nums,
**kwargs,
)

View File

@@ -326,6 +326,8 @@ class RadixCache(BasePrefixCache):
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
all_token_len = len(token_ids)
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :all_token_len
@@ -349,7 +351,8 @@ class RadixCache(BasePrefixCache):
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
@@ -370,7 +373,8 @@ class RadixCache(BasePrefixCache):
token_ids = req.fill_ids
all_token_len = len(token_ids)
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :all_token_len
@@ -393,7 +397,8 @@ class RadixCache(BasePrefixCache):
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool

View File

@@ -32,6 +32,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import (
RadixKey,
_convert_to_bigram_key,
_key_match_page_size1,
_key_match_paged,
get_child_key,
@@ -327,12 +328,14 @@ class SWARadixCache(BasePrefixCache):
sliding_window_size: int,
page_size: int,
disable: bool = False,
is_eagle: bool = False,
):
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
self.is_eagle = is_eagle
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
@@ -346,6 +349,11 @@ class SWARadixCache(BasePrefixCache):
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
if is_eagle:
self.key_convert_fn = _convert_to_bigram_key
else:
self.key_convert_fn = lambda key: key
self.sliding_window_size = sliding_window_size
self.reset()
@@ -376,6 +384,8 @@ class SWARadixCache(BasePrefixCache):
The last node create a new child if the prefix is shorter
than the last node's value.
"""
key.token_ids = self.key_convert_fn(key.token_ids)
if self.disable or len(key) == 0:
return MatchResult(
device_indices=torch.empty(
@@ -406,8 +416,15 @@ class SWARadixCache(BasePrefixCache):
if self.disable:
return 0
key.token_ids = self.key_convert_fn(key.token_ids)
if value is None:
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
if self.is_eagle:
# Make sure the value len equal to the EAGLE bigram key len
value = value[: len(key)]
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
def cache_finished_req(self, req: Req) -> None:
@@ -422,25 +439,41 @@ class SWARadixCache(BasePrefixCache):
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
all_token_len = len(token_ids)
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
if self.is_eagle:
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
# insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert(
RadixKey(token_ids[:page_aligned_len], req.extra_key),
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
page_aligned_kv_indices,
len(req.prefix_indices),
old_prefix_len,
)
# Remove req slot release the cache lock
@@ -459,39 +492,56 @@ class SWARadixCache(BasePrefixCache):
return
token_ids = req.fill_ids
all_token_len = len(token_ids)
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
page_aligned_token_ids = token_ids[:page_aligned_len]
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
page_aligned_token_ids = token_ids[:page_aligned_token_len]
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
# Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert(
RadixKey(page_aligned_token_ids, req.extra_key),
page_aligned_kv_indices,
len(req.prefix_indices),
old_prefix_len,
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(page_aligned_token_ids, req.extra_key)
)
assert len(req.prefix_indices) <= len(
assert old_prefix_len <= len(
new_indices
), f"{req.prefix_indices=}, {new_indices=}"
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
new_indices[old_prefix_len:],
)
req.last_matched_prefix_len = len(new_indices)
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
@@ -501,7 +551,13 @@ class SWARadixCache(BasePrefixCache):
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
if self.is_eagle:
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
req.prefix_indices = torch.cat(
[new_indices, kv_indices[actual_kv_len:]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
req.swa_uuid_for_lock = swa_uuid_for_lock

View File

@@ -27,7 +27,11 @@ if _is_cuda:
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
return (
_is_cuda
and hasattr(forward_batch.token_to_kv_pool, "dtype")
and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
)
def create_fused_set_kv_buffer_arg(