598 lines
21 KiB
Python
598 lines
21 KiB
Python
"""
|
|
Unit tests for the RadixCache implementation.
|
|
|
|
This module tests the core functionality of RadixCache, RadixKey, and TreeNode
|
|
following SGLang testing patterns.
|
|
|
|
Test Coverage:
|
|
- RadixKey: token ID management, slicing, iteration, representation
|
|
- TreeNode: node properties, reference counting, hash values
|
|
- RadixCache: insert/match operations, eviction, page alignment, error handling
|
|
- Cache events and request handling
|
|
- Boundary conditions with parameterized testing
|
|
|
|
Usage:
|
|
python test_radix_cache_unit.py
|
|
python -m pytest test_radix_cache_unit.py -v
|
|
python -m pytest test_radix_cache_unit.py::TestRadixCache::test_insert_basic
|
|
"""
|
|
|
|
import time
|
|
import unittest
|
|
import unittest.mock
|
|
|
|
import torch
|
|
|
|
from sglang.srt.disaggregation.kv_events import BlockRemoved, BlockStored
|
|
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
|
|
|
# Test constants
|
|
DEFAULT_PAGE_SIZE = 4
|
|
|
|
|
|
class TestRadixKey(unittest.TestCase):
|
|
"""Test cases for RadixKey class."""
|
|
|
|
def test_init_basic(self):
|
|
"""Test basic initialization of RadixKey."""
|
|
token_ids = [1, 2, 3, 4]
|
|
key = RadixKey(token_ids)
|
|
self.assertEqual(key.token_ids, token_ids)
|
|
self.assertIsNone(key.extra_key)
|
|
|
|
def test_init_with_extra_key(self):
|
|
"""Test initialization with extra_key."""
|
|
token_ids = [1, 2, 3]
|
|
extra_key = "test_key"
|
|
key = RadixKey(token_ids, extra_key)
|
|
self.assertEqual(key.token_ids, token_ids)
|
|
self.assertEqual(key.extra_key, extra_key)
|
|
|
|
def test_len(self):
|
|
"""Test __len__ method."""
|
|
key = RadixKey([1, 2, 3])
|
|
self.assertEqual(len(key), 3)
|
|
|
|
empty_key = RadixKey([])
|
|
self.assertEqual(len(empty_key), 0)
|
|
|
|
def test_iter(self):
|
|
"""Test __iter__ method."""
|
|
token_ids = [1, 2, 3, 4]
|
|
key = RadixKey(token_ids)
|
|
self.assertEqual(list(key), token_ids)
|
|
|
|
def test_len_and_iter(self):
|
|
"""Test __len__ and __iter__ methods."""
|
|
test_cases = [
|
|
([1, 2, 3], 3),
|
|
([], 0),
|
|
([42], 1),
|
|
]
|
|
|
|
for tokens, expected in test_cases:
|
|
with self.subTest(tokens=tokens):
|
|
key = RadixKey(tokens)
|
|
self.assertEqual(len(key), expected)
|
|
self.assertEqual(list(key), tokens)
|
|
|
|
def test_getitem_int(self):
|
|
"""Test __getitem__ with int index."""
|
|
test_cases = [
|
|
([10, 20, 30], 0, [10]),
|
|
([10, 20, 30], -1, [30]),
|
|
([10, 20, 30], 2, [30]),
|
|
]
|
|
|
|
for tokens, index, expected in test_cases:
|
|
with self.subTest(tokens=tokens, index=index):
|
|
key = RadixKey(tokens)
|
|
result = key[index]
|
|
self.assertIsInstance(result, RadixKey)
|
|
self.assertEqual(result.token_ids, expected)
|
|
|
|
def test_getitem_slice(self):
|
|
"""Test __getitem__ with slice and edge cases."""
|
|
key = RadixKey([1, 2, 3, 4, 5], "extra")
|
|
|
|
# Basic slice
|
|
sliced = key[1:4]
|
|
self.assertIsInstance(sliced, RadixKey)
|
|
self.assertEqual(sliced.token_ids, [2, 3, 4])
|
|
self.assertEqual(sliced.extra_key, "extra")
|
|
|
|
# Edge cases
|
|
self.assertEqual(key[2:2].token_ids, []) # Empty slice
|
|
self.assertEqual(key[:].token_ids, [1, 2, 3, 4, 5]) # Full slice
|
|
|
|
def test_getitem_invalid_index(self):
|
|
"""Test __getitem__ with invalid indices."""
|
|
key = RadixKey([1, 2, 3])
|
|
with self.assertRaises(IndexError):
|
|
_ = key[10] # Out of bounds
|
|
|
|
def test_repr(self):
|
|
"""Test __repr__ method."""
|
|
key = RadixKey([1, 2, 3], "test")
|
|
repr_str = repr(key)
|
|
self.assertIn("RadixKey", repr_str)
|
|
self.assertIn("extra_key='test'", repr_str)
|
|
self.assertIn("[1, 2, 3]", repr_str)
|
|
|
|
def test_repr_long_token_ids(self):
|
|
"""Test __repr__ with long token_ids."""
|
|
long_tokens = list(range(15))
|
|
key = RadixKey(long_tokens)
|
|
repr_str = repr(key)
|
|
self.assertIn("...", repr_str) # Should be truncated
|
|
|
|
|
|
class TestTreeNode(unittest.TestCase):
|
|
"""Test cases for TreeNode class."""
|
|
|
|
def setUp(self):
|
|
"""Reset the counter before each test."""
|
|
TreeNode.counter = 0
|
|
|
|
def test_init_basic(self):
|
|
"""Test basic initialization of TreeNode."""
|
|
node = TreeNode()
|
|
self.assertEqual(node.id, 0)
|
|
self.assertEqual(len(node.children), 0)
|
|
self.assertIsNone(node.parent)
|
|
self.assertIsNone(node.key)
|
|
self.assertIsNone(node.value)
|
|
self.assertEqual(node.lock_ref, 0)
|
|
self.assertEqual(node.hit_count, 0)
|
|
self.assertEqual(node.host_ref_counter, 0)
|
|
self.assertIsNone(node.host_value)
|
|
self.assertIsNone(node.hash_value)
|
|
|
|
def test_init_with_id(self):
|
|
"""Test initialization with custom ID."""
|
|
node = TreeNode(id=42)
|
|
self.assertEqual(node.id, 42)
|
|
node2 = TreeNode()
|
|
self.assertEqual(node2.id, 1) # Counter was incremented
|
|
|
|
def test_counter_increment(self):
|
|
"""Test that counter increments properly."""
|
|
node1 = TreeNode()
|
|
node2 = TreeNode()
|
|
self.assertEqual(node1.id, 0)
|
|
self.assertEqual(node2.id, 1)
|
|
|
|
def test_evicted_backuped_properties(self):
|
|
"""Test evicted and backuped properties."""
|
|
test_cases = [
|
|
(False, False, True, False),
|
|
(True, False, False, False),
|
|
(True, True, False, True),
|
|
(False, True, True, True),
|
|
]
|
|
|
|
for (
|
|
has_value,
|
|
has_host_value,
|
|
expected_evicted,
|
|
expected_backuped,
|
|
) in test_cases:
|
|
with self.subTest(has_value=has_value, has_host_value=has_host_value):
|
|
node = TreeNode()
|
|
|
|
if has_value:
|
|
node.value = torch.tensor([1, 2, 3])
|
|
if has_host_value:
|
|
node.host_value = torch.tensor([4, 5, 6])
|
|
|
|
self.assertEqual(node.evicted, expected_evicted)
|
|
self.assertEqual(node.backuped, expected_backuped)
|
|
|
|
def test_protect_release_host(self):
|
|
"""Test protect_host and release_host methods."""
|
|
node = TreeNode()
|
|
self.assertEqual(node.host_ref_counter, 0)
|
|
|
|
node.protect_host()
|
|
self.assertEqual(node.host_ref_counter, 1)
|
|
|
|
node.release_host()
|
|
self.assertEqual(node.host_ref_counter, 0)
|
|
|
|
# Test error case
|
|
with self.assertRaises(RuntimeError):
|
|
node.release_host()
|
|
|
|
def test_get_last_hash_value(self):
|
|
"""Test get_last_hash_value method."""
|
|
node = TreeNode()
|
|
self.assertIsNone(node.get_last_hash_value())
|
|
|
|
node.hash_value = ["hash1", "hash2", "hash3"]
|
|
self.assertEqual(node.get_last_hash_value(), "hash3")
|
|
|
|
def test_lt_comparison(self):
|
|
"""Test less than comparison based on last_access_time."""
|
|
node1 = TreeNode()
|
|
time.sleep(0.001) # Small delay to ensure different timestamps
|
|
node2 = TreeNode()
|
|
|
|
self.assertTrue(node1 < node2)
|
|
self.assertFalse(node2 < node1)
|
|
|
|
|
|
class TestRadixCache(unittest.TestCase):
|
|
"""Test cases for RadixCache class."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
TreeNode.counter = 0
|
|
|
|
def test_init_variations(self):
|
|
"""Test cache initialization with different parameters."""
|
|
test_cases = [
|
|
(1, False, False),
|
|
(4, False, True),
|
|
(1, True, False),
|
|
]
|
|
|
|
for page_size, disable, enable_events in test_cases:
|
|
with self.subTest(
|
|
page_size=page_size, disable=disable, enable_events=enable_events
|
|
):
|
|
cache = RadixCache(
|
|
req_to_token_pool=None,
|
|
token_to_kv_pool_allocator=None,
|
|
page_size=page_size,
|
|
disable=disable,
|
|
enable_kv_cache_events=enable_events,
|
|
)
|
|
|
|
self.assertEqual(cache.page_size, page_size)
|
|
self.assertEqual(cache.disable, disable)
|
|
self.assertEqual(cache.enable_kv_cache_events, enable_events)
|
|
self.assertEqual(cache.device, torch.device("cpu"))
|
|
self.assertIsNotNone(cache.root_node)
|
|
self.assertEqual(len(cache.root_node.key), 0)
|
|
|
|
def test_reset(self):
|
|
"""Test reset method."""
|
|
cache = RadixCache(
|
|
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
|
)
|
|
|
|
# Insert some data
|
|
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
|
self.assertGreater(cache.total_size(), 0)
|
|
|
|
# Reset
|
|
cache.reset()
|
|
self.assertEqual(cache.total_size(), 0)
|
|
self.assertEqual(cache.evictable_size(), 0)
|
|
self.assertEqual(cache.protected_size(), 0)
|
|
|
|
def test_insert_and_match_basic(self):
|
|
"""Test basic insert and match operations."""
|
|
for disable_cache in [False, True]:
|
|
with self.subTest(disable_cache=disable_cache):
|
|
cache = RadixCache(
|
|
req_to_token_pool=None,
|
|
token_to_kv_pool_allocator=None,
|
|
page_size=1,
|
|
disable=disable_cache,
|
|
)
|
|
|
|
key = RadixKey([1, 2, 3])
|
|
value = torch.tensor([10, 20, 30], dtype=torch.int64)
|
|
prefix_len = cache.insert(key, value)
|
|
|
|
if disable_cache:
|
|
self.assertEqual(prefix_len, 0)
|
|
self.assertEqual(cache.total_size(), 0)
|
|
continue
|
|
|
|
self.assertEqual(prefix_len, 0) # No existing prefix
|
|
self.assertEqual(cache.total_size(), 3)
|
|
self.assertEqual(cache.evictable_size(), 3)
|
|
|
|
# Test match_prefix
|
|
result = cache.match_prefix(RadixKey([1, 2, 3]))
|
|
self.assertEqual(len(result.device_indices), 3)
|
|
torch.testing.assert_close(result.device_indices, value)
|
|
|
|
# Test partial match
|
|
result = cache.match_prefix(RadixKey([1, 2]))
|
|
self.assertEqual(len(result.device_indices), 2)
|
|
torch.testing.assert_close(
|
|
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
|
|
)
|
|
|
|
def test_insert_with_none_value(self):
|
|
"""Test insert with None value (should use token_ids as list)."""
|
|
cache = RadixCache(
|
|
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
|
)
|
|
|
|
key = RadixKey([1, 2, 3])
|
|
prefix_len = cache.insert(key, None)
|
|
|
|
# When None is passed, it should create value from token_ids
|
|
self.assertEqual(prefix_len, 0)
|
|
self.assertEqual(cache.total_size(), 3)
|
|
|
|
def test_total_size(self):
|
|
"""Test total_size calculation."""
|
|
cache = RadixCache(
|
|
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
|
)
|
|
|
|
self.assertEqual(cache.total_size(), 0)
|
|
|
|
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
|
self.assertEqual(cache.total_size(), 3)
|
|
|
|
cache.insert(RadixKey([4, 5]), torch.tensor([40, 50], dtype=torch.int64))
|
|
self.assertEqual(cache.total_size(), 5)
|
|
|
|
def test_kv_cache_events(self):
|
|
"""Test KV cache events functionality."""
|
|
test_cases = [
|
|
(1, True),
|
|
(2, True),
|
|
(1, False),
|
|
]
|
|
|
|
for page_size, enable_events in test_cases:
|
|
with self.subTest(page_size=page_size, enable_events=enable_events):
|
|
cache = RadixCache(
|
|
req_to_token_pool=None,
|
|
token_to_kv_pool_allocator=None,
|
|
page_size=page_size,
|
|
enable_kv_cache_events=enable_events,
|
|
)
|
|
|
|
# Insert data
|
|
cache.insert(RadixKey([1, 2, 3, 4, 5]), None)
|
|
|
|
# Take events
|
|
events = cache.take_events()
|
|
|
|
if enable_events:
|
|
self.assertGreater(len(events), 0)
|
|
# Verify events include BlockStored events (there might be other event types)
|
|
block_stored_events = [
|
|
e for e in events if isinstance(e, BlockStored)
|
|
]
|
|
self.assertGreater(len(block_stored_events), 0)
|
|
for event in block_stored_events:
|
|
self.assertLessEqual(len(event.token_ids), page_size)
|
|
else:
|
|
self.assertEqual(len(events), 0)
|
|
|
|
def test_kv_cache_events_with_eviction(self):
|
|
"""Test KV cache events include removal events."""
|
|
mock_allocator = unittest.mock.Mock()
|
|
mock_allocator.device = torch.device("cpu")
|
|
|
|
cache = RadixCache(
|
|
req_to_token_pool=None,
|
|
token_to_kv_pool_allocator=mock_allocator,
|
|
page_size=1,
|
|
enable_kv_cache_events=True,
|
|
)
|
|
|
|
# Insert and then evict data
|
|
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
|
cache.evict(3)
|
|
|
|
# Take events - should include both store and remove events
|
|
events = cache.take_events()
|
|
self.assertGreater(len(events), 0)
|
|
|
|
# Check event types
|
|
event_types = [type(event).__name__ for event in events]
|
|
self.assertIn("BlockStored", event_types)
|
|
|
|
# Verify BlockRemoved event content
|
|
remove_events = [e for e in events if isinstance(e, BlockRemoved)]
|
|
for event in remove_events:
|
|
self.assertGreater(len(event.block_hashes), 0)
|
|
|
|
def test_extra_key_isolation(self):
|
|
"""Test that keys with different extra_key values are isolated."""
|
|
cache = RadixCache(
|
|
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
|
)
|
|
|
|
# Insert same token sequence with different extra keys
|
|
cache.insert(
|
|
RadixKey([1, 2, 3], "key1"), torch.tensor([10, 20, 30], dtype=torch.int64)
|
|
)
|
|
cache.insert(
|
|
RadixKey([1, 2, 3], "key2"), torch.tensor([40, 50, 60], dtype=torch.int64)
|
|
)
|
|
cache.insert(
|
|
RadixKey([1, 2, 3], None), torch.tensor([70, 80, 90], dtype=torch.int64)
|
|
)
|
|
|
|
# Keys with different extra_key should not match each other
|
|
result1 = cache.match_prefix(RadixKey([1, 2, 3], "key1"))
|
|
result2 = cache.match_prefix(RadixKey([1, 2, 3], "key2"))
|
|
result3 = cache.match_prefix(RadixKey([1, 2, 3], None))
|
|
result4 = cache.match_prefix(RadixKey([1, 2, 3], "nonexistent"))
|
|
|
|
# Each should match only its own data
|
|
self.assertEqual(len(result1.device_indices), 3)
|
|
torch.testing.assert_close(
|
|
result1.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64)
|
|
)
|
|
|
|
self.assertEqual(len(result2.device_indices), 3)
|
|
torch.testing.assert_close(
|
|
result2.device_indices, torch.tensor([40, 50, 60], dtype=torch.int64)
|
|
)
|
|
|
|
self.assertEqual(len(result3.device_indices), 3)
|
|
torch.testing.assert_close(
|
|
result3.device_indices, torch.tensor([70, 80, 90], dtype=torch.int64)
|
|
)
|
|
|
|
# Non-existent extra_key should not match
|
|
self.assertEqual(len(result4.device_indices), 0)
|
|
|
|
def test_lock_ref_operations(self):
|
|
"""Test lock reference counting operations."""
|
|
cache = RadixCache(
|
|
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
|
)
|
|
|
|
# Insert sequence
|
|
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
|
|
|
# Get node
|
|
result = cache.match_prefix(RadixKey([1, 2, 3]))
|
|
node = result.last_device_node
|
|
|
|
initial_evictable = cache.evictable_size()
|
|
initial_protected = cache.protected_size()
|
|
|
|
# Lock the node
|
|
cache.inc_lock_ref(node)
|
|
self.assertEqual(cache.protected_size(), initial_protected + 3)
|
|
self.assertEqual(cache.evictable_size(), initial_evictable - 3)
|
|
|
|
# Unlock the node
|
|
cache.dec_lock_ref(node)
|
|
self.assertEqual(cache.protected_size(), initial_protected)
|
|
self.assertEqual(cache.evictable_size(), initial_evictable)
|
|
|
|
def test_evict_functionality(self):
|
|
"""Test eviction functionality."""
|
|
mock_allocator = unittest.mock.Mock()
|
|
mock_allocator.device = torch.device("cpu")
|
|
|
|
cache = RadixCache(
|
|
req_to_token_pool=None,
|
|
token_to_kv_pool_allocator=mock_allocator,
|
|
page_size=1,
|
|
)
|
|
|
|
# Insert sequences
|
|
cache.insert(RadixKey([1, 2]), torch.tensor([10, 20], dtype=torch.int64))
|
|
cache.insert(RadixKey([3, 4]), torch.tensor([30, 40], dtype=torch.int64))
|
|
|
|
initial_size = cache.total_size()
|
|
|
|
# Evict some tokens
|
|
cache.evict(2)
|
|
|
|
# Should have called free and reduced size
|
|
mock_allocator.free.assert_called()
|
|
self.assertLess(cache.total_size(), initial_size)
|
|
|
|
def test_page_alignment_boundary(self):
|
|
"""Test page alignment with different sizes."""
|
|
test_cases = [
|
|
(1, 5),
|
|
(2, 5),
|
|
(4, 6),
|
|
]
|
|
|
|
for page_size, sequence_length in test_cases:
|
|
with self.subTest(page_size=page_size, sequence_length=sequence_length):
|
|
cache = RadixCache(
|
|
req_to_token_pool=None,
|
|
token_to_kv_pool_allocator=None,
|
|
page_size=page_size,
|
|
)
|
|
|
|
tokens = list(range(sequence_length))
|
|
cache.insert(RadixKey(tokens), torch.tensor(tokens, dtype=torch.int64))
|
|
|
|
result = cache.match_prefix(RadixKey(tokens))
|
|
self.assertGreater(len(result.device_indices), 0)
|
|
|
|
# Match length should be page-aligned
|
|
match_len = len(result.device_indices)
|
|
self.assertEqual(match_len % page_size, 0)
|
|
|
|
def test_pretty_print_basic(self):
|
|
"""Test pretty_print produces output."""
|
|
cache = RadixCache(
|
|
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
|
)
|
|
|
|
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64))
|
|
|
|
# Just test that it doesn't crash
|
|
try:
|
|
cache.pretty_print()
|
|
except Exception as e:
|
|
self.fail(f"pretty_print raised an exception: {e}")
|
|
|
|
def test_all_values_flatten(self):
|
|
"""Test all_values_flatten method."""
|
|
cache = RadixCache(
|
|
req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1
|
|
)
|
|
|
|
cache.insert(RadixKey([1, 2]), torch.tensor([10, 20], dtype=torch.int64))
|
|
cache.insert(RadixKey([3, 4]), torch.tensor([30, 40], dtype=torch.int64))
|
|
|
|
all_values = cache.all_values_flatten()
|
|
self.assertEqual(len(all_values), 4)
|
|
# Values should contain all inserted values (order may vary)
|
|
values_set = set(all_values.tolist())
|
|
self.assertEqual(values_set, {10, 20, 30, 40})
|
|
|
|
def test_advanced_prefix_match_with_node_splits(self):
|
|
"""Advanced prefix matching: splits inside nodes and across pages."""
|
|
for page_size in [1, 2]:
|
|
with self.subTest(page_size=page_size):
|
|
cache = RadixCache(
|
|
req_to_token_pool=None,
|
|
token_to_kv_pool_allocator=None,
|
|
page_size=page_size,
|
|
)
|
|
|
|
# Insert a long sequence that will be split later.
|
|
seq1 = [1, 2, 3, 4, 5, 6, 7, 8]
|
|
val1 = torch.tensor([x * 10 for x in seq1], dtype=torch.int64)
|
|
cache.insert(RadixKey(seq1), val1)
|
|
|
|
# Insert a diverging branch to create an internal node on the path.
|
|
seq2 = [1, 2, 9, 10]
|
|
val2 = torch.tensor([x * 10 for x in seq2], dtype=torch.int64)
|
|
cache.insert(RadixKey(seq2), val2)
|
|
print(cache.pretty_print())
|
|
|
|
baseline_total = cache.total_size()
|
|
expected_total = 10 # 8 + 2
|
|
self.assertEqual(baseline_total, expected_total)
|
|
|
|
# Match that causes a split inside an existing node:
|
|
# take first 4 tokens of seq1, then diverge.
|
|
query1 = [1, 2, 3, 4, 999, 1000]
|
|
result1 = cache.match_prefix(RadixKey(query1))
|
|
torch.testing.assert_close(result1.device_indices, val1[:4])
|
|
# No data change after structural split during matching.
|
|
self.assertEqual(cache.total_size(), baseline_total)
|
|
|
|
# Full match of the long sequence still returns the full indices.
|
|
result_full = cache.match_prefix(RadixKey(seq1))
|
|
torch.testing.assert_close(result_full.device_indices, val1)
|
|
|
|
# Another split deeper on the path (after matching 6 tokens, then diverge).
|
|
query2 = [1, 2, 3, 4, 5, 6, 777, 888]
|
|
result2 = cache.match_prefix(RadixKey(query2))
|
|
torch.testing.assert_close(result2.device_indices, val1[:6])
|
|
self.assertEqual(cache.total_size(), baseline_total)
|
|
|
|
# Matching the short diverging branch should return exactly its indices.
|
|
result_branch = cache.match_prefix(RadixKey(seq2))
|
|
torch.testing.assert_close(result_branch.device_indices, val2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|