EAGLE cache fix for SWARadixCache (#11231)
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -777,6 +777,7 @@ class Scheduler(
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
is_eagle=self.spec_algorithm.is_eagle(),
|
||||
)
|
||||
elif server_args.enable_lmcache:
|
||||
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
||||
|
||||
@@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
self.full_to_swa_index_mapping[free_index] = 0
|
||||
|
||||
def backup_state(self):
|
||||
raise NotImplementedError
|
||||
return [
|
||||
self.full_attn_allocator.backup_state(),
|
||||
self.swa_attn_allocator.backup_state(),
|
||||
]
|
||||
|
||||
def restore_state(self, state):
|
||||
raise NotImplementedError
|
||||
assert len(state) == 2
|
||||
self.full_attn_allocator.restore_state(state[0])
|
||||
self.swa_attn_allocator.restore_state(state[1])
|
||||
|
||||
def clear(self):
|
||||
self.swa_attn_allocator.clear()
|
||||
|
||||
@@ -749,6 +749,7 @@ class SWAKVPool(KVCache):
|
||||
self,
|
||||
size: int,
|
||||
size_swa: int,
|
||||
dtype: torch.dtype,
|
||||
swa_attention_layer_ids: List[int],
|
||||
full_attention_layer_ids: List[int],
|
||||
enable_kvcache_transpose: bool,
|
||||
@@ -757,6 +758,7 @@ class SWAKVPool(KVCache):
|
||||
):
|
||||
self.size = size
|
||||
self.size_swa = size_swa
|
||||
self.dtype = dtype
|
||||
self.swa_layer_nums = len(swa_attention_layer_ids)
|
||||
self.full_layer_nums = len(full_attention_layer_ids)
|
||||
kwargs["page_size"] = 1
|
||||
@@ -766,11 +768,13 @@ class SWAKVPool(KVCache):
|
||||
|
||||
self.swa_kv_pool = token_to_kv_pool_class(
|
||||
size=size_swa,
|
||||
dtype=dtype,
|
||||
layer_num=self.swa_layer_nums,
|
||||
**kwargs,
|
||||
)
|
||||
self.full_kv_pool = token_to_kv_pool_class(
|
||||
size=size,
|
||||
dtype=dtype,
|
||||
layer_num=self.full_layer_nums,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -326,6 +326,8 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
all_token_len = len(token_ids)
|
||||
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
||||
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :all_token_len
|
||||
@@ -349,7 +351,8 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
old_prefix_len = len(req.prefix_indices)
|
||||
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
||||
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
||||
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
||||
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
||||
old_prefix_len -= 1
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
@@ -370,7 +373,8 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
token_ids = req.fill_ids
|
||||
all_token_len = len(token_ids)
|
||||
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
|
||||
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
||||
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :all_token_len
|
||||
@@ -393,7 +397,8 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
old_prefix_len = len(req.prefix_indices)
|
||||
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
||||
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
||||
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
||||
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
||||
old_prefix_len -= 1
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import (
|
||||
RadixKey,
|
||||
_convert_to_bigram_key,
|
||||
_key_match_page_size1,
|
||||
_key_match_paged,
|
||||
get_child_key,
|
||||
@@ -327,12 +328,14 @@ class SWARadixCache(BasePrefixCache):
|
||||
sliding_window_size: int,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
is_eagle: bool = False,
|
||||
):
|
||||
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = page_size
|
||||
self.disable = disable
|
||||
self.is_eagle = is_eagle
|
||||
|
||||
if self.token_to_kv_pool_allocator:
|
||||
self.device = self.token_to_kv_pool_allocator.device
|
||||
@@ -346,6 +349,11 @@ class SWARadixCache(BasePrefixCache):
|
||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
||||
|
||||
if is_eagle:
|
||||
self.key_convert_fn = _convert_to_bigram_key
|
||||
else:
|
||||
self.key_convert_fn = lambda key: key
|
||||
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.reset()
|
||||
|
||||
@@ -376,6 +384,8 @@ class SWARadixCache(BasePrefixCache):
|
||||
The last node create a new child if the prefix is shorter
|
||||
than the last node's value.
|
||||
"""
|
||||
key.token_ids = self.key_convert_fn(key.token_ids)
|
||||
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
device_indices=torch.empty(
|
||||
@@ -406,8 +416,15 @@ class SWARadixCache(BasePrefixCache):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
key.token_ids = self.key_convert_fn(key.token_ids)
|
||||
|
||||
if value is None:
|
||||
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
|
||||
|
||||
if self.is_eagle:
|
||||
# Make sure the value len equal to the EAGLE bigram key len
|
||||
value = value[: len(key)]
|
||||
|
||||
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
||||
|
||||
def cache_finished_req(self, req: Req) -> None:
|
||||
@@ -422,25 +439,41 @@ class SWARadixCache(BasePrefixCache):
|
||||
return
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
all_token_len = len(token_ids)
|
||||
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
||||
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
req.req_pool_idx, :all_token_len
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
if self.is_eagle:
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
|
||||
page_aligned_token_len = (
|
||||
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
||||
)
|
||||
|
||||
old_prefix_len = len(req.prefix_indices)
|
||||
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
||||
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
||||
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
||||
old_prefix_len -= 1
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
# insert the token_ids and kv_indices into the radix tree
|
||||
# Note: the insert function already frees the overlapped kv_indices
|
||||
new_prefix_len = self.insert(
|
||||
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
||||
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
len(req.prefix_indices),
|
||||
old_prefix_len,
|
||||
)
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
@@ -459,39 +492,56 @@ class SWARadixCache(BasePrefixCache):
|
||||
return
|
||||
|
||||
token_ids = req.fill_ids
|
||||
all_token_len = len(token_ids)
|
||||
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
||||
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
req.req_pool_idx, :all_token_len
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||
|
||||
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
||||
page_aligned_token_len = (
|
||||
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
||||
)
|
||||
page_aligned_token_ids = token_ids[:page_aligned_token_len]
|
||||
|
||||
old_prefix_len = len(req.prefix_indices)
|
||||
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
||||
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
||||
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
||||
old_prefix_len -= 1
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
# Note: the insert function already frees the overlapped kv_indices
|
||||
new_prefix_len = self.insert(
|
||||
RadixKey(page_aligned_token_ids, req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
len(req.prefix_indices),
|
||||
old_prefix_len,
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||
RadixKey(page_aligned_token_ids, req.extra_key)
|
||||
)
|
||||
assert len(req.prefix_indices) <= len(
|
||||
assert old_prefix_len <= len(
|
||||
new_indices
|
||||
), f"{req.prefix_indices=}, {new_indices=}"
|
||||
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
||||
new_indices[old_prefix_len:],
|
||||
)
|
||||
|
||||
req.last_matched_prefix_len = len(new_indices)
|
||||
|
||||
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
||||
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
|
||||
|
||||
@@ -501,7 +551,13 @@ class SWARadixCache(BasePrefixCache):
|
||||
[new_indices, kv_indices[len(new_indices) :]]
|
||||
)
|
||||
else:
|
||||
req.prefix_indices = new_indices
|
||||
if self.is_eagle:
|
||||
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
|
||||
req.prefix_indices = torch.cat(
|
||||
[new_indices, kv_indices[actual_kv_len:]]
|
||||
)
|
||||
else:
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
req.swa_uuid_for_lock = swa_uuid_for_lock
|
||||
|
||||
|
||||
@@ -27,7 +27,11 @@ if _is_cuda:
|
||||
|
||||
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
||||
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
||||
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
||||
return (
|
||||
_is_cuda
|
||||
and hasattr(forward_batch.token_to_kv_pool, "dtype")
|
||||
and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
||||
)
|
||||
|
||||
|
||||
def create_fused_set_kv_buffer_arg(
|
||||
|
||||
Reference in New Issue
Block a user