Clean up ascend allocator (#11152)
This commit is contained in:
@@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
if self.debug_mode:
|
if self.debug_mode:
|
||||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||||
|
|
||||||
num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size)
|
num_new_pages = get_num_new_pages(
|
||||||
|
seq_lens=seq_lens_cpu,
|
||||||
|
page_size=self.page_size,
|
||||||
|
prefix_lens=prefix_lens_cpu,
|
||||||
|
)
|
||||||
if num_new_pages > len(self.free_pages):
|
if num_new_pages > len(self.free_pages):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||||
|
|
||||||
num_new_pages = get_num_new_pages(
|
num_new_pages = get_num_new_pages(
|
||||||
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True
|
seq_lens=seq_lens_cpu,
|
||||||
|
page_size=self.page_size,
|
||||||
|
decode=True,
|
||||||
)
|
)
|
||||||
if num_new_pages > len(self.free_pages):
|
if num_new_pages > len(self.free_pages):
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
||||||
|
from sglang.srt.utils import get_num_new_pages
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
|
||||||
|
|
||||||
|
|
||||||
def alloc_extend_kernel_ascend(
|
def alloc_extend_kernel_ascend(
|
||||||
@@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
num_new_pages = (
|
num_new_pages = get_num_new_pages(
|
||||||
(
|
seq_lens=seq_lens_cpu,
|
||||||
(seq_lens_cpu + self.page_size - 1) // self.page_size
|
page_size=self.page_size,
|
||||||
- (prefix_lens_cpu + self.page_size - 1) // self.page_size
|
prefix_lens=prefix_lens_cpu,
|
||||||
)
|
|
||||||
.sum()
|
|
||||||
.item()
|
|
||||||
)
|
)
|
||||||
if self.need_sort and num_new_pages > len(self.free_pages):
|
if self.need_sort and num_new_pages > len(self.free_pages):
|
||||||
self.merge_and_sort_free()
|
self.merge_and_sort_free()
|
||||||
@@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
need_new_pages = (seq_lens % self.page_size == 1).int()
|
num_new_pages = get_num_new_pages(
|
||||||
need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int()
|
seq_lens=seq_lens_cpu,
|
||||||
num_new_pages = need_new_pages_cpu.sum().item()
|
page_size=self.page_size,
|
||||||
|
decode=True,
|
||||||
|
)
|
||||||
|
|
||||||
if num_new_pages > len(self.free_pages):
|
if num_new_pages > len(self.free_pages):
|
||||||
self.merge_and_sort_free()
|
self.merge_and_sort_free()
|
||||||
@@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
if num_new_pages > len(self.free_pages):
|
if num_new_pages > len(self.free_pages):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
need_new_pages = (seq_lens % self.page_size == 1).int()
|
||||||
end_new_pages = torch.cumsum(need_new_pages, 0)
|
end_new_pages = torch.cumsum(need_new_pages, 0)
|
||||||
start_new_pages = end_new_pages - need_new_pages
|
start_new_pages = end_new_pages - need_new_pages
|
||||||
if num_new_pages == 0:
|
if num_new_pages == 0:
|
||||||
|
|||||||
@@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit(
|
|||||||
|
|
||||||
|
|
||||||
def get_num_new_pages(
|
def get_num_new_pages(
|
||||||
prefix_lens: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
|
prefix_lens: Optional[torch.Tensor] = None,
|
||||||
decode: bool = False,
|
decode: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
|
Get the number of new pages for the given prefix and sequence lengths.
|
||||||
|
We use cpu tensors to avoid blocking kernel launch.
|
||||||
"""
|
"""
|
||||||
cpu_device = torch.device("cpu")
|
cpu_device = torch.device("cpu")
|
||||||
assert prefix_lens.device == cpu_device
|
|
||||||
assert seq_lens.device == cpu_device
|
assert seq_lens.device == cpu_device
|
||||||
|
|
||||||
|
if prefix_lens is None or decode:
|
||||||
|
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
|
||||||
|
assert decode
|
||||||
|
return (seq_lens % page_size == 1).int().sum().item()
|
||||||
|
|
||||||
|
assert prefix_lens.device == cpu_device
|
||||||
num_pages_after = (seq_lens + page_size - 1) // page_size
|
num_pages_after = (seq_lens + page_size - 1) // page_size
|
||||||
num_pages_before = (prefix_lens + page_size - 1) // page_size
|
num_pages_before = (prefix_lens + page_size - 1) // page_size
|
||||||
num_new_pages = num_pages_after - num_pages_before
|
num_new_pages = num_pages_after - num_pages_before
|
||||||
|
|||||||
Reference in New Issue
Block a user