diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index e3314ab60..9fef8d133 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): if self.debug_mode: assert len(torch.unique(out_indices)) == len(out_indices) - num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size) + num_new_pages = get_num_new_pages( + seq_lens=seq_lens_cpu, + page_size=self.page_size, + prefix_lens=prefix_lens_cpu, + ) if num_new_pages > len(self.free_pages): return None @@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): assert len(torch.unique(out_indices)) == len(out_indices) num_new_pages = get_num_new_pages( - seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True + seq_lens=seq_lens_cpu, + page_size=self.page_size, + decode=True, ) if num_new_pages > len(self.free_pages): return None diff --git a/python/sglang/srt/mem_cache/allocator_ascend.py b/python/sglang/srt/mem_cache/allocator_ascend.py index 546e3b45a..0bb1eaf0a 100644 --- a/python/sglang/srt/mem_cache/allocator_ascend.py +++ b/python/sglang/srt/mem_cache/allocator_ascend.py @@ -1,13 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import torch from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator - -if TYPE_CHECKING: - from sglang.srt.mem_cache.memory_pool import KVCache +from sglang.srt.utils import get_num_new_pages def alloc_extend_kernel_ascend( @@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) - num_new_pages = ( - ( - (seq_lens_cpu + self.page_size - 1) // self.page_size - - (prefix_lens_cpu + self.page_size - 1) // self.page_size - ) - .sum() - .item() + num_new_pages = get_num_new_pages( + seq_lens=seq_lens_cpu, + page_size=self.page_size, + prefix_lens=prefix_lens_cpu, ) if self.need_sort and num_new_pages > len(self.free_pages): self.merge_and_sort_free() @@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): (last_loc + 2) % self.page_size == seq_lens % self.page_size ) - need_new_pages = (seq_lens % self.page_size == 1).int() - need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int() - num_new_pages = need_new_pages_cpu.sum().item() + num_new_pages = get_num_new_pages( + seq_lens=seq_lens_cpu, + page_size=self.page_size, + decode=True, + ) if num_new_pages > len(self.free_pages): self.merge_and_sort_free() @@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): if num_new_pages > len(self.free_pages): return None + need_new_pages = (seq_lens % self.page_size == 1).int() end_new_pages = torch.cumsum(need_new_pages, 0) start_new_pages = end_new_pages - need_new_pages if num_new_pages == 0: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 28bc3f30f..2a26e029c 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit( def get_num_new_pages( - prefix_lens: torch.Tensor, seq_lens: torch.Tensor, page_size: int, + prefix_lens: Optional[torch.Tensor] = None, decode: bool = False, ) -> torch.Tensor: """ - Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch. + Get the number of new pages for the given prefix and sequence lengths. + We use cpu tensors to avoid blocking kernel launch. """ cpu_device = torch.device("cpu") - assert prefix_lens.device == cpu_device assert seq_lens.device == cpu_device + + if prefix_lens is None or decode: + # NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`. + assert decode + return (seq_lens % page_size == 1).int().sum().item() + + assert prefix_lens.device == cpu_device num_pages_after = (seq_lens + page_size - 1) // page_size num_pages_before = (prefix_lens + page_size - 1) // page_size num_new_pages = num_pages_after - num_pages_before