Support page size > 1 (#4356)
This commit is contained in:
283
python/sglang/srt/mem_cache/paged_allocator.py
Normal file
283
python/sglang/srt/mem_cache/paged_allocator.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Copyright 2025 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Page-aligned memory pool.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def alloc_extend_kernel(
|
||||
pre_lens_ptr,
|
||||
seq_lens_ptr,
|
||||
last_loc_ptr,
|
||||
free_page_ptr,
|
||||
out_indices,
|
||||
ret_values,
|
||||
bs_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
max_num_extend_tokens: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
load_offset = tl.arange(0, bs_upper)
|
||||
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
||||
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
|
||||
extend_lens = seq_lens - pre_lens
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + pid)
|
||||
pre_len = tl.load(pre_lens_ptr + pid)
|
||||
extend_len = seq_len - pre_len
|
||||
|
||||
sum_extend_lens = tl.sum(extend_lens)
|
||||
output_start_loc = sum_extend_lens - extend_len
|
||||
|
||||
num_pages_after = (seq_lens + page_size - 1) // page_size
|
||||
num_pages_before = (pre_lens + page_size - 1) // page_size
|
||||
num_new_pages = num_pages_after - num_pages_before
|
||||
|
||||
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
||||
pre_len + page_size - 1
|
||||
) // page_size
|
||||
sum_num_new_pages = tl.sum(num_new_pages)
|
||||
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
||||
|
||||
# Return value
|
||||
if pid == tl.num_programs(0) - 1:
|
||||
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
|
||||
tl.int64
|
||||
)
|
||||
tl.store(ret_values, merged_value)
|
||||
|
||||
# Part 1: fill the old partial page
|
||||
last_loc = tl.load(last_loc_ptr + pid)
|
||||
num_part1 = (
|
||||
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
|
||||
)
|
||||
offset_one_page = tl.arange(0, page_size)
|
||||
tl.store(
|
||||
out_indices + output_start_loc + offset_one_page,
|
||||
last_loc + 1 + offset_one_page,
|
||||
mask=offset_one_page < num_part1,
|
||||
)
|
||||
if pre_len + num_part1 == seq_len:
|
||||
return
|
||||
|
||||
# Part 2: fill the new full pages
|
||||
num_part2 = (
|
||||
seq_len // page_size * page_size
|
||||
- (pre_len + page_size - 1) // page_size * page_size
|
||||
)
|
||||
|
||||
offset_many_page = tl.arange(0, max_num_extend_tokens)
|
||||
page_start = tl.load(
|
||||
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
|
||||
mask=offset_many_page < num_part2,
|
||||
)
|
||||
tl.store(
|
||||
out_indices + output_start_loc + num_part1 + offset_many_page,
|
||||
page_start * page_size + offset_many_page % page_size,
|
||||
mask=offset_many_page < num_part2,
|
||||
)
|
||||
if pre_len + num_part1 + num_part2 == seq_len:
|
||||
return
|
||||
|
||||
# Part 3: fill the new partial page
|
||||
num_part3 = seq_len - seq_len // page_size * page_size
|
||||
start_loc = tl.load(
|
||||
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
|
||||
)
|
||||
tl.store(
|
||||
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
|
||||
start_loc * page_size + offset_one_page,
|
||||
mask=offset_one_page < num_part3,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def alloc_decode_kernel(
|
||||
seq_lens_ptr,
|
||||
last_loc_ptr,
|
||||
free_page_ptr,
|
||||
out_indices,
|
||||
ret_values,
|
||||
bs_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
load_offset = tl.arange(0, bs_upper)
|
||||
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
||||
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + pid)
|
||||
pre_len = seq_len - 1
|
||||
|
||||
num_pages_after = (seq_lens + page_size - 1) // page_size
|
||||
num_pages_before = (pre_lens + page_size - 1) // page_size
|
||||
num_new_pages = num_pages_after - num_pages_before
|
||||
|
||||
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
||||
pre_len + page_size - 1
|
||||
) // page_size
|
||||
sum_num_new_pages = tl.sum(num_new_pages)
|
||||
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
||||
|
||||
# Return value
|
||||
if pid == tl.num_programs(0) - 1:
|
||||
tl.store(ret_values, sum_num_new_pages)
|
||||
|
||||
if num_page_start_loc_self == 0:
|
||||
last_loc = tl.load(last_loc_ptr + pid)
|
||||
tl.store(out_indices + pid, last_loc + 1)
|
||||
else:
|
||||
page = tl.load(free_page_ptr + new_page_start_loc)
|
||||
tl.store(out_indices + pid, page * page_size)
|
||||
|
||||
|
||||
class PagedTokenToKVPoolAllocator:
|
||||
"""
|
||||
An allocator managing the indices to kv cache data.
|
||||
|
||||
This class has the same interface as `TokenToKVPoolAllocator` but the output
|
||||
of one request is always page-aligned.
|
||||
|
||||
TODO: fuse last_loc into the kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
self.num_pages = size // page_size
|
||||
|
||||
self.free_pages = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
self.clear()
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
|
||||
self._kvcache = kvcache
|
||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
|
||||
def alloc_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(prefix_lens)
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
alloc_extend_kernel[(bs,)](
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
next_power_of_2(extend_num_tokens),
|
||||
)
|
||||
|
||||
merged_value = self.ret_values.item()
|
||||
num_new_pages = merged_value >> 32
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def alloc_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(seq_lens)
|
||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
alloc_decode_kernel[(bs,)](
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
num_new_pages = self.ret_values.item()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
free_page_indices = torch.unique(free_index // self.page_size)
|
||||
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.concat(self.free_group))
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_pages = torch.arange(
|
||||
1, self.num_pages + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.is_in_free_group = False
|
||||
self.free_group = []
|
||||
Reference in New Issue
Block a user