[Minor] Refactors KV memory pool (#9842)

This commit is contained in:
Xinyuan Tong
2025-09-06 00:06:08 +00:00
committed by GitHub
parent f84db115b1
commit 273b28344b
2 changed files with 59 additions and 61 deletions

View File

@@ -31,16 +31,18 @@ class TestSWA(unittest.TestCase):
i for i in range(num_layers) if i not in full_attention_layer_ids_set
]
pool = SWAKVPool(
size,
size_swa,
dtype,
num_head,
head_dim,
swa_attention_layer_ids,
full_attention_layer_ids,
device,
size=size,
size_swa=size_swa,
dtype=dtype,
num_head=num_head,
head_dim=head_dim,
swa_attention_layer_ids=swa_attention_layer_ids,
full_attention_layer_ids=full_attention_layer_ids,
device=device,
)
alloc = SWATokenToKVPoolAllocator(
size=size, size_swa=size_swa, dtype=dtype, device=device, kvcache=pool
)
alloc = SWATokenToKVPoolAllocator(size, size_swa, dtype, device, pool)
assert alloc.available_size() == size + size_swa
index = alloc.alloc(1)
assert alloc.available_size() == size_swa + size_swa - 2
@@ -75,18 +77,22 @@ class TestSWA(unittest.TestCase):
)
# setup kv pool
kv_pool = SWAKVPool(
kv_size,
kv_size_swa,
dtype,
num_head,
head_dim,
swa_attention_layer_ids,
full_attention_layer_ids,
device,
size=kv_size,
size_swa=kv_size_swa,
dtype=dtype,
num_head=num_head,
head_dim=head_dim,
swa_attention_layer_ids=swa_attention_layer_ids,
full_attention_layer_ids=full_attention_layer_ids,
device=device,
)
# setup token to kv pool allocator
allocator = SWATokenToKVPoolAllocator(
kv_size, kv_size_swa, dtype, device, kv_pool
size=kv_size,
size_swa=kv_size_swa,
dtype=dtype,
device=device,
kvcache=kv_pool,
)
# setup radix cache
tree = SWARadixCache(