[minor] simplify the TokenToKVPoolAllocator (#7414)

This commit is contained in:
Liangsheng Yin
2025-06-22 12:37:18 +08:00
committed by GitHub
parent b7a2df0a44
commit 05c9bc8956
14 changed files with 165 additions and 149 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +19,132 @@ limitations under the License.
Page-aligned memory pool.
"""
import abc
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_bool_env_var, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
class BaseTokenToKVPoolAllocator(abc.ABC):
@abc.abstractmethod
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
self._kvcache = kvcache
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
def debug_print(self) -> str:
return ""
def available_size(self):
return len(self.free_pages) * self.page_size
def get_kvcache(self):
return self._kvcache
def restore_state(self, free_pages):
self.free_pages = free_pages
def backup_state(self):
return self.free_pages
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def get_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise NotImplementedError()
def load_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
raise NotImplementedError()
def alloc_extend(self, *args, **kwargs):
raise NotImplementedError("alloc_extend is only for paged allocator")
def alloc_decode(self, *args, **kwargs):
raise NotImplementedError("alloc_decode is only for paged allocator")
@abc.abstractmethod
def clear(self):
raise NotImplementedError()
@abc.abstractmethod
def alloc(self, need_size: int):
raise NotImplementedError()
@abc.abstractmethod
def free(self, free_index: torch.Tensor):
raise NotImplementedError()
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""An allocator managing the indices to kv cache data."""
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
super().__init__(size, 1, dtype, device, kvcache)
self.clear()
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []
def available_size(self):
# To avoid minor "len(free_pages) * 1" overhead
return len(self.free_pages)
def alloc(self, need_size: int):
if need_size > len(self.free_pages):
return None
select_index = self.free_pages[:need_size]
self.free_pages = self.free_pages[need_size:]
return select_index
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_pages = torch.cat((self.free_pages, free_index))
else:
self.free_group.append(free_index)
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
@triton.jit
def alloc_extend_kernel(
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
tl.store(out_indices + pid, page * page_size)
class PagedTokenToKVPoolAllocator:
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""
An allocator managing the indices to kv cache data.
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = page_size
super().__init__(size, page_size, dtype, device, kvcache)
self.num_pages = size // page_size
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self._kvcache = kvcache
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
def available_size(self):
return len(self.free_pages) * self.page_size
def get_kvcache(self):
return self._kvcache
self.clear()
def alloc(self, need_size: int):
# page-aligned allocation, returning contiguous indices of pages
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
if self.debug_mode:
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_pages
def restore_state(self, free_pages):
self.free_pages = free_pages
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(

View File

@@ -2,12 +2,13 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
from typing import TYPE_CHECKING, Any
import torch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
):
self.req_to_token_pool = req_to_token_pool

View File

@@ -7,12 +7,12 @@ from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup,
page_size: int,
hicache_ratio: float,

View File

@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache.
import abc
import logging
import os
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union
@@ -167,84 +166,6 @@ class KVCache(abc.ABC):
raise NotImplementedError()
class TokenToKVPoolAllocator:
"""An allocator managing the indices to kv cache data."""
def __init__(
self,
size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = 1
self.free_slots = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self._kvcache = kvcache
def available_size(self):
return len(self.free_slots)
def debug_print(self) -> str:
return ""
def get_kvcache(self):
return self._kvcache
def alloc(self, need_size: int):
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_slots = torch.cat((self.free_slots, free_index))
else:
self.free_group.append(free_index)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_slots
def restore_state(self, free_slots):
self.free_slots = free_slots
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
class MHATokenToKVPool(KVCache):
def __init__(

View File

@@ -23,7 +23,7 @@ import heapq
import time
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional
import torch
@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
AllBlocksCleared,
BlockRemoved,
BlockStored,
KVCacheEvent,
)
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
enable_kv_cache_events: bool = False,