[Optimization] Update estimated_num_new_pages logic in TokenToKVPoolAllocator (#8794)
Signed-off-by: Xingrui Yi <yixingrui@linux.alibaba.com> Co-authored-by: Xingrui Yi <yixingrui@linux.alibaba.com>
This commit is contained in:
@@ -43,12 +43,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
need_sort: bool,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self._kvcache = kvcache
|
||||
self.need_sort = need_sort
|
||||
|
||||
self.free_pages = None
|
||||
self.release_pages = None
|
||||
@@ -79,6 +81,9 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def estimated_num_new_pages(self, bs, extend_num_tokens):
|
||||
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
|
||||
|
||||
def merge_and_sort_free(self):
|
||||
if len(self.release_pages) > 0:
|
||||
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
||||
@@ -117,8 +122,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
"""An allocator managing the indices to kv cache data."""
|
||||
|
||||
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
|
||||
super().__init__(size, 1, dtype, device, kvcache)
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
need_sort: bool,
|
||||
):
|
||||
super().__init__(size, 1, dtype, device, kvcache, need_sort)
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
@@ -135,7 +147,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
return len(self.free_pages) + len(self.release_pages)
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > len(self.free_pages):
|
||||
if self.need_sort and need_size > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
if need_size > len(self.free_pages):
|
||||
return None
|
||||
@@ -149,7 +161,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.release_pages = torch.cat((self.release_pages, free_index))
|
||||
if self.need_sort:
|
||||
self.release_pages = torch.cat((self.release_pages, free_index))
|
||||
else:
|
||||
self.free_pages = torch.cat((self.free_pages, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
@@ -170,8 +185,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: SWAKVPool,
|
||||
need_sort: bool,
|
||||
):
|
||||
super().__init__(size, 1, dtype, device, kvcache)
|
||||
super().__init__(size, 1, dtype, device, kvcache, need_sort)
|
||||
assert isinstance(kvcache, SWAKVPool)
|
||||
self._size_full = size
|
||||
self._size_swa = size_swa
|
||||
@@ -180,12 +196,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
dtype,
|
||||
device,
|
||||
kvcache.full_kv_pool,
|
||||
need_sort,
|
||||
)
|
||||
self.swa_attn_allocator = TokenToKVPoolAllocator(
|
||||
size_swa,
|
||||
dtype,
|
||||
device,
|
||||
kvcache.swa_kv_pool,
|
||||
need_sort,
|
||||
)
|
||||
self.full_to_swa_index_mapping = torch.empty(
|
||||
size + size_swa + 1,
|
||||
@@ -418,8 +436,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
need_sort: bool,
|
||||
):
|
||||
super().__init__(size, page_size, dtype, device, kvcache)
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||
self.num_pages = size // page_size
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
@@ -433,7 +452,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
), "The allocation size should be page-aligned"
|
||||
|
||||
num_pages = need_size // self.page_size
|
||||
if num_pages > len(self.free_pages):
|
||||
if self.need_sort and num_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
if num_pages > len(self.free_pages):
|
||||
return None
|
||||
@@ -460,18 +479,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
bs = len(prefix_lens)
|
||||
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
|
||||
self.free_pages
|
||||
):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
bs = len(prefix_lens)
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
@@ -508,18 +521,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
bs = len(seq_lens)
|
||||
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
|
||||
self.free_pages
|
||||
):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
bs = len(seq_lens)
|
||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
alloc_decode_kernel[(bs,)](
|
||||
seq_lens,
|
||||
@@ -547,7 +554,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
free_page_indices = torch.unique(free_index // self.page_size)
|
||||
self.release_pages = torch.cat((free_page_indices, self.release_pages))
|
||||
if self.need_sort:
|
||||
self.release_pages = torch.cat((free_page_indices, self.release_pages))
|
||||
else:
|
||||
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
@@ -654,8 +664,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
need_sort: bool,
|
||||
):
|
||||
super().__init__(size, page_size, dtype, device, kvcache)
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
|
||||
|
||||
def alloc_extend(
|
||||
@@ -670,18 +681,12 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
bs = len(prefix_lens)
|
||||
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
|
||||
self.free_pages
|
||||
):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
bs = len(prefix_lens)
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
@@ -716,18 +721,12 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
bs = len(seq_lens)
|
||||
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
|
||||
self.free_pages
|
||||
):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
bs = len(seq_lens)
|
||||
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
|
||||
self.ret_values = alloc_decode_kernel_ascend(
|
||||
|
||||
@@ -1300,6 +1300,8 @@ class ModelRunner:
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=self.server_args.disaggregation_mode
|
||||
in ("decode", "prefill"),
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
@@ -1307,6 +1309,8 @@ class ModelRunner:
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=self.server_args.disaggregation_mode
|
||||
in ("decode", "prefill"),
|
||||
)
|
||||
else:
|
||||
if _is_npu:
|
||||
@@ -1316,6 +1320,8 @@ class ModelRunner:
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=self.server_args.disaggregation_mode
|
||||
in ("decode", "prefill"),
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
||||
@@ -1324,6 +1330,8 @@ class ModelRunner:
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=self.server_args.disaggregation_mode
|
||||
in ("decode", "prefill"),
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
|
||||
Reference in New Issue
Block a user