[minor] simplify the TokenToKVPoolAllocator (#7414)
This commit is contained in:
@@ -21,13 +21,11 @@ Life cycle of a request in the decode server
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
@@ -47,12 +45,9 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
prepare_abort,
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||||
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||||
KVCache,
|
|
||||||
ReqToTokenPool,
|
|
||||||
TokenToKVPoolAllocator,
|
|
||||||
)
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import require_mlp_sync
|
from sglang.srt.utils import require_mlp_sync
|
||||||
@@ -141,7 +136,7 @@ class DecodePreallocQueue:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
draft_token_to_kv_pool: Optional[KVCache],
|
draft_token_to_kv_pool: Optional[KVCache],
|
||||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||||
metadata_buffers: MetadataBuffers,
|
metadata_buffers: MetadataBuffers,
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from collections import deque
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
||||||
|
|||||||
@@ -18,12 +18,13 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import threading
|
import threading
|
||||||
from queue import Empty, Full, PriorityQueue, Queue
|
from queue import Empty, Full, PriorityQueue, Queue
|
||||||
from typing import List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -163,7 +164,7 @@ class HiCacheController:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
mem_pool_host: HostKVCache,
|
mem_pool_host: HostKVCache,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
load_cache_event: threading.Event = None,
|
load_cache_event: threading.Event = None,
|
||||||
|
|||||||
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||||
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
from sglang.srt.metrics.collector import TimeStats
|
from sglang.srt.metrics.collector import TimeStats
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
@@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# Request, memory pool, and cache
|
# Request, memory pool, and cache
|
||||||
reqs: List[Req]
|
reqs: List[Req]
|
||||||
req_to_token_pool: ReqToTokenPool = None
|
req_to_token_pool: ReqToTokenPool = None
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
|
||||||
tree_cache: BasePrefixCache = None
|
tree_cache: BasePrefixCache = None
|
||||||
|
|
||||||
# Batch configs
|
# Batch configs
|
||||||
@@ -907,7 +908,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
cls,
|
cls,
|
||||||
reqs: List[Req],
|
reqs: List[Req],
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
# Copyright 2023-2024 SGLang Team
|
# Copyright 2023-2024 SGLang Team
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -18,15 +20,17 @@ import random
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Dict, List, Optional, Set, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
|
|
||||||
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
||||||
# This can prevent the server from being too conservative.
|
# This can prevent the server from being too conservative.
|
||||||
# Note that this only clips the estimation in the scheduler but does not change the stop
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
||||||
@@ -265,7 +269,7 @@ class PrefillAdder:
|
|||||||
self,
|
self,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
running_batch: ScheduleBatch,
|
running_batch: ScheduleBatch,
|
||||||
new_token_ratio: float,
|
new_token_ratio: float,
|
||||||
rem_input_tokens: int,
|
rem_input_tokens: int,
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import time
|
|||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -57,7 +58,7 @@ class TpModelWorker:
|
|||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
is_draft_worker: bool = False,
|
is_draft_worker: bool = False,
|
||||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Copyright 2025 SGLang Team
|
Copyright 2025 SGLang Team
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -17,13 +19,132 @@ limitations under the License.
|
|||||||
Page-aligned memory pool.
|
Page-aligned memory pool.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
|
||||||
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
page_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
kvcache: KVCache,
|
||||||
|
):
|
||||||
|
self.size = size
|
||||||
|
self.page_size = page_size
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
self._kvcache = kvcache
|
||||||
|
|
||||||
|
self.free_pages = None
|
||||||
|
self.is_not_in_free_group = True
|
||||||
|
self.free_group = []
|
||||||
|
|
||||||
|
def debug_print(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def available_size(self):
|
||||||
|
return len(self.free_pages) * self.page_size
|
||||||
|
|
||||||
|
def get_kvcache(self):
|
||||||
|
return self._kvcache
|
||||||
|
|
||||||
|
def restore_state(self, free_pages):
|
||||||
|
self.free_pages = free_pages
|
||||||
|
|
||||||
|
def backup_state(self):
|
||||||
|
return self.free_pages
|
||||||
|
|
||||||
|
def free_group_begin(self):
|
||||||
|
self.is_not_in_free_group = False
|
||||||
|
self.free_group = []
|
||||||
|
|
||||||
|
def free_group_end(self):
|
||||||
|
self.is_not_in_free_group = True
|
||||||
|
if self.free_group:
|
||||||
|
self.free(torch.cat(self.free_group))
|
||||||
|
|
||||||
|
def get_cpu_copy(self, *args, **kwargs):
|
||||||
|
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def load_cpu_copy(self, *args, **kwargs):
|
||||||
|
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def alloc_extend(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("alloc_extend is only for paged allocator")
|
||||||
|
|
||||||
|
def alloc_decode(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("alloc_decode is only for paged allocator")
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def clear(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def alloc(self, need_size: int):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def free(self, free_index: torch.Tensor):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||||
|
"""An allocator managing the indices to kv cache data."""
|
||||||
|
|
||||||
|
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
|
||||||
|
super().__init__(size, 1, dtype, device, kvcache)
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
|
self.free_pages = torch.arange(
|
||||||
|
1, self.size + 1, dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
self.is_not_in_free_group = True
|
||||||
|
self.free_group = []
|
||||||
|
|
||||||
|
def available_size(self):
|
||||||
|
# To avoid minor "len(free_pages) * 1" overhead
|
||||||
|
return len(self.free_pages)
|
||||||
|
|
||||||
|
def alloc(self, need_size: int):
|
||||||
|
if need_size > len(self.free_pages):
|
||||||
|
return None
|
||||||
|
|
||||||
|
select_index = self.free_pages[:need_size]
|
||||||
|
self.free_pages = self.free_pages[need_size:]
|
||||||
|
return select_index
|
||||||
|
|
||||||
|
def free(self, free_index: torch.Tensor):
|
||||||
|
if free_index.numel() == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.is_not_in_free_group:
|
||||||
|
self.free_pages = torch.cat((self.free_pages, free_index))
|
||||||
|
else:
|
||||||
|
self.free_group.append(free_index)
|
||||||
|
|
||||||
|
def get_cpu_copy(self, indices):
|
||||||
|
return self._kvcache.get_cpu_copy(indices)
|
||||||
|
|
||||||
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||||
|
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def alloc_extend_kernel(
|
def alloc_extend_kernel(
|
||||||
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
|
|||||||
tl.store(out_indices + pid, page * page_size)
|
tl.store(out_indices + pid, page * page_size)
|
||||||
|
|
||||||
|
|
||||||
class PagedTokenToKVPoolAllocator:
|
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||||
"""
|
"""
|
||||||
An allocator managing the indices to kv cache data.
|
An allocator managing the indices to kv cache data.
|
||||||
|
|
||||||
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
|
|||||||
device: str,
|
device: str,
|
||||||
kvcache: KVCache,
|
kvcache: KVCache,
|
||||||
):
|
):
|
||||||
self.size = size
|
super().__init__(size, page_size, dtype, device, kvcache)
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
self.page_size = page_size
|
|
||||||
self.num_pages = size // page_size
|
self.num_pages = size // page_size
|
||||||
|
|
||||||
self.free_pages = None
|
|
||||||
self.is_not_in_free_group = True
|
|
||||||
self.free_group = []
|
|
||||||
self.clear()
|
|
||||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||||
|
|
||||||
self._kvcache = kvcache
|
|
||||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||||
|
self.clear()
|
||||||
def available_size(self):
|
|
||||||
return len(self.free_pages) * self.page_size
|
|
||||||
|
|
||||||
def get_kvcache(self):
|
|
||||||
return self._kvcache
|
|
||||||
|
|
||||||
def alloc(self, need_size: int):
|
def alloc(self, need_size: int):
|
||||||
# page-aligned allocation, returning contiguous indices of pages
|
# page-aligned allocation, returning contiguous indices of pages
|
||||||
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
|
|||||||
if self.debug_mode:
|
if self.debug_mode:
|
||||||
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
||||||
|
|
||||||
def free_group_begin(self):
|
|
||||||
self.is_not_in_free_group = False
|
|
||||||
self.free_group = []
|
|
||||||
|
|
||||||
def free_group_end(self):
|
|
||||||
self.is_not_in_free_group = True
|
|
||||||
if self.free_group:
|
|
||||||
self.free(torch.cat(self.free_group))
|
|
||||||
|
|
||||||
def backup_state(self):
|
|
||||||
return self.free_pages
|
|
||||||
|
|
||||||
def restore_state(self, free_pages):
|
|
||||||
self.free_pages = free_pages
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
self.free_pages = torch.arange(
|
self.free_pages = torch.arange(
|
||||||
@@ -2,12 +2,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
):
|
):
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.cache_controller import HiCacheController
|
from sglang.srt.managers.cache_controller import HiCacheController
|
||||||
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
MHATokenToKVPool,
|
MHATokenToKVPool,
|
||||||
MLATokenToKVPool,
|
MLATokenToKVPool,
|
||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
TokenToKVPoolAllocator,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool_host import (
|
from sglang.srt.mem_cache.memory_pool_host import (
|
||||||
MHATokenToKVPoolHost,
|
MHATokenToKVPoolHost,
|
||||||
@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
tp_cache_group: torch.distributed.ProcessGroup,
|
tp_cache_group: torch.distributed.ProcessGroup,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
hicache_ratio: float,
|
hicache_ratio: float,
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache.
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -167,84 +166,6 @@ class KVCache(abc.ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class TokenToKVPoolAllocator:
|
|
||||||
"""An allocator managing the indices to kv cache data."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: str,
|
|
||||||
kvcache: KVCache,
|
|
||||||
):
|
|
||||||
self.size = size
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
self.page_size = 1
|
|
||||||
|
|
||||||
self.free_slots = None
|
|
||||||
self.is_not_in_free_group = True
|
|
||||||
self.free_group = []
|
|
||||||
self.clear()
|
|
||||||
|
|
||||||
self._kvcache = kvcache
|
|
||||||
|
|
||||||
def available_size(self):
|
|
||||||
return len(self.free_slots)
|
|
||||||
|
|
||||||
def debug_print(self) -> str:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def get_kvcache(self):
|
|
||||||
return self._kvcache
|
|
||||||
|
|
||||||
def alloc(self, need_size: int):
|
|
||||||
if need_size > len(self.free_slots):
|
|
||||||
return None
|
|
||||||
|
|
||||||
select_index = self.free_slots[:need_size]
|
|
||||||
self.free_slots = self.free_slots[need_size:]
|
|
||||||
return select_index
|
|
||||||
|
|
||||||
def free(self, free_index: torch.Tensor):
|
|
||||||
if free_index.numel() == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.is_not_in_free_group:
|
|
||||||
self.free_slots = torch.cat((self.free_slots, free_index))
|
|
||||||
else:
|
|
||||||
self.free_group.append(free_index)
|
|
||||||
|
|
||||||
def free_group_begin(self):
|
|
||||||
self.is_not_in_free_group = False
|
|
||||||
self.free_group = []
|
|
||||||
|
|
||||||
def free_group_end(self):
|
|
||||||
self.is_not_in_free_group = True
|
|
||||||
if self.free_group:
|
|
||||||
self.free(torch.cat(self.free_group))
|
|
||||||
|
|
||||||
def backup_state(self):
|
|
||||||
return self.free_slots
|
|
||||||
|
|
||||||
def restore_state(self, free_slots):
|
|
||||||
self.free_slots = free_slots
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
|
||||||
self.free_slots = torch.arange(
|
|
||||||
1, self.size + 1, dtype=torch.int64, device=self.device
|
|
||||||
)
|
|
||||||
self.is_not_in_free_group = True
|
|
||||||
self.free_group = []
|
|
||||||
|
|
||||||
def get_cpu_copy(self, indices):
|
|
||||||
return self._kvcache.get_cpu_copy(indices)
|
|
||||||
|
|
||||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
|
||||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
|
||||||
|
|
||||||
|
|
||||||
class MHATokenToKVPool(KVCache):
|
class MHATokenToKVPool(KVCache):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import heapq
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
|
|||||||
AllBlocksCleared,
|
AllBlocksCleared,
|
||||||
BlockRemoved,
|
BlockRemoved,
|
||||||
BlockStored,
|
BlockStored,
|
||||||
KVCacheEvent,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
disable: bool = False,
|
disable: bool = False,
|
||||||
enable_kv_cache_events: bool = False,
|
enable_kv_cache_events: bool = False,
|
||||||
|
|||||||
@@ -71,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
GLOBAL_SERVER_ARGS_KEYS,
|
GLOBAL_SERVER_ARGS_KEYS,
|
||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.mem_cache.allocator import (
|
||||||
|
BaseTokenToKVPoolAllocator,
|
||||||
|
PagedTokenToKVPoolAllocator,
|
||||||
|
TokenToKVPoolAllocator,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
DoubleSparseTokenToKVPool,
|
DoubleSparseTokenToKVPool,
|
||||||
MHATokenToKVPool,
|
MHATokenToKVPool,
|
||||||
MLATokenToKVPool,
|
MLATokenToKVPool,
|
||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
TokenToKVPoolAllocator,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
@@ -152,7 +155,7 @@ class ModelRunner:
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
is_draft_worker: bool = False,
|
is_draft_worker: bool = False,
|
||||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
get_last_loc,
|
get_last_loc,
|
||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||||
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||||
|
|
||||||
@@ -315,7 +315,7 @@ class EagleVerifyInput:
|
|||||||
self,
|
self,
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
logits_output: torch.Tensor,
|
logits_output: torch.Tensor,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|||||||
Reference in New Issue
Block a user