Increase the capacity of the memory pool (#643)

This commit is contained in:
Ying Sheng
2024-07-17 15:44:41 -07:00
committed by GitHub
parent abd5385ac5
commit 476584cb6e
4 changed files with 18 additions and 16 deletions

View File

@@ -44,7 +44,7 @@ class ReqToTokenPool:
class TokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""
def __init__(self, size, dtype, head_num, head_dim, layer_num):
def __init__(self, size: int, dtype: torch.dtype, head_num: int, head_dim: int, layer_num: int):
self.size = size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
@@ -63,16 +63,16 @@ class TokenToKVPool:
self.can_use_mem_size = self.size
self.clear()
def get_key_buffer(self, layer_id):
def get_key_buffer(self, layer_id: int):
return self.kv_data[layer_id][:, 0]
def get_value_buffer(self, layer_id):
def get_value_buffer(self, layer_id: int):
return self.kv_data[layer_id][:, 1]
def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer)
def alloc(self, need_size):
def alloc(self, need_size: int):
buffer_len = len(self.prefetch_buffer)
if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size]