diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 7dd488e9c..58afbf312 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC): self._kvcache = kvcache self.free_pages = None + self.release_pages = None self.is_not_in_free_group = True self.free_group = [] @@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC): return "" def available_size(self): - return len(self.free_pages) * self.page_size + return (len(self.free_pages) + len(self.release_pages)) * self.page_size def get_kvcache(self): return self._kvcache - def restore_state(self, free_pages): - self.free_pages = free_pages + def restore_state(self, state): + self.free_pages, self.release_pages = state def backup_state(self): - return self.free_pages + return (self.free_pages, self.release_pages) def free_group_begin(self): self.is_not_in_free_group = False @@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC): if self.free_group: self.free(torch.cat(self.free_group)) + def merge_and_sort_free(self): + if len(self.release_pages) > 0: + self.free_pages = torch.cat((self.free_pages, self.release_pages)) + self.free_pages, _ = torch.sort(self.free_pages) + self.release_pages = torch.empty( + (0,), dtype=self.release_pages.dtype, device=self.device + ) + def get_cpu_copy(self, *args, **kwargs): # FIXME: reuse the get_cpu_copy after paged allocator is implemented raise NotImplementedError() @@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ) self.is_not_in_free_group = True self.free_group = [] + self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device) def available_size(self): # To avoid minor "len(free_pages) * 1" overhead - return len(self.free_pages) + return len(self.free_pages) + len(self.release_pages) def alloc(self, need_size: int): + if need_size > len(self.free_pages): + self.merge_and_sort_free() if need_size > len(self.free_pages): return None @@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): return if self.is_not_in_free_group: - self.free_pages = torch.cat((self.free_pages, free_index)) + self.release_pages = torch.cat((self.release_pages, free_index)) else: self.free_group.append(free_index) @@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ), "The allocation size should be page-aligned" num_pages = need_size // self.page_size + if num_pages > len(self.free_pages): + self.merge_and_sort_free() if num_pages > len(self.free_pages): return None @@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) + estimated_num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (prefix_lens + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if estimated_num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + bs = len(prefix_lens) out_indices = torch.empty( (extend_num_tokens,), dtype=torch.int64, device=self.device @@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): (last_loc + 2) % self.page_size == seq_lens % self.page_size ) + estimated_num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (seq_lens - 1 + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if estimated_num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + bs = len(seq_lens) out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) alloc_decode_kernel[(bs,)]( @@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): if self.is_not_in_free_group: free_page_indices = torch.unique(free_index // self.page_size) - self.free_pages = torch.cat((free_page_indices, self.free_pages)) + self.release_pages = torch.cat((free_page_indices, self.release_pages)) else: self.free_group.append(free_index) @@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ) self.is_not_in_free_group = True self.free_group = [] + self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device) def get_cpu_copy(self, indices): return self._kvcache.get_cpu_copy(indices) @@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) + estimated_num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (prefix_lens + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if estimated_num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + bs = len(prefix_lens) out_indices = torch.empty( (extend_num_tokens,), dtype=torch.int32, device=self.device @@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): (last_loc + 2) % self.page_size == seq_lens % self.page_size ) + estimated_num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (seq_lens - 1 + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if estimated_num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + bs = len(seq_lens) out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device) @@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): def clear(self): super().clear() self.free_pages = self.free_pages.to(torch.int32) + self.release_pages = self.release_pages.to(torch.int32)