Optimize mem indices mangement (#619)

This commit is contained in:
Liangsheng Yin
2024-07-13 23:39:37 -07:00
committed by GitHub
parent 5d264a90ac
commit 564a898ad9
15 changed files with 251 additions and 178 deletions

View File

@@ -39,10 +39,12 @@ class ReqToTokenPool:
class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.size = size
# mem_state is the reference counter.
# This can be promised:
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
self.total_ref_ct = 0
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.bool, device="cuda")
self.total_size = self.size
self.total_alloc = 0
# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
@@ -71,7 +73,9 @@ class TokenToKVPool:
addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
select_index = (
torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
)
if select_index.shape[0] < addition_size:
return None
@@ -105,26 +109,22 @@ class TokenToKVPool:
return select_index.to(torch.int32), start_loc, start_loc + need_size
def used_size(self):
return len(torch.nonzero(self.mem_state).squeeze(1))
return self.total_alloc
def available_size(self):
return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
return self.total_size - self.total_alloc + len(self.prefetch_buffer)
def add_refs(self, token_index: torch.Tensor):
self.total_ref_ct += len(token_index)
self.mem_state[token_index] += 1
self.total_alloc += len(token_index)
self.mem_state[token_index] ^= True
def dec_refs(self, token_index: torch.Tensor):
self.total_ref_ct -= len(token_index)
self.mem_state[token_index] -= 1
num_freed = torch.sum(self.mem_state[token_index] == 0)
return num_freed
self.total_alloc -= len(token_index)
self.mem_state[token_index] ^= True
def clear(self):
self.mem_state.fill_(0)
self.total_ref_ct = 0
self.total_alloc = 0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.add_refs(torch.tensor([0], dtype=torch.int32))
self.mem_state[0] = True