Clean up ascend allocator (#11152)

This commit is contained in:
Liangsheng Yin
2025-10-02 20:34:26 +08:00
committed by GitHub
parent 083629c235
commit 7d00479950
3 changed files with 29 additions and 20 deletions

View File

@@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size)
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
prefix_lens=prefix_lens_cpu,
)
if num_new_pages > len(self.free_pages):
return None
@@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages(
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True
seq_lens=seq_lens_cpu,
page_size=self.page_size,
decode=True,
)
if num_new_pages > len(self.free_pages):
return None

View File

@@ -1,13 +1,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_num_new_pages
def alloc_extend_kernel_ascend(
@@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
num_new_pages = (
(
(seq_lens_cpu + self.page_size - 1) // self.page_size
- (prefix_lens_cpu + self.page_size - 1) // self.page_size
)
.sum()
.item()
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
prefix_lens=prefix_lens_cpu,
)
if self.need_sort and num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
@@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
need_new_pages = (seq_lens % self.page_size == 1).int()
need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int()
num_new_pages = need_new_pages_cpu.sum().item()
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
decode=True,
)
if num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
@@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
if num_new_pages > len(self.free_pages):
return None
need_new_pages = (seq_lens % self.page_size == 1).int()
end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0:

View File

@@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit(
def get_num_new_pages(
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
page_size: int,
prefix_lens: Optional[torch.Tensor] = None,
decode: bool = False,
) -> torch.Tensor:
"""
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
Get the number of new pages for the given prefix and sequence lengths.
We use cpu tensors to avoid blocking kernel launch.
"""
cpu_device = torch.device("cpu")
assert prefix_lens.device == cpu_device
assert seq_lens.device == cpu_device
if prefix_lens is None or decode:
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
assert decode
return (seq_lens % page_size == 1).int().sum().item()
assert prefix_lens.device == cpu_device
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (prefix_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before