Optimization for AscendPagedTokenToKVPoolAllocator (#8293)
Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: VDV1985 <vladdv85@mail.ru>
This commit is contained in:
@@ -632,27 +632,6 @@ def alloc_extend_kernel_ascend(
|
|||||||
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
||||||
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
||||||
).view(-1)
|
).view(-1)
|
||||||
return num_new_pages
|
|
||||||
|
|
||||||
|
|
||||||
def alloc_decode_kernel_ascend(
|
|
||||||
seq_lens,
|
|
||||||
last_loc,
|
|
||||||
free_pages,
|
|
||||||
out_indices,
|
|
||||||
page_size,
|
|
||||||
):
|
|
||||||
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
|
||||||
seq_lens - 1 + page_size - 1
|
|
||||||
) // page_size
|
|
||||||
end_new_pages = torch.cumsum(num_new_pages, 0)
|
|
||||||
start_new_pages = end_new_pages - num_new_pages
|
|
||||||
for i in range(len(seq_lens)):
|
|
||||||
if num_new_pages[i]:
|
|
||||||
out_indices[i] = free_pages[start_new_pages[i]] * page_size
|
|
||||||
else:
|
|
||||||
out_indices[i] = last_loc[i] + 1
|
|
||||||
return num_new_pages
|
|
||||||
|
|
||||||
|
|
||||||
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||||
@@ -667,7 +646,6 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
need_sort: bool,
|
need_sort: bool,
|
||||||
):
|
):
|
||||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||||
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
|
|
||||||
|
|
||||||
def alloc_extend(
|
def alloc_extend(
|
||||||
self,
|
self,
|
||||||
@@ -681,17 +659,25 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
bs = len(prefix_lens)
|
estimated_num_new_pages = (
|
||||||
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
|
(
|
||||||
self.free_pages
|
(seq_lens + self.page_size - 1) // self.page_size
|
||||||
):
|
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
|
||||||
self.merge_and_sort_free()
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
|
if estimated_num_new_pages > len(self.free_pages):
|
||||||
|
return None
|
||||||
|
|
||||||
out_indices = torch.empty(
|
out_indices = torch.empty(
|
||||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ret_values = alloc_extend_kernel_ascend(
|
alloc_extend_kernel_ascend(
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
last_loc,
|
last_loc,
|
||||||
@@ -704,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
if self.debug_mode:
|
if self.debug_mode:
|
||||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||||
|
|
||||||
num_new_pages = self.ret_values.sum()
|
self.free_pages = self.free_pages[estimated_num_new_pages:]
|
||||||
if num_new_pages > len(self.free_pages):
|
|
||||||
return None
|
|
||||||
|
|
||||||
self.free_pages = self.free_pages[num_new_pages:]
|
|
||||||
return out_indices
|
return out_indices
|
||||||
|
|
||||||
def alloc_decode(
|
def alloc_decode(
|
||||||
@@ -721,33 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
bs = len(seq_lens)
|
need_new_pages = (seq_lens % self.page_size == 1).int()
|
||||||
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
|
num_new_pages = need_new_pages.sum().item()
|
||||||
self.free_pages
|
|
||||||
):
|
if num_new_pages > len(self.free_pages):
|
||||||
self.merge_and_sort_free()
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
if num_new_pages > len(self.free_pages):
|
||||||
|
return None
|
||||||
|
|
||||||
self.ret_values = alloc_decode_kernel_ascend(
|
end_new_pages = torch.cumsum(need_new_pages, 0)
|
||||||
seq_lens,
|
start_new_pages = end_new_pages - need_new_pages
|
||||||
last_loc,
|
if num_new_pages == 0:
|
||||||
self.free_pages,
|
out_indices = last_loc + 1
|
||||||
out_indices,
|
else:
|
||||||
self.page_size,
|
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
||||||
)
|
start_new_pages
|
||||||
|
] * self.page_size * need_new_pages
|
||||||
|
|
||||||
if self.debug_mode:
|
if self.debug_mode:
|
||||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||||
|
|
||||||
num_new_pages = self.ret_values.sum()
|
|
||||||
if num_new_pages > len(self.free_pages):
|
|
||||||
return None
|
|
||||||
|
|
||||||
self.free_pages = self.free_pages[num_new_pages:]
|
self.free_pages = self.free_pages[num_new_pages:]
|
||||||
return out_indices
|
return out_indices.int()
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
super().clear()
|
|
||||||
self.free_pages = self.free_pages.to(torch.int32)
|
|
||||||
self.release_pages = self.release_pages.to(torch.int32)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user