Unify index operations (#620)
This commit is contained in:
@@ -20,8 +20,9 @@ class ReqToTokenPool:
|
||||
return None
|
||||
|
||||
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
|
||||
self.mem_state[select_index] = 0
|
||||
self.mem_state[select_index] = False
|
||||
self.can_use_mem_size -= need_size
|
||||
|
||||
return select_index.to(torch.int32)
|
||||
|
||||
def free(self, free_index):
|
||||
@@ -29,22 +30,23 @@ class ReqToTokenPool:
|
||||
self.can_use_mem_size += 1
|
||||
else:
|
||||
self.can_use_mem_size += free_index.shape[0]
|
||||
self.mem_state[free_index] = 1
|
||||
|
||||
self.mem_state[free_index] = True
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(1)
|
||||
self.mem_state.fill_(True)
|
||||
self.can_use_mem_size = len(self.mem_state)
|
||||
|
||||
|
||||
class TokenToKVPool:
|
||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||
self.size = size
|
||||
|
||||
# This can be promised:
|
||||
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||
self.total_size = self.size
|
||||
self.total_alloc = 0
|
||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||
self.can_use_mem_size = self.size
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
self.kv_data = [
|
||||
@@ -73,9 +75,8 @@ class TokenToKVPool:
|
||||
|
||||
addition_size = need_size - buffer_len
|
||||
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
||||
select_index = (
|
||||
torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
|
||||
)
|
||||
select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size]
|
||||
select_index = select_index.to(torch.int32)
|
||||
|
||||
if select_index.shape[0] < addition_size:
|
||||
return None
|
||||
@@ -88,43 +89,20 @@ class TokenToKVPool:
|
||||
|
||||
return ret_index
|
||||
|
||||
def alloc_contiguous(self, need_size):
|
||||
# NOTE: This function is deprecated.
|
||||
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
|
||||
if empty_index.shape[0] < need_size:
|
||||
return None
|
||||
empty_size = len(empty_index)
|
||||
loc_sum = (
|
||||
empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)]
|
||||
)
|
||||
can_used_loc = empty_index[: empty_size - (need_size - 1)][
|
||||
loc_sum == need_size - 1
|
||||
]
|
||||
if can_used_loc.shape[0] == 0:
|
||||
return None
|
||||
|
||||
start_loc = can_used_loc[0].item()
|
||||
select_index = torch.arange(start_loc, start_loc + need_size, device="cuda")
|
||||
self.add_refs(select_index)
|
||||
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
||||
|
||||
def used_size(self):
|
||||
return self.total_alloc
|
||||
|
||||
def available_size(self):
|
||||
return self.total_size - self.total_alloc + len(self.prefetch_buffer)
|
||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
||||
|
||||
def add_refs(self, token_index: torch.Tensor):
|
||||
self.total_alloc += len(token_index)
|
||||
self.mem_state[token_index] ^= True
|
||||
self.can_use_mem_size -= len(token_index)
|
||||
self.mem_state[token_index] = False
|
||||
|
||||
def dec_refs(self, token_index: torch.Tensor):
|
||||
self.total_alloc -= len(token_index)
|
||||
self.mem_state[token_index] ^= True
|
||||
self.can_use_mem_size += len(token_index)
|
||||
self.mem_state[token_index] = True
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(0)
|
||||
self.total_alloc = 0
|
||||
self.mem_state.fill_(True)
|
||||
self.can_use_mem_size = self.size
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state[0] = True
|
||||
self.mem_state[0] = False
|
||||
|
||||
Reference in New Issue
Block a user