From 24bc3fb0f9149215ffe80ce8ceece0dac3ad1441 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 7 Oct 2025 18:21:37 +0800 Subject: [PATCH] EAGLE cache fix for SWARadixCache (#11231) Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com> --- python/sglang/srt/managers/scheduler.py | 1 + python/sglang/srt/mem_cache/allocator.py | 9 +- python/sglang/srt/mem_cache/memory_pool.py | 4 + python/sglang/srt/mem_cache/radix_cache.py | 11 +- .../sglang/srt/mem_cache/swa_radix_cache.py | 84 +++++++-- python/sglang/srt/models/utils.py | 6 +- test/srt/run_suite.py | 1 + test/srt/test_swa_unittest.py | 163 ++++++++++++++++-- 8 files changed, 248 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e5d6fab7b..8d91df038 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -777,6 +777,7 @@ class Scheduler( sliding_window_size=self.sliding_window_size, page_size=self.page_size, disable=server_args.disable_radix_cache, + is_eagle=self.spec_algorithm.is_eagle(), ) elif server_args.enable_lmcache: from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import ( diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 9fef8d133..4fefac941 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): self.full_to_swa_index_mapping[free_index] = 0 def backup_state(self): - raise NotImplementedError + return [ + self.full_attn_allocator.backup_state(), + self.swa_attn_allocator.backup_state(), + ] def restore_state(self, state): - raise NotImplementedError + assert len(state) == 2 + self.full_attn_allocator.restore_state(state[0]) + self.swa_attn_allocator.restore_state(state[1]) def clear(self): self.swa_attn_allocator.clear() diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 11249ff79..01f422504 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -749,6 +749,7 @@ class SWAKVPool(KVCache): self, size: int, size_swa: int, + dtype: torch.dtype, swa_attention_layer_ids: List[int], full_attention_layer_ids: List[int], enable_kvcache_transpose: bool, @@ -757,6 +758,7 @@ class SWAKVPool(KVCache): ): self.size = size self.size_swa = size_swa + self.dtype = dtype self.swa_layer_nums = len(swa_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids) kwargs["page_size"] = 1 @@ -766,11 +768,13 @@ class SWAKVPool(KVCache): self.swa_kv_pool = token_to_kv_pool_class( size=size_swa, + dtype=dtype, layer_num=self.swa_layer_nums, **kwargs, ) self.full_kv_pool = token_to_kv_pool_class( size=size, + dtype=dtype, layer_num=self.full_layer_nums, **kwargs, ) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index edb0495dc..dac120016 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -326,6 +326,8 @@ class RadixCache(BasePrefixCache): token_ids = (req.origin_input_ids + req.output_ids)[:-1] all_token_len = len(token_ids) + # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)) + # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing. actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, :all_token_len @@ -349,7 +351,8 @@ class RadixCache(BasePrefixCache): old_prefix_len = len(req.prefix_indices) if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: - # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) + # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:]) + # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak old_prefix_len -= 1 # Radix Cache takes one ref in memory pool @@ -370,7 +373,8 @@ class RadixCache(BasePrefixCache): token_ids = req.fill_ids all_token_len = len(token_ids) - # The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key + # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)) + # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing. actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, :all_token_len @@ -393,7 +397,8 @@ class RadixCache(BasePrefixCache): old_prefix_len = len(req.prefix_indices) if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: - # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) + # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:]) + # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak old_prefix_len -= 1 # Radix Cache takes one ref in memory pool diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 592f1198f..764def85c 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -32,6 +32,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.radix_cache import ( RadixKey, + _convert_to_bigram_key, _key_match_page_size1, _key_match_paged, get_child_key, @@ -327,12 +328,14 @@ class SWARadixCache(BasePrefixCache): sliding_window_size: int, page_size: int, disable: bool = False, + is_eagle: bool = False, ): assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.page_size = page_size self.disable = disable + self.is_eagle = is_eagle if self.token_to_kv_pool_allocator: self.device = self.token_to_kv_pool_allocator.device @@ -346,6 +349,11 @@ class SWARadixCache(BasePrefixCache): self.key_match_fn = partial(_key_match_paged, page_size=page_size) self.get_child_key_fn = partial(get_child_key, page_size=page_size) + if is_eagle: + self.key_convert_fn = _convert_to_bigram_key + else: + self.key_convert_fn = lambda key: key + self.sliding_window_size = sliding_window_size self.reset() @@ -376,6 +384,8 @@ class SWARadixCache(BasePrefixCache): The last node create a new child if the prefix is shorter than the last node's value. """ + key.token_ids = self.key_convert_fn(key.token_ids) + if self.disable or len(key) == 0: return MatchResult( device_indices=torch.empty( @@ -406,8 +416,15 @@ class SWARadixCache(BasePrefixCache): if self.disable: return 0 + key.token_ids = self.key_convert_fn(key.token_ids) + if value is None: value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) + + if self.is_eagle: + # Make sure the value len equal to the EAGLE bigram key len + value = value[: len(key)] + return self._insert_helper(self.root_node, key, value, prev_prefix_len) def cache_finished_req(self, req: Req) -> None: @@ -422,25 +439,41 @@ class SWARadixCache(BasePrefixCache): return token_ids = (req.origin_input_ids + req.output_ids)[:-1] + all_token_len = len(token_ids) + # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)) + # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing. + actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :all_token_len ] if self.page_size != 1: - page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_len = actual_kv_len // self.page_size * self.page_size page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) else: - page_aligned_len = len(kv_indices) + page_aligned_len = actual_kv_len page_aligned_kv_indices = kv_indices.clone() + if self.is_eagle: + self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) + + page_aligned_token_len = ( + page_aligned_len + 1 if self.is_eagle else page_aligned_len + ) + + old_prefix_len = len(req.prefix_indices) + if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: + # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:]) + # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak + old_prefix_len -= 1 # Radix Cache takes one ref in memory pool # insert the token_ids and kv_indices into the radix tree # Note: the insert function already frees the overlapped kv_indices new_prefix_len = self.insert( - RadixKey(token_ids[:page_aligned_len], req.extra_key), + RadixKey(token_ids[:page_aligned_token_len], req.extra_key), page_aligned_kv_indices, - len(req.prefix_indices), + old_prefix_len, ) # Remove req slot release the cache lock @@ -459,39 +492,56 @@ class SWARadixCache(BasePrefixCache): return token_ids = req.fill_ids + all_token_len = len(token_ids) + # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)) + # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing. + actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :all_token_len ] if self.page_size != 1: - page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_len = actual_kv_len // self.page_size * self.page_size page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() else: - page_aligned_len = len(kv_indices) + page_aligned_len = actual_kv_len page_aligned_kv_indices = kv_indices.clone() - page_aligned_token_ids = token_ids[:page_aligned_len] + + # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1 + page_aligned_token_len = ( + page_aligned_len + 1 if self.is_eagle else page_aligned_len + ) + page_aligned_token_ids = token_ids[:page_aligned_token_len] + + old_prefix_len = len(req.prefix_indices) + if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: + # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:]) + # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak + old_prefix_len -= 1 # Radix Cache takes one ref in memory pool # Note: the insert function already frees the overlapped kv_indices new_prefix_len = self.insert( RadixKey(page_aligned_token_ids, req.extra_key), page_aligned_kv_indices, - len(req.prefix_indices), + old_prefix_len, ) # The prefix indices could be updated, reuse it new_indices, new_last_node, _, _ = self.match_prefix( RadixKey(page_aligned_token_ids, req.extra_key) ) - assert len(req.prefix_indices) <= len( + assert old_prefix_len <= len( new_indices ), f"{req.prefix_indices=}, {new_indices=}" assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}" self.req_to_token_pool.write( - (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), - new_indices[len(req.prefix_indices) :], + (req.req_pool_idx, slice(old_prefix_len, len(new_indices))), + new_indices[old_prefix_len:], ) + req.last_matched_prefix_len = len(new_indices) + self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) swa_uuid_for_lock = self.inc_lock_ref(new_last_node) @@ -501,7 +551,13 @@ class SWARadixCache(BasePrefixCache): [new_indices, kv_indices[len(new_indices) :]] ) else: - req.prefix_indices = new_indices + if self.is_eagle: + # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill + req.prefix_indices = torch.cat( + [new_indices, kv_indices[actual_kv_len:]] + ) + else: + req.prefix_indices = new_indices req.last_node = new_last_node req.swa_uuid_for_lock = swa_uuid_for_lock diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py index f4c2a0e3e..3adab87fe 100644 --- a/python/sglang/srt/models/utils.py +++ b/python/sglang/srt/models/utils.py @@ -27,7 +27,11 @@ if _is_cuda: def enable_fused_set_kv_buffer(forward_batch: ForwardBatch): """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache.""" - return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16 + return ( + _is_cuda + and hasattr(forward_batch.token_to_kv_pool, "dtype") + and forward_batch.token_to_kv_pool.dtype == torch.bfloat16 + ) def create_fused_set_kv_buffer_arg( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 4d97e97c0..a2ec504cd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -113,6 +113,7 @@ suites = { TestFile("test_srt_engine.py", 261), TestFile("test_srt_endpoint.py", 130), TestFile("test_start_profile.py", 60), + TestFile("test_swa_unittest.py", 1), TestFile("test_torch_compile.py", 76), TestFile("test_torch_compile_moe.py", 172), TestFile("test_torch_native_attention_backend.py", 123), diff --git a/test/srt/test_swa_unittest.py b/test/srt/test_swa_unittest.py index 68c76e1f5..b11435b8f 100644 --- a/test/srt/test_swa_unittest.py +++ b/test/srt/test_swa_unittest.py @@ -50,9 +50,14 @@ class TestSWA(unittest.TestCase): kvcache=pool, need_sort=False, ) - assert alloc.available_size() == size + size_swa + self.assertEqual( + alloc.full_available_size() + alloc.swa_available_size(), size + size_swa + ) index = alloc.alloc(1) - assert alloc.available_size() == size_swa + size_swa - 2 + self.assertEqual( + alloc.full_available_size() + alloc.swa_available_size(), + size_swa + size_swa - 2, + ) alloc.free_swa(index) result = alloc.translate_loc_from_full_to_swa(index) print(result) @@ -117,7 +122,7 @@ class TestSWA(unittest.TestCase): f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" ) req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3) - assert len(req1_token_ids) == len(req1_kv_indices) + self.assertEqual(len(req1_token_ids), len(req1_kv_indices)) print( f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}" ) @@ -126,7 +131,7 @@ class TestSWA(unittest.TestCase): f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" ) req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7) - assert len(req2_token_ids) == len(req2_kv_indices) + self.assertEqual(len(req2_token_ids), len(req2_kv_indices)) print( f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}" ) @@ -135,7 +140,7 @@ class TestSWA(unittest.TestCase): f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" ) req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3) - assert len(req3_token_ids) == len(req3_kv_indices) + self.assertEqual(len(req3_token_ids), len(req3_kv_indices)) print( f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}" ) @@ -144,7 +149,7 @@ class TestSWA(unittest.TestCase): f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" ) req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7) - assert len(req4_token_ids) == len(req4_kv_indices) + self.assertEqual(len(req4_token_ids), len(req4_kv_indices)) print( f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}" ) @@ -175,7 +180,7 @@ class TestSWA(unittest.TestCase): print( f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" ) - assert len(kv_indices) == 0 + self.assertEqual(len(kv_indices), 0) req6_token_ids = [1, 2, 3, 4, 5, 60, 70] result = tree.match_prefix(RadixKey(req6_token_ids)) @@ -183,10 +188,146 @@ class TestSWA(unittest.TestCase): print( f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" ) - assert len(kv_indices) == 7 - assert len(last_node.key) == 2 - assert last_node.key.token_ids[0] == 60 - assert last_node.key.token_ids[1] == 70 + self.assertEqual(len(kv_indices), 7) + self.assertEqual(len(last_node.key), 2) + self.assertEqual(last_node.key.token_ids[0], 60) + self.assertEqual(last_node.key.token_ids[1], 70) + + def test_swa_radix_cache_eagle(self): + # args + req_size = 10 + max_context_len = 128 + kv_size = 128 + kv_size_swa = 64 + sliding_window_size = 4 + head_num = 8 + head_dim = 128 + num_layers = 48 + global_interval = 4 + dtype = torch.bfloat16 + device = "cuda" + full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] + full_attention_layer_ids_set = set(full_attention_layer_ids) + swa_attention_layer_ids = [ + i for i in range(num_layers) if i not in full_attention_layer_ids_set + ] + # setup req to token pool + req_to_token_pool = ReqToTokenPool( + size=req_size, + max_context_len=max_context_len, + device=device, + enable_memory_saver=False, + ) + # setup kv pool + kv_pool = SWAKVPool( + size=kv_size, + size_swa=kv_size_swa, + dtype=dtype, + head_num=head_num, + head_dim=head_dim, + swa_attention_layer_ids=swa_attention_layer_ids, + full_attention_layer_ids=full_attention_layer_ids, + enable_kvcache_transpose=False, + device=device, + ) + # setup token to kv pool allocator + allocator = SWATokenToKVPoolAllocator( + size=kv_size, + size_swa=kv_size_swa, + dtype=dtype, + device=device, + kvcache=kv_pool, + need_sort=False, + ) + # setup radix cache + tree = SWARadixCache( + req_to_token_pool=req_to_token_pool, + token_to_kv_pool_allocator=allocator, + sliding_window_size=sliding_window_size, + page_size=1, + disable=False, + is_eagle=True, + ) + + # test + print( + f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3) + self.assertEqual(len(req1_token_ids), len(req1_kv_indices)) + print( + f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}" + ) + prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices) + self.assertEqual(prefix_len, 0) + print( + f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7) + self.assertEqual(len(req2_token_ids), len(req2_kv_indices)) + print( + f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}" + ) + prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices) + self.assertEqual(prefix_len, 2) + print( + f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3) + self.assertEqual(len(req3_token_ids), len(req3_kv_indices)) + print( + f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}" + ) + prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices) + self.assertEqual(prefix_len, 0) + print( + f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7) + self.assertEqual(len(req4_token_ids), len(req4_kv_indices)) + print( + f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}" + ) + prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices) + self.assertEqual(prefix_len, 4) + print( + f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + + tree.pretty_print() + full_num_tokens, swa_num_tokens = 1, 0 + print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token") + tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens) + tree.pretty_print() + + full_num_tokens, swa_num_tokens = 0, 1 + print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token") + tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens) + tree.pretty_print() + + full_num_tokens, swa_num_tokens = 1, 2 + print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token") + tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens) + tree.pretty_print() + + req5_token_ids = [1, 2, 3, 4, 5] + result = tree.match_prefix(RadixKey(req5_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + self.assertEqual(len(kv_indices), 0) # no swa prefix matched + + req6_token_ids = [1, 2, 3, 4, 5, 60, 70] + result = tree.match_prefix(RadixKey(req6_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + self.assertEqual(len(kv_indices), 6) + self.assertEqual(len(last_node.key), 2) + self.assertEqual(last_node.key.token_ids[0], (5, 60)) + self.assertEqual(last_node.key.token_ids[1], (60, 70)) if __name__ == "__main__":