Memorypool chunked prefetch (#614)

This commit is contained in:
Liangsheng Yin
2024-07-13 15:24:03 -07:00
committed by GitHub
parent 65c6577696
commit 10143e1a5f
5 changed files with 30 additions and 39 deletions

View File

@@ -50,6 +50,10 @@ class TokenToKVPool:
for _ in range(layer_num)
]
# Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 256
self.clear()
def get_key_buffer(self, layer_id):
@@ -59,14 +63,29 @@ class TokenToKVPool:
return self.kv_data[layer_id][:, 1]
def alloc(self, need_size):
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
if select_index.shape[0] < need_size:
buffer_len = len(self.prefetch_buffer)
if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return select_index.to(torch.int32)
addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size]
if select_index.shape[0] < addition_size:
return None
self.add_refs(select_index)
return select_index.to(torch.int32)
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return ret_index.to(torch.int32)
def alloc_contiguous(self, need_size):
# NOTE: This function is deprecated.
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
if empty_index.shape[0] < need_size:
return None
@@ -89,7 +108,7 @@ class TokenToKVPool:
return len(torch.nonzero(self.mem_state).squeeze(1))
def available_size(self):
return torch.sum(self.mem_state == 0).item()
return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
def add_refs(self, token_index: torch.Tensor):
self.total_ref_ct += len(token_index)