148 lines
4.4 KiB
Python
148 lines
4.4 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
|
|
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
|
|
|
|
|
def alloc_extend_kernel_ascend(
|
|
prefix_lens,
|
|
seq_lens,
|
|
last_loc,
|
|
free_pages,
|
|
out_indices,
|
|
page_size,
|
|
device,
|
|
):
|
|
extend_lens = seq_lens - prefix_lens
|
|
end_pos = torch.cumsum(extend_lens, 0)
|
|
start_pos = end_pos - extend_lens
|
|
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
|
prefix_lens + page_size - 1
|
|
) // page_size
|
|
num_full_new_pages = (seq_lens) // page_size - (
|
|
prefix_lens + page_size - 1
|
|
) // page_size
|
|
need_page = num_new_pages - num_full_new_pages
|
|
end_new_pages = torch.cumsum(num_new_pages, 0)
|
|
start_new_pages = end_new_pages - num_new_pages
|
|
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
|
for i in range(len(prefix_lens)):
|
|
num1 = (
|
|
min(
|
|
seq_lens[i],
|
|
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
|
)
|
|
- prefix_lens[i]
|
|
)
|
|
if num1:
|
|
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
|
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
|
)
|
|
|
|
num2 = (
|
|
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
|
) * page_size
|
|
if num2:
|
|
pages = (
|
|
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
|
* page_size
|
|
)
|
|
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
|
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
|
).view(-1)
|
|
|
|
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
|
if num3:
|
|
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
|
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
|
).view(-1)
|
|
|
|
|
|
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
|
|
def alloc_extend(
|
|
self,
|
|
prefix_lens: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
last_loc: torch.Tensor,
|
|
extend_num_tokens: int,
|
|
):
|
|
if self.debug_mode:
|
|
assert torch.all(
|
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
|
)
|
|
|
|
num_new_pages = (
|
|
(
|
|
(seq_lens + self.page_size - 1) // self.page_size
|
|
- (prefix_lens + self.page_size - 1) // self.page_size
|
|
)
|
|
.sum()
|
|
.item()
|
|
)
|
|
if self.need_sort and num_new_pages > len(self.free_pages):
|
|
self.merge_and_sort_free()
|
|
|
|
if num_new_pages > len(self.free_pages):
|
|
return None
|
|
|
|
out_indices = torch.empty(
|
|
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
alloc_extend_kernel_ascend(
|
|
prefix_lens,
|
|
seq_lens,
|
|
last_loc,
|
|
self.free_pages,
|
|
out_indices,
|
|
self.page_size,
|
|
self.device,
|
|
)
|
|
|
|
if self.debug_mode:
|
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
|
|
|
self.free_pages = self.free_pages[num_new_pages:]
|
|
return out_indices
|
|
|
|
def alloc_decode(
|
|
self,
|
|
seq_lens: torch.Tensor,
|
|
last_loc: torch.Tensor,
|
|
):
|
|
if self.debug_mode:
|
|
assert torch.all(
|
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
|
)
|
|
|
|
need_new_pages = (seq_lens % self.page_size == 1).int()
|
|
num_new_pages = need_new_pages.sum().item()
|
|
|
|
if num_new_pages > len(self.free_pages):
|
|
self.merge_and_sort_free()
|
|
|
|
if num_new_pages > len(self.free_pages):
|
|
return None
|
|
|
|
end_new_pages = torch.cumsum(need_new_pages, 0)
|
|
start_new_pages = end_new_pages - need_new_pages
|
|
if num_new_pages == 0:
|
|
out_indices = last_loc + 1
|
|
else:
|
|
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
|
start_new_pages
|
|
] * self.page_size * need_new_pages
|
|
|
|
if self.debug_mode:
|
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
|
|
|
self.free_pages = self.free_pages[num_new_pages:]
|
|
return out_indices.int()
|