Fix eagle radix cache (#10846)
This commit is contained in:
@@ -307,6 +307,72 @@ class TestRadixCache(unittest.TestCase):
|
||||
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_and_match_eagle(self):
|
||||
"""Test insert and match operations for EAGLE."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=1,
|
||||
disable=False,
|
||||
is_eagle=True,
|
||||
)
|
||||
|
||||
key = RadixKey([1, 2, 3, 4])
|
||||
value = torch.tensor([10, 20, 30, 40], dtype=torch.int64)
|
||||
prefix_len = cache.insert(key, value)
|
||||
|
||||
self.assertEqual(prefix_len, 0) # No existing prefix
|
||||
self.assertEqual(
|
||||
cache.total_size(), 3
|
||||
) # The last token is ignored in bigram key
|
||||
self.assertEqual(cache.evictable_size(), 3)
|
||||
|
||||
# Test match_prefix
|
||||
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
|
||||
self.assertEqual(len(result.device_indices), 3)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Test partial match
|
||||
result = cache.match_prefix(RadixKey([1, 2]))
|
||||
self.assertEqual(len(result.device_indices), 1)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_and_match_eagle_page_size(self):
|
||||
"""Test insert and match operations for EAGLE and page_size > 1."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=2,
|
||||
disable=False,
|
||||
is_eagle=True,
|
||||
)
|
||||
|
||||
key = RadixKey([1, 2, 3])
|
||||
value = torch.tensor([10, 20, 30], dtype=torch.int64)
|
||||
prefix_len = cache.insert(key, value)
|
||||
|
||||
self.assertEqual(prefix_len, 0) # No existing prefix
|
||||
self.assertEqual(cache.total_size(), 2) # only one page is inserted
|
||||
self.assertEqual(cache.evictable_size(), 2)
|
||||
|
||||
# Test match_prefix
|
||||
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
|
||||
self.assertEqual(len(result.device_indices), 2)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Test unmatched
|
||||
result = cache.match_prefix(RadixKey([1, 2]))
|
||||
self.assertEqual(len(result.device_indices), 0)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_with_none_value(self):
|
||||
"""Test insert with None value (should use token_ids as list)."""
|
||||
cache = RadixCache(
|
||||
|
||||
Reference in New Issue
Block a user