Optimization for AscendPagedTokenToKVPoolAllocator (#8293)
Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: VDV1985 <vladdv85@mail.ru>
This commit is contained in:
@@ -632,27 +632,6 @@ def alloc_extend_kernel_ascend(
|
||||
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
||||
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
||||
).view(-1)
|
||||
return num_new_pages
|
||||
|
||||
|
||||
def alloc_decode_kernel_ascend(
|
||||
seq_lens,
|
||||
last_loc,
|
||||
free_pages,
|
||||
out_indices,
|
||||
page_size,
|
||||
):
|
||||
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
||||
seq_lens - 1 + page_size - 1
|
||||
) // page_size
|
||||
end_new_pages = torch.cumsum(num_new_pages, 0)
|
||||
start_new_pages = end_new_pages - num_new_pages
|
||||
for i in range(len(seq_lens)):
|
||||
if num_new_pages[i]:
|
||||
out_indices[i] = free_pages[start_new_pages[i]] * page_size
|
||||
else:
|
||||
out_indices[i] = last_loc[i] + 1
|
||||
return num_new_pages
|
||||
|
||||
|
||||
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
@@ -667,7 +646,6 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
need_sort: bool,
|
||||
):
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
|
||||
|
||||
def alloc_extend(
|
||||
self,
|
||||
@@ -681,17 +659,25 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(prefix_lens)
|
||||
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
|
||||
self.free_pages
|
||||
):
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.ret_values = alloc_extend_kernel_ascend(
|
||||
alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
@@ -704,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = self.ret_values.sum()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
self.free_pages = self.free_pages[estimated_num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def alloc_decode(
|
||||
@@ -721,33 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(seq_lens)
|
||||
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
|
||||
self.free_pages
|
||||
):
|
||||
need_new_pages = (seq_lens % self.page_size == 1).int()
|
||||
num_new_pages = need_new_pages.sum().item()
|
||||
|
||||
if num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.ret_values = alloc_decode_kernel_ascend(
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.page_size,
|
||||
)
|
||||
end_new_pages = torch.cumsum(need_new_pages, 0)
|
||||
start_new_pages = end_new_pages - need_new_pages
|
||||
if num_new_pages == 0:
|
||||
out_indices = last_loc + 1
|
||||
else:
|
||||
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
||||
start_new_pages
|
||||
] * self.page_size * need_new_pages
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = self.ret_values.sum()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
self.free_pages = self.free_pages.to(torch.int32)
|
||||
self.release_pages = self.release_pages.to(torch.int32)
|
||||
return out_indices.int()
|
||||
|
||||
Reference in New Issue
Block a user