Fix memory pool index error (#616)
This commit is contained in:
@@ -46,7 +46,7 @@ class TokenToKVPool:
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
self.kv_data = [
|
||||
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
@@ -127,4 +127,4 @@ class TokenToKVPool:
|
||||
self.total_ref_ct = 0
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.add_refs(torch.tensor([0], dtype=torch.int32))
|
||||
self.add_refs(torch.tensor([0], dtype=torch.int32))
|
||||
|
||||
Reference in New Issue
Block a user