Fix eagle radix cache (#10846)

This commit is contained in:
Ke Bao
2025-09-30 22:59:20 +08:00
committed by GitHub
parent 5a290a5644
commit 91847e382a
6 changed files with 235 additions and 26 deletions

View File

@@ -307,6 +307,72 @@ class TestRadixCache(unittest.TestCase):
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
)
def test_insert_and_match_eagle(self):
"""Test insert and match operations for EAGLE."""
cache = RadixCache(
req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=1,
disable=False,
is_eagle=True,
)
key = RadixKey([1, 2, 3, 4])
value = torch.tensor([10, 20, 30, 40], dtype=torch.int64)
prefix_len = cache.insert(key, value)
self.assertEqual(prefix_len, 0) # No existing prefix
self.assertEqual(
cache.total_size(), 3
) # The last token is ignored in bigram key
self.assertEqual(cache.evictable_size(), 3)
# Test match_prefix
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
self.assertEqual(len(result.device_indices), 3)
torch.testing.assert_close(
result.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64)
)
# Test partial match
result = cache.match_prefix(RadixKey([1, 2]))
self.assertEqual(len(result.device_indices), 1)
torch.testing.assert_close(
result.device_indices, torch.tensor([10], dtype=torch.int64)
)
def test_insert_and_match_eagle_page_size(self):
"""Test insert and match operations for EAGLE and page_size > 1."""
cache = RadixCache(
req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=2,
disable=False,
is_eagle=True,
)
key = RadixKey([1, 2, 3])
value = torch.tensor([10, 20, 30], dtype=torch.int64)
prefix_len = cache.insert(key, value)
self.assertEqual(prefix_len, 0) # No existing prefix
self.assertEqual(cache.total_size(), 2) # only one page is inserted
self.assertEqual(cache.evictable_size(), 2)
# Test match_prefix
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
self.assertEqual(len(result.device_indices), 2)
torch.testing.assert_close(
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
)
# Test unmatched
result = cache.match_prefix(RadixKey([1, 2]))
self.assertEqual(len(result.device_indices), 0)
torch.testing.assert_close(
result.device_indices, torch.tensor([], dtype=torch.int64)
)
def test_insert_with_none_value(self):
"""Test insert with None value (should use token_ids as list)."""
cache = RadixCache(