[Performance][PD Disaggregation] optimize TokenToKVPoolAllocator by sorting free pages (#8133)
Signed-off-by: Xingrui Yi <yixingrui@linux.alibaba.com> Co-authored-by: Xingrui Yi <yixingrui@linux.alibaba.com>
This commit is contained in:
@@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|||||||
self._kvcache = kvcache
|
self._kvcache = kvcache
|
||||||
|
|
||||||
self.free_pages = None
|
self.free_pages = None
|
||||||
|
self.release_pages = None
|
||||||
self.is_not_in_free_group = True
|
self.is_not_in_free_group = True
|
||||||
self.free_group = []
|
self.free_group = []
|
||||||
|
|
||||||
@@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
return len(self.free_pages) * self.page_size
|
return (len(self.free_pages) + len(self.release_pages)) * self.page_size
|
||||||
|
|
||||||
def get_kvcache(self):
|
def get_kvcache(self):
|
||||||
return self._kvcache
|
return self._kvcache
|
||||||
|
|
||||||
def restore_state(self, free_pages):
|
def restore_state(self, state):
|
||||||
self.free_pages = free_pages
|
self.free_pages, self.release_pages = state
|
||||||
|
|
||||||
def backup_state(self):
|
def backup_state(self):
|
||||||
return self.free_pages
|
return (self.free_pages, self.release_pages)
|
||||||
|
|
||||||
def free_group_begin(self):
|
def free_group_begin(self):
|
||||||
self.is_not_in_free_group = False
|
self.is_not_in_free_group = False
|
||||||
@@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|||||||
if self.free_group:
|
if self.free_group:
|
||||||
self.free(torch.cat(self.free_group))
|
self.free(torch.cat(self.free_group))
|
||||||
|
|
||||||
|
def merge_and_sort_free(self):
|
||||||
|
if len(self.release_pages) > 0:
|
||||||
|
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
||||||
|
self.free_pages, _ = torch.sort(self.free_pages)
|
||||||
|
self.release_pages = torch.empty(
|
||||||
|
(0,), dtype=self.release_pages.dtype, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
def get_cpu_copy(self, *args, **kwargs):
|
def get_cpu_copy(self, *args, **kwargs):
|
||||||
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
)
|
)
|
||||||
self.is_not_in_free_group = True
|
self.is_not_in_free_group = True
|
||||||
self.free_group = []
|
self.free_group = []
|
||||||
|
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
# To avoid minor "len(free_pages) * 1" overhead
|
# To avoid minor "len(free_pages) * 1" overhead
|
||||||
return len(self.free_pages)
|
return len(self.free_pages) + len(self.release_pages)
|
||||||
|
|
||||||
def alloc(self, need_size: int):
|
def alloc(self, need_size: int):
|
||||||
|
if need_size > len(self.free_pages):
|
||||||
|
self.merge_and_sort_free()
|
||||||
if need_size > len(self.free_pages):
|
if need_size > len(self.free_pages):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.is_not_in_free_group:
|
if self.is_not_in_free_group:
|
||||||
self.free_pages = torch.cat((self.free_pages, free_index))
|
self.release_pages = torch.cat((self.release_pages, free_index))
|
||||||
else:
|
else:
|
||||||
self.free_group.append(free_index)
|
self.free_group.append(free_index)
|
||||||
|
|
||||||
@@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
), "The allocation size should be page-aligned"
|
), "The allocation size should be page-aligned"
|
||||||
|
|
||||||
num_pages = need_size // self.page_size
|
num_pages = need_size // self.page_size
|
||||||
|
if num_pages > len(self.free_pages):
|
||||||
|
self.merge_and_sort_free()
|
||||||
if num_pages > len(self.free_pages):
|
if num_pages > len(self.free_pages):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
estimated_num_new_pages = (
|
||||||
|
(
|
||||||
|
(seq_lens + self.page_size - 1) // self.page_size
|
||||||
|
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
if estimated_num_new_pages > len(self.free_pages):
|
||||||
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
bs = len(prefix_lens)
|
bs = len(prefix_lens)
|
||||||
out_indices = torch.empty(
|
out_indices = torch.empty(
|
||||||
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
||||||
@@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
estimated_num_new_pages = (
|
||||||
|
(
|
||||||
|
(seq_lens + self.page_size - 1) // self.page_size
|
||||||
|
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
if estimated_num_new_pages > len(self.free_pages):
|
||||||
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
bs = len(seq_lens)
|
bs = len(seq_lens)
|
||||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||||
alloc_decode_kernel[(bs,)](
|
alloc_decode_kernel[(bs,)](
|
||||||
@@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
|
|
||||||
if self.is_not_in_free_group:
|
if self.is_not_in_free_group:
|
||||||
free_page_indices = torch.unique(free_index // self.page_size)
|
free_page_indices = torch.unique(free_index // self.page_size)
|
||||||
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
self.release_pages = torch.cat((free_page_indices, self.release_pages))
|
||||||
else:
|
else:
|
||||||
self.free_group.append(free_index)
|
self.free_group.append(free_index)
|
||||||
|
|
||||||
@@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
)
|
)
|
||||||
self.is_not_in_free_group = True
|
self.is_not_in_free_group = True
|
||||||
self.free_group = []
|
self.free_group = []
|
||||||
|
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
def get_cpu_copy(self, indices):
|
def get_cpu_copy(self, indices):
|
||||||
return self._kvcache.get_cpu_copy(indices)
|
return self._kvcache.get_cpu_copy(indices)
|
||||||
@@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
estimated_num_new_pages = (
|
||||||
|
(
|
||||||
|
(seq_lens + self.page_size - 1) // self.page_size
|
||||||
|
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
if estimated_num_new_pages > len(self.free_pages):
|
||||||
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
bs = len(prefix_lens)
|
bs = len(prefix_lens)
|
||||||
out_indices = torch.empty(
|
out_indices = torch.empty(
|
||||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||||
@@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
estimated_num_new_pages = (
|
||||||
|
(
|
||||||
|
(seq_lens + self.page_size - 1) // self.page_size
|
||||||
|
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
if estimated_num_new_pages > len(self.free_pages):
|
||||||
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
bs = len(seq_lens)
|
bs = len(seq_lens)
|
||||||
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
@@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
def clear(self):
|
def clear(self):
|
||||||
super().clear()
|
super().clear()
|
||||||
self.free_pages = self.free_pages.to(torch.int32)
|
self.free_pages = self.free_pages.to(torch.int32)
|
||||||
|
self.release_pages = self.release_pages.to(torch.int32)
|
||||||
|
|||||||
Reference in New Issue
Block a user