Files
sglang/python/sglang/srt/mem_cache/allocator.py
2025-11-04 20:27:27 +08:00

595 lines
18 KiB
Python

from __future__ import annotations
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Page-aligned memory pool.
"""
import abc
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
class BaseTokenToKVPoolAllocator(abc.ABC):
@abc.abstractmethod
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
self._kvcache = kvcache
self.need_sort = need_sort
self.free_pages = None
self.release_pages = None
self.is_not_in_free_group = True
self.free_group = []
def debug_print(self) -> str:
return ""
def available_size(self):
return (len(self.free_pages) + len(self.release_pages)) * self.page_size
def get_kvcache(self):
return self._kvcache
def restore_state(self, state):
self.free_pages, self.release_pages = state
def backup_state(self):
return (self.free_pages, self.release_pages)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def merge_and_sort_free(self):
if len(self.release_pages) > 0:
self.free_pages = torch.cat((self.free_pages, self.release_pages))
self.free_pages, _ = torch.sort(self.free_pages)
self.release_pages = torch.empty(
(0,), dtype=self.release_pages.dtype, device=self.device
)
def get_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise NotImplementedError()
def load_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
raise NotImplementedError()
def alloc_extend(self, *args, **kwargs):
raise NotImplementedError("alloc_extend is only for paged allocator")
def alloc_decode(self, *args, **kwargs):
raise NotImplementedError("alloc_decode is only for paged allocator")
@abc.abstractmethod
def clear(self):
raise NotImplementedError()
@abc.abstractmethod
def alloc(self, need_size: int):
raise NotImplementedError()
@abc.abstractmethod
def free(self, free_index: torch.Tensor):
raise NotImplementedError()
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""An allocator managing the indices to kv cache data."""
def __init__(
self,
size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, 1, dtype, device, kvcache, need_sort)
self.clear()
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
def available_size(self):
# To avoid minor "len(free_pages) * 1" overhead
return len(self.free_pages) + len(self.release_pages)
def alloc(self, need_size: int):
if self.need_sort and need_size > len(self.free_pages):
self.merge_and_sort_free()
if need_size > len(self.free_pages):
return None
select_index = self.free_pages[:need_size]
self.free_pages = self.free_pages[need_size:]
return select_index
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
if self.need_sort:
self.release_pages = torch.cat((self.release_pages, free_index))
else:
self.free_pages = torch.cat((self.free_pages, free_index))
else:
self.free_group.append(free_index)
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""Allocator for SWA hybrid KV cache."""
def __init__(
self,
size: int,
size_swa: int,
dtype: torch.dtype,
device: str,
kvcache: SWAKVPool,
need_sort: bool,
):
super().__init__(size, 1, dtype, device, kvcache, need_sort)
assert isinstance(kvcache, SWAKVPool)
self._size_full = size
self._size_swa = size_swa
self.full_attn_allocator = TokenToKVPoolAllocator(
size,
dtype,
device,
kvcache.full_kv_pool,
need_sort,
)
self.swa_attn_allocator = TokenToKVPoolAllocator(
size_swa,
dtype,
device,
kvcache.swa_kv_pool,
need_sort,
)
self.full_to_swa_index_mapping = torch.empty(
size + size_swa + 1,
dtype=torch.int64,
device=device,
)
self.clear()
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
def available_size(self):
raise NotImplementedError()
def full_available_size(self):
return self.full_attn_allocator.available_size()
def swa_available_size(self):
return self.swa_attn_allocator.available_size()
@property
def size_full(self):
return self._size_full
@property
def size_swa(self):
return self._size_swa
def debug_print(self) -> str:
msg = ""
msg += f"#swa-available-size: {self.swa_attn_allocator.available_size()}, "
msg += (
f"#full-attn-available-size: {self.full_attn_allocator.available_size()}, "
)
return msg
def get_kvcache(self):
return self._kvcache
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
assert self.full_to_swa_index_mapping is not None
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
def alloc(self, need_size: int):
if need_size > self.full_attn_allocator.available_size():
return None
if need_size > self.swa_attn_allocator.available_size():
return None
alloc_full_indices = self.full_attn_allocator.alloc(need_size)
alloc_swa_indices = self.swa_attn_allocator.alloc(need_size)
self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices
return alloc_full_indices
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.full_attn_allocator.free(free_index)
self.free_swa(free_index)
else:
self.free_group.append(free_index)
assert (
self.full_attn_allocator.available_size() <= self.full_attn_allocator.size
)
assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size
def free_swa(self, free_index: torch.Tensor):
swa_indices = self.full_to_swa_index_mapping[free_index]
swa_indices = swa_indices[swa_indices > 0]
self.swa_attn_allocator.free(swa_indices)
self.full_to_swa_index_mapping[free_index] = 0
def backup_state(self):
return [
self.full_attn_allocator.backup_state(),
self.swa_attn_allocator.backup_state(),
]
def restore_state(self, state):
assert len(state) == 2
self.full_attn_allocator.restore_state(state[0])
self.swa_attn_allocator.restore_state(state[1])
def clear(self):
self.swa_attn_allocator.clear()
self.full_attn_allocator.clear()
self.full_to_swa_index_mapping.fill_(0)
self.is_not_in_free_group = True
self.free_group = []
@triton.jit
def alloc_extend_kernel(
pre_lens_ptr,
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
max_num_extend_tokens: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
extend_lens = seq_lens - pre_lens
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = tl.load(pre_lens_ptr + pid)
extend_len = seq_len - pre_len
sum_extend_lens = tl.sum(extend_lens)
output_start_loc = sum_extend_lens - extend_len
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Part 1: fill the old partial page
last_loc = tl.load(last_loc_ptr + pid)
num_part1 = (
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
)
offset_one_page = tl.arange(0, page_size)
tl.store(
out_indices + output_start_loc + offset_one_page,
last_loc + 1 + offset_one_page,
mask=offset_one_page < num_part1,
)
if pre_len + num_part1 == seq_len:
return
# Part 2: fill the new full pages
num_part2 = (
seq_len // page_size * page_size
- (pre_len + page_size - 1) // page_size * page_size
)
offset_many_page = tl.arange(0, max_num_extend_tokens)
page_start = tl.load(
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
mask=offset_many_page < num_part2,
)
tl.store(
out_indices + output_start_loc + num_part1 + offset_many_page,
page_start * page_size + offset_many_page % page_size,
mask=offset_many_page < num_part2,
)
if pre_len + num_part1 + num_part2 == seq_len:
return
# Part 3: fill the new partial page
num_part3 = seq_len - seq_len // page_size * page_size
start_loc = tl.load(
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
)
tl.store(
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
start_loc * page_size + offset_one_page,
mask=offset_one_page < num_part3,
)
@triton.jit
def alloc_decode_kernel(
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = seq_len - 1
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
if num_page_start_loc_self == 0:
last_loc = tl.load(last_loc_ptr + pid)
tl.store(out_indices + pid, last_loc + 1)
else:
page = tl.load(free_page_ptr + new_page_start_loc)
tl.store(out_indices + pid, page * page_size)
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""
An allocator managing the indices to kv cache data.
This class has the same interface as `TokenToKVPoolAllocator` but the output
of one request is always page-aligned.
TODO: fuse last_loc into the kernel.
"""
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.use_dcu_decode_kernel = get_bool_env_var("USE_DCU_DECODE_KERNEL")
self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear()
def alloc(self, need_size: int):
# page-aligned allocation, returning contiguous indices of pages
if self.debug_mode:
assert (
need_size % self.page_size == 0
), "The allocation size should be page-aligned"
num_pages = need_size // self.page_size
if self.need_sort and num_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_pages > len(self.free_pages):
return None
out_pages = self.free_pages[:num_pages]
self.free_pages = self.free_pages[num_pages:]
out_indices = (
out_pages[:, None] * self.page_size
+ torch.arange(self.page_size, device=self.device)
).reshape(-1)
return out_indices
def alloc_extend(
self,
prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
self.seen_max_num_extend_tokens_next_power_of_2 = max(
self.seen_max_num_extend_tokens_next_power_of_2,
next_power_of_2(extend_num_tokens),
)
bs = len(prefix_lens)
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
self.free_pages
):
self.merge_and_sort_free()
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
prefix_lens=prefix_lens_cpu,
)
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
bs = len(seq_lens)
if self.need_sort and bs > len(self.free_pages):
self.merge_and_sort_free()
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
if self.use_dcu_decode_kernel:
dcu_alloc_decode_kernel(
seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc,
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
)
else:
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
decode=True,
)
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size)
if self.need_sort:
self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_pages = torch.cat((free_page_indices, self.free_pages))
else:
self.free_group.append(free_index)
if self.debug_mode:
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.num_pages + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)