Optimize mem indices mangement (#619)
This commit is contained in:
@@ -39,10 +39,12 @@ class ReqToTokenPool:
|
||||
class TokenToKVPool:
|
||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||
self.size = size
|
||||
# mem_state is the reference counter.
|
||||
# This can be promised:
|
||||
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
|
||||
self.total_ref_ct = 0
|
||||
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||
self.total_size = self.size
|
||||
self.total_alloc = 0
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
self.kv_data = [
|
||||
@@ -71,7 +73,9 @@ class TokenToKVPool:
|
||||
|
||||
addition_size = need_size - buffer_len
|
||||
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
||||
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
|
||||
select_index = (
|
||||
torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
|
||||
)
|
||||
|
||||
if select_index.shape[0] < addition_size:
|
||||
return None
|
||||
@@ -105,26 +109,22 @@ class TokenToKVPool:
|
||||
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
||||
|
||||
def used_size(self):
|
||||
return len(torch.nonzero(self.mem_state).squeeze(1))
|
||||
return self.total_alloc
|
||||
|
||||
def available_size(self):
|
||||
return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
|
||||
return self.total_size - self.total_alloc + len(self.prefetch_buffer)
|
||||
|
||||
def add_refs(self, token_index: torch.Tensor):
|
||||
self.total_ref_ct += len(token_index)
|
||||
self.mem_state[token_index] += 1
|
||||
self.total_alloc += len(token_index)
|
||||
self.mem_state[token_index] ^= True
|
||||
|
||||
def dec_refs(self, token_index: torch.Tensor):
|
||||
self.total_ref_ct -= len(token_index)
|
||||
self.mem_state[token_index] -= 1
|
||||
|
||||
num_freed = torch.sum(self.mem_state[token_index] == 0)
|
||||
|
||||
return num_freed
|
||||
self.total_alloc -= len(token_index)
|
||||
self.mem_state[token_index] ^= True
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(0)
|
||||
self.total_ref_ct = 0
|
||||
self.total_alloc = 0
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.add_refs(torch.tensor([0], dtype=torch.int32))
|
||||
self.mem_state[0] = True
|
||||
|
||||
Reference in New Issue
Block a user