Fix memory pool index error (#616)

This commit is contained in:
Ying Sheng
2024-07-13 16:45:11 -07:00
committed by GitHub
parent 0feca02dd9
commit 5949b1ca0e
4 changed files with 9 additions and 11 deletions

View File

@@ -46,7 +46,7 @@ class TokenToKVPool:
# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
@@ -127,4 +127,4 @@ class TokenToKVPool:
self.total_ref_ct = 0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.add_refs(torch.tensor([0], dtype=torch.int32))
self.add_refs(torch.tensor([0], dtype=torch.int32))