Clean up ascend allocator (#11152)
This commit is contained in:
@@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size)
|
||||
num_new_pages = get_num_new_pages(
|
||||
seq_lens=seq_lens_cpu,
|
||||
page_size=self.page_size,
|
||||
prefix_lens=prefix_lens_cpu,
|
||||
)
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
@@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = get_num_new_pages(
|
||||
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True
|
||||
seq_lens=seq_lens_cpu,
|
||||
page_size=self.page_size,
|
||||
decode=True,
|
||||
)
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
from sglang.srt.utils import get_num_new_pages
|
||||
|
||||
|
||||
def alloc_extend_kernel_ascend(
|
||||
@@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
num_new_pages = (
|
||||
(
|
||||
(seq_lens_cpu + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens_cpu + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
num_new_pages = get_num_new_pages(
|
||||
seq_lens=seq_lens_cpu,
|
||||
page_size=self.page_size,
|
||||
prefix_lens=prefix_lens_cpu,
|
||||
)
|
||||
if self.need_sort and num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
@@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
need_new_pages = (seq_lens % self.page_size == 1).int()
|
||||
need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int()
|
||||
num_new_pages = need_new_pages_cpu.sum().item()
|
||||
num_new_pages = get_num_new_pages(
|
||||
seq_lens=seq_lens_cpu,
|
||||
page_size=self.page_size,
|
||||
decode=True,
|
||||
)
|
||||
|
||||
if num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
@@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
need_new_pages = (seq_lens % self.page_size == 1).int()
|
||||
end_new_pages = torch.cumsum(need_new_pages, 0)
|
||||
start_new_pages = end_new_pages - need_new_pages
|
||||
if num_new_pages == 0:
|
||||
|
||||
@@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit(
|
||||
|
||||
|
||||
def get_num_new_pages(
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_size: int,
|
||||
prefix_lens: Optional[torch.Tensor] = None,
|
||||
decode: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
|
||||
Get the number of new pages for the given prefix and sequence lengths.
|
||||
We use cpu tensors to avoid blocking kernel launch.
|
||||
"""
|
||||
cpu_device = torch.device("cpu")
|
||||
assert prefix_lens.device == cpu_device
|
||||
assert seq_lens.device == cpu_device
|
||||
|
||||
if prefix_lens is None or decode:
|
||||
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
|
||||
assert decode
|
||||
return (seq_lens % page_size == 1).int().sum().item()
|
||||
|
||||
assert prefix_lens.device == cpu_device
|
||||
num_pages_after = (seq_lens + page_size - 1) // page_size
|
||||
num_pages_before = (prefix_lens + page_size - 1) // page_size
|
||||
num_new_pages = num_pages_after - num_pages_before
|
||||
|
||||
Reference in New Issue
Block a user