Increase the capacity of the memory pool (#643)
This commit is contained in:
@@ -44,7 +44,7 @@ class ReqToTokenPool:
|
||||
class TokenToKVPool:
|
||||
"""A memory pool that maps a token to its kv cache locations"""
|
||||
|
||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||
def __init__(self, size: int, dtype: torch.dtype, head_num: int, head_dim: int, layer_num: int):
|
||||
self.size = size
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
@@ -63,16 +63,16 @@ class TokenToKVPool:
|
||||
self.can_use_mem_size = self.size
|
||||
self.clear()
|
||||
|
||||
def get_key_buffer(self, layer_id):
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.kv_data[layer_id][:, 0]
|
||||
|
||||
def get_value_buffer(self, layer_id):
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
return self.kv_data[layer_id][:, 1]
|
||||
|
||||
def available_size(self):
|
||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
||||
|
||||
def alloc(self, need_size):
|
||||
def alloc(self, need_size: int):
|
||||
buffer_len = len(self.prefetch_buffer)
|
||||
if need_size <= buffer_len:
|
||||
select_index = self.prefetch_buffer[:need_size]
|
||||
|
||||
Reference in New Issue
Block a user