Refactors radix cache for extra key support (#10317)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -4,7 +4,8 @@ import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import SWARadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
|
||||
|
||||
class TestSWA(unittest.TestCase):
|
||||
@@ -19,7 +20,7 @@ class TestSWA(unittest.TestCase):
|
||||
def test_swa_memory_pool(self):
|
||||
size = 16
|
||||
size_swa = 16
|
||||
num_head = 8
|
||||
head_num = 8
|
||||
head_dim = 128
|
||||
num_layers = 48
|
||||
global_interval = 4
|
||||
@@ -34,14 +35,20 @@ class TestSWA(unittest.TestCase):
|
||||
size=size,
|
||||
size_swa=size_swa,
|
||||
dtype=dtype,
|
||||
num_head=num_head,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
swa_attention_layer_ids=swa_attention_layer_ids,
|
||||
full_attention_layer_ids=full_attention_layer_ids,
|
||||
enable_kvcache_transpose=False,
|
||||
device=device,
|
||||
)
|
||||
alloc = SWATokenToKVPoolAllocator(
|
||||
size=size, size_swa=size_swa, dtype=dtype, device=device, kvcache=pool
|
||||
size=size,
|
||||
size_swa=size_swa,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
kvcache=pool,
|
||||
need_sort=False,
|
||||
)
|
||||
assert alloc.available_size() == size + size_swa
|
||||
index = alloc.alloc(1)
|
||||
@@ -57,7 +64,7 @@ class TestSWA(unittest.TestCase):
|
||||
kv_size = 128
|
||||
kv_size_swa = 64
|
||||
sliding_window_size = 4
|
||||
num_head = 8
|
||||
head_num = 8
|
||||
head_dim = 128
|
||||
num_layers = 48
|
||||
global_interval = 4
|
||||
@@ -80,10 +87,11 @@ class TestSWA(unittest.TestCase):
|
||||
size=kv_size,
|
||||
size_swa=kv_size_swa,
|
||||
dtype=dtype,
|
||||
num_head=num_head,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
swa_attention_layer_ids=swa_attention_layer_ids,
|
||||
full_attention_layer_ids=full_attention_layer_ids,
|
||||
enable_kvcache_transpose=False,
|
||||
device=device,
|
||||
)
|
||||
# setup token to kv pool allocator
|
||||
@@ -93,6 +101,7 @@ class TestSWA(unittest.TestCase):
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
kvcache=kv_pool,
|
||||
need_sort=False,
|
||||
)
|
||||
# setup radix cache
|
||||
tree = SWARadixCache(
|
||||
@@ -112,7 +121,7 @@ class TestSWA(unittest.TestCase):
|
||||
print(
|
||||
f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
|
||||
)
|
||||
prefix_len = tree.insert(req1_token_ids, req1_kv_indices)
|
||||
prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices)
|
||||
print(
|
||||
f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
|
||||
)
|
||||
@@ -121,7 +130,7 @@ class TestSWA(unittest.TestCase):
|
||||
print(
|
||||
f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
|
||||
)
|
||||
prefix_len = tree.insert(req2_token_ids, req2_kv_indices)
|
||||
prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices)
|
||||
print(
|
||||
f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
|
||||
)
|
||||
@@ -130,7 +139,7 @@ class TestSWA(unittest.TestCase):
|
||||
print(
|
||||
f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
|
||||
)
|
||||
prefix_len = tree.insert(req3_token_ids, req3_kv_indices)
|
||||
prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices)
|
||||
print(
|
||||
f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
|
||||
)
|
||||
@@ -139,7 +148,7 @@ class TestSWA(unittest.TestCase):
|
||||
print(
|
||||
f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
|
||||
)
|
||||
prefix_len = tree.insert(req4_token_ids, req4_kv_indices)
|
||||
prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices)
|
||||
print(
|
||||
f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
|
||||
)
|
||||
@@ -161,21 +170,23 @@ class TestSWA(unittest.TestCase):
|
||||
tree.pretty_print()
|
||||
|
||||
req5_token_ids = [1, 2, 3, 4, 5]
|
||||
kv_indices, last_node = tree.match_prefix(req5_token_ids)
|
||||
result = tree.match_prefix(RadixKey(req5_token_ids))
|
||||
kv_indices, last_node = result.device_indices, result.last_device_node
|
||||
print(
|
||||
f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
|
||||
)
|
||||
assert len(kv_indices) == 0
|
||||
|
||||
req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
|
||||
kv_indices, last_node = tree.match_prefix(req6_token_ids)
|
||||
result = tree.match_prefix(RadixKey(req6_token_ids))
|
||||
kv_indices, last_node = result.device_indices, result.last_device_node
|
||||
print(
|
||||
f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
|
||||
)
|
||||
assert len(kv_indices) == 7
|
||||
assert len(last_node.key) == 2
|
||||
assert last_node.key[0] == 60
|
||||
assert last_node.key[1] == 70
|
||||
assert last_node.key.token_ids[0] == 60
|
||||
assert last_node.key.token_ids[1] == 70
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user