Files
sglang/python/sglang/srt/mem_cache/allocator_ascend.py
2025-08-18 16:54:36 -07:00

148 lines
4.4 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
def alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
device,
):
extend_lens = seq_lens - prefix_lens
end_pos = torch.cumsum(extend_lens, 0)
start_pos = end_pos - extend_lens
num_new_pages = (seq_lens + page_size - 1) // page_size - (
prefix_lens + page_size - 1
) // page_size
num_full_new_pages = (seq_lens) // page_size - (
prefix_lens + page_size - 1
) // page_size
need_page = num_new_pages - num_full_new_pages
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
for i in range(len(prefix_lens)):
num1 = (
min(
seq_lens[i],
(prefix_lens[i] + page_size - 1) // page_size * page_size,
)
- prefix_lens[i]
)
if num1:
out_indices[start_pos[i] : start_pos[i] + num1] = (
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
)
num2 = (
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
) * page_size
if num2:
pages = (
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
* page_size
)
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
pages.view(-1, 1) + pos_in_page.view(1, -1)
).view(-1)
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
if num3:
out_indices[end_pos[i] - num3 : end_pos[i]] = (
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
).view(-1)
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if self.need_sort and num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
return None
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
need_new_pages = (seq_lens % self.page_size == 1).int()
num_new_pages = need_new_pages.sum().item()
if num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
return None
end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0:
out_indices = last_loc + 1
else:
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
start_new_pages
] * self.page_size * need_new_pages
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
return out_indices.int()