[Minor] Refactors KV memory pool (#9842)
This commit is contained in:
@@ -31,16 +31,18 @@ class TestSWA(unittest.TestCase):
|
||||
i for i in range(num_layers) if i not in full_attention_layer_ids_set
|
||||
]
|
||||
pool = SWAKVPool(
|
||||
size,
|
||||
size_swa,
|
||||
dtype,
|
||||
num_head,
|
||||
head_dim,
|
||||
swa_attention_layer_ids,
|
||||
full_attention_layer_ids,
|
||||
device,
|
||||
size=size,
|
||||
size_swa=size_swa,
|
||||
dtype=dtype,
|
||||
num_head=num_head,
|
||||
head_dim=head_dim,
|
||||
swa_attention_layer_ids=swa_attention_layer_ids,
|
||||
full_attention_layer_ids=full_attention_layer_ids,
|
||||
device=device,
|
||||
)
|
||||
alloc = SWATokenToKVPoolAllocator(
|
||||
size=size, size_swa=size_swa, dtype=dtype, device=device, kvcache=pool
|
||||
)
|
||||
alloc = SWATokenToKVPoolAllocator(size, size_swa, dtype, device, pool)
|
||||
assert alloc.available_size() == size + size_swa
|
||||
index = alloc.alloc(1)
|
||||
assert alloc.available_size() == size_swa + size_swa - 2
|
||||
@@ -75,18 +77,22 @@ class TestSWA(unittest.TestCase):
|
||||
)
|
||||
# setup kv pool
|
||||
kv_pool = SWAKVPool(
|
||||
kv_size,
|
||||
kv_size_swa,
|
||||
dtype,
|
||||
num_head,
|
||||
head_dim,
|
||||
swa_attention_layer_ids,
|
||||
full_attention_layer_ids,
|
||||
device,
|
||||
size=kv_size,
|
||||
size_swa=kv_size_swa,
|
||||
dtype=dtype,
|
||||
num_head=num_head,
|
||||
head_dim=head_dim,
|
||||
swa_attention_layer_ids=swa_attention_layer_ids,
|
||||
full_attention_layer_ids=full_attention_layer_ids,
|
||||
device=device,
|
||||
)
|
||||
# setup token to kv pool allocator
|
||||
allocator = SWATokenToKVPoolAllocator(
|
||||
kv_size, kv_size_swa, dtype, device, kv_pool
|
||||
size=kv_size,
|
||||
size_swa=kv_size_swa,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
kvcache=kv_pool,
|
||||
)
|
||||
# setup radix cache
|
||||
tree = SWARadixCache(
|
||||
|
||||
Reference in New Issue
Block a user