Clean up ascend allocator (#11152)

This commit is contained in:
Liangsheng Yin
2025-10-02 20:34:26 +08:00
committed by GitHub
parent 083629c235
commit 7d00479950
3 changed files with 29 additions and 20 deletions

View File

@@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size) num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
prefix_lens=prefix_lens_cpu,
)
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
@@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages( num_new_pages = get_num_new_pages(
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True seq_lens=seq_lens_cpu,
page_size=self.page_size,
decode=True,
) )
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None

View File

@@ -1,13 +1,9 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING
import torch import torch
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
from sglang.srt.utils import get_num_new_pages
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
def alloc_extend_kernel_ascend( def alloc_extend_kernel_ascend(
@@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size (last_loc + 1) % self.page_size == prefix_lens % self.page_size
) )
num_new_pages = ( num_new_pages = get_num_new_pages(
( seq_lens=seq_lens_cpu,
(seq_lens_cpu + self.page_size - 1) // self.page_size page_size=self.page_size,
- (prefix_lens_cpu + self.page_size - 1) // self.page_size prefix_lens=prefix_lens_cpu,
)
.sum()
.item()
) )
if self.need_sort and num_new_pages > len(self.free_pages): if self.need_sort and num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
@@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size (last_loc + 2) % self.page_size == seq_lens % self.page_size
) )
need_new_pages = (seq_lens % self.page_size == 1).int() num_new_pages = get_num_new_pages(
need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int() seq_lens=seq_lens_cpu,
num_new_pages = need_new_pages_cpu.sum().item() page_size=self.page_size,
decode=True,
)
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
@@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
need_new_pages = (seq_lens % self.page_size == 1).int()
end_new_pages = torch.cumsum(need_new_pages, 0) end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0: if num_new_pages == 0:

View File

@@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit(
def get_num_new_pages( def get_num_new_pages(
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
page_size: int, page_size: int,
prefix_lens: Optional[torch.Tensor] = None,
decode: bool = False, decode: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch. Get the number of new pages for the given prefix and sequence lengths.
We use cpu tensors to avoid blocking kernel launch.
""" """
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
assert prefix_lens.device == cpu_device
assert seq_lens.device == cpu_device assert seq_lens.device == cpu_device
if prefix_lens is None or decode:
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
assert decode
return (seq_lens % page_size == 1).int().sum().item()
assert prefix_lens.device == cpu_device
num_pages_after = (seq_lens + page_size - 1) // page_size num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (prefix_lens + page_size - 1) // page_size num_pages_before = (prefix_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before num_new_pages = num_pages_after - num_pages_before