diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 64e5447b6..0bf8cc2e1 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -632,27 +632,6 @@ def alloc_extend_kernel_ascend( out_indices[end_pos[i] - num3 : end_pos[i]] = ( free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3] ).view(-1) - return num_new_pages - - -def alloc_decode_kernel_ascend( - seq_lens, - last_loc, - free_pages, - out_indices, - page_size, -): - num_new_pages = (seq_lens + page_size - 1) // page_size - ( - seq_lens - 1 + page_size - 1 - ) // page_size - end_new_pages = torch.cumsum(num_new_pages, 0) - start_new_pages = end_new_pages - num_new_pages - for i in range(len(seq_lens)): - if num_new_pages[i]: - out_indices[i] = free_pages[start_new_pages[i]] * page_size - else: - out_indices[i] = last_loc[i] + 1 - return num_new_pages class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): @@ -667,7 +646,6 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): need_sort: bool, ): super().__init__(size, page_size, dtype, device, kvcache, need_sort) - self.ret_values = torch.empty((), dtype=torch.int32, device=self.device) def alloc_extend( self, @@ -681,17 +659,25 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) - bs = len(prefix_lens) - if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len( - self.free_pages - ): + estimated_num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (prefix_lens + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if self.need_sort and estimated_num_new_pages > len(self.free_pages): self.merge_and_sort_free() + if estimated_num_new_pages > len(self.free_pages): + return None + out_indices = torch.empty( (extend_num_tokens,), dtype=torch.int32, device=self.device ) - self.ret_values = alloc_extend_kernel_ascend( + alloc_extend_kernel_ascend( prefix_lens, seq_lens, last_loc, @@ -704,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): if self.debug_mode: assert len(torch.unique(out_indices)) == len(out_indices) - num_new_pages = self.ret_values.sum() - if num_new_pages > len(self.free_pages): - return None - - self.free_pages = self.free_pages[num_new_pages:] + self.free_pages = self.free_pages[estimated_num_new_pages:] return out_indices def alloc_decode( @@ -721,33 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): (last_loc + 2) % self.page_size == seq_lens % self.page_size ) - bs = len(seq_lens) - if self.need_sort and self.estimated_num_new_pages(bs, 1) > len( - self.free_pages - ): + need_new_pages = (seq_lens % self.page_size == 1).int() + num_new_pages = need_new_pages.sum().item() + + if num_new_pages > len(self.free_pages): self.merge_and_sort_free() - out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device) + if num_new_pages > len(self.free_pages): + return None - self.ret_values = alloc_decode_kernel_ascend( - seq_lens, - last_loc, - self.free_pages, - out_indices, - self.page_size, - ) + end_new_pages = torch.cumsum(need_new_pages, 0) + start_new_pages = end_new_pages - need_new_pages + if num_new_pages == 0: + out_indices = last_loc + 1 + else: + out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[ + start_new_pages + ] * self.page_size * need_new_pages if self.debug_mode: assert len(torch.unique(out_indices)) == len(out_indices) - num_new_pages = self.ret_values.sum() - if num_new_pages > len(self.free_pages): - return None - self.free_pages = self.free_pages[num_new_pages:] - return out_indices - - def clear(self): - super().clear() - self.free_pages = self.free_pages.to(torch.int32) - self.release_pages = self.release_pages.to(torch.int32) + return out_indices.int()