[minor] simplify the TokenToKVPoolAllocator (#7414)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Copyright 2025 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -17,13 +19,132 @@ limitations under the License.
|
||||
Page-aligned memory pool.
|
||||
"""
|
||||
|
||||
import abc
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
|
||||
|
||||
class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self._kvcache = kvcache
|
||||
|
||||
self.free_pages = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
def debug_print(self) -> str:
|
||||
return ""
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
def restore_state(self, free_pages):
|
||||
self.free_pages = free_pages
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_pages
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def get_cpu_copy(self, *args, **kwargs):
|
||||
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_cpu_copy(self, *args, **kwargs):
|
||||
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
|
||||
raise NotImplementedError()
|
||||
|
||||
def alloc_extend(self, *args, **kwargs):
|
||||
raise NotImplementedError("alloc_extend is only for paged allocator")
|
||||
|
||||
def alloc_decode(self, *args, **kwargs):
|
||||
raise NotImplementedError("alloc_decode is only for paged allocator")
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def alloc(self, need_size: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def free(self, free_index: torch.Tensor):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
"""An allocator managing the indices to kv cache data."""
|
||||
|
||||
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
|
||||
super().__init__(size, 1, dtype, device, kvcache)
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_pages = torch.arange(
|
||||
1, self.size + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
def available_size(self):
|
||||
# To avoid minor "len(free_pages) * 1" overhead
|
||||
return len(self.free_pages)
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > len(self.free_pages):
|
||||
return None
|
||||
|
||||
select_index = self.free_pages[:need_size]
|
||||
self.free_pages = self.free_pages[need_size:]
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_pages = torch.cat((self.free_pages, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
return self._kvcache.get_cpu_copy(indices)
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def alloc_extend_kernel(
|
||||
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
|
||||
tl.store(out_indices + pid, page * page_size)
|
||||
|
||||
|
||||
class PagedTokenToKVPoolAllocator:
|
||||
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
"""
|
||||
An allocator managing the indices to kv cache data.
|
||||
|
||||
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
super().__init__(size, page_size, dtype, device, kvcache)
|
||||
self.num_pages = size // page_size
|
||||
|
||||
self.free_pages = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
self.clear()
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
|
||||
self._kvcache = kvcache
|
||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
self.clear()
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
# page-aligned allocation, returning contiguous indices of pages
|
||||
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_pages
|
||||
|
||||
def restore_state(self, free_pages):
|
||||
self.free_pages = free_pages
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_pages = torch.arange(
|
||||
@@ -2,12 +2,13 @@ from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
|
||||
@@ -7,12 +7,12 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool_host import (
|
||||
MHATokenToKVPoolHost,
|
||||
@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
tp_cache_group: torch.distributed.ProcessGroup,
|
||||
page_size: int,
|
||||
hicache_ratio: float,
|
||||
|
||||
@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -167,84 +166,6 @@ class KVCache(abc.ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenToKVPoolAllocator:
|
||||
"""An allocator managing the indices to kv cache data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = 1
|
||||
|
||||
self.free_slots = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
self.clear()
|
||||
|
||||
self._kvcache = kvcache
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def debug_print(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_slots = torch.cat((self.free_slots, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_slots
|
||||
|
||||
def restore_state(self, free_slots):
|
||||
self.free_slots = free_slots
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_slots = torch.arange(
|
||||
1, self.size + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
return self._kvcache.get_cpu_copy(indices)
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -23,7 +23,7 @@ import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
|
||||
AllBlocksCleared,
|
||||
BlockRemoved,
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
|
||||
Reference in New Issue
Block a user