[Performance][PD Disaggregation] optimize TokenToKVPoolAllocator by sorting free pages (#8133)
Signed-off-by: Xingrui Yi <yixingrui@linux.alibaba.com> Co-authored-by: Xingrui Yi <yixingrui@linux.alibaba.com>
This commit is contained in:
@@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
self._kvcache = kvcache
|
||||
|
||||
self.free_pages = None
|
||||
self.release_pages = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
@@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
return ""
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
return (len(self.free_pages) + len(self.release_pages)) * self.page_size
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
def restore_state(self, free_pages):
|
||||
self.free_pages = free_pages
|
||||
def restore_state(self, state):
|
||||
self.free_pages, self.release_pages = state
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_pages
|
||||
return (self.free_pages, self.release_pages)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
@@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def merge_and_sort_free(self):
|
||||
if len(self.release_pages) > 0:
|
||||
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
||||
self.free_pages, _ = torch.sort(self.free_pages)
|
||||
self.release_pages = torch.empty(
|
||||
(0,), dtype=self.release_pages.dtype, device=self.device
|
||||
)
|
||||
|
||||
def get_cpu_copy(self, *args, **kwargs):
|
||||
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
||||
raise NotImplementedError()
|
||||
@@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
)
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
|
||||
def available_size(self):
|
||||
# To avoid minor "len(free_pages) * 1" overhead
|
||||
return len(self.free_pages)
|
||||
return len(self.free_pages) + len(self.release_pages)
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
if need_size > len(self.free_pages):
|
||||
return None
|
||||
|
||||
@@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_pages = torch.cat((self.free_pages, free_index))
|
||||
self.release_pages = torch.cat((self.release_pages, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
@@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
), "The allocation size should be page-aligned"
|
||||
|
||||
num_pages = need_size // self.page_size
|
||||
if num_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
if num_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
@@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
bs = len(prefix_lens)
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
||||
@@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
bs = len(seq_lens)
|
||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
alloc_decode_kernel[(bs,)](
|
||||
@@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
free_page_indices = torch.unique(free_index // self.page_size)
|
||||
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
||||
self.release_pages = torch.cat((free_page_indices, self.release_pages))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
@@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
)
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
return self._kvcache.get_cpu_copy(indices)
|
||||
@@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
bs = len(prefix_lens)
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||
@@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
bs = len(seq_lens)
|
||||
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
|
||||
@@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
def clear(self):
|
||||
super().clear()
|
||||
self.free_pages = self.free_pages.to(torch.int32)
|
||||
self.release_pages = self.release_pages.to(torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user