[minor] simplify the TokenToKVPoolAllocator (#7414)
This commit is contained in:
@@ -21,13 +21,11 @@ Life cycle of a request in the decode server
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
@@ -47,12 +45,9 @@ from sglang.srt.disaggregation.utils import (
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
KVCache,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import require_mlp_sync
|
||||
@@ -141,7 +136,7 @@ class DecodePreallocQueue:
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
draft_token_to_kv_pool: Optional[KVCache],
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
|
||||
@@ -25,7 +25,6 @@ from collections import deque
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
||||
|
||||
@@ -18,12 +18,13 @@ import logging
|
||||
import math
|
||||
import threading
|
||||
from queue import Empty, Full, PriorityQueue, Queue
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -163,7 +164,7 @@ class HiCacheController:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
mem_pool_host: HostKVCache,
|
||||
page_size: int,
|
||||
load_cache_event: threading.Event = None,
|
||||
|
||||
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.metrics.collector import TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
@@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
|
||||
# Batch configs
|
||||
@@ -907,7 +908,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
cls,
|
||||
reqs: List[Req],
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
tree_cache: BasePrefixCache,
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -18,15 +20,17 @@ import random
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
|
||||
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
||||
# This can prevent the server from being too conservative.
|
||||
# Note that this only clips the estimation in the scheduler but does not change the stop
|
||||
@@ -265,7 +269,7 @@ class PrefillAdder:
|
||||
self,
|
||||
page_size: int,
|
||||
tree_cache: BasePrefixCache,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
running_batch: ScheduleBatch,
|
||||
new_token_ratio: float,
|
||||
rem_input_tokens: int,
|
||||
|
||||
@@ -23,7 +23,6 @@ import time
|
||||
from collections import defaultdict, deque
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -57,7 +58,7 @@ class TpModelWorker:
|
||||
nccl_port: int,
|
||||
is_draft_worker: bool = False,
|
||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
||||
):
|
||||
# Parse args
|
||||
self.tp_size = server_args.tp_size
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Copyright 2025 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -17,13 +19,132 @@ limitations under the License.
|
||||
Page-aligned memory pool.
|
||||
"""
|
||||
|
||||
import abc
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
|
||||
|
||||
class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self._kvcache = kvcache
|
||||
|
||||
self.free_pages = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
def debug_print(self) -> str:
|
||||
return ""
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
def restore_state(self, free_pages):
|
||||
self.free_pages = free_pages
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_pages
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def get_cpu_copy(self, *args, **kwargs):
|
||||
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_cpu_copy(self, *args, **kwargs):
|
||||
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
|
||||
raise NotImplementedError()
|
||||
|
||||
def alloc_extend(self, *args, **kwargs):
|
||||
raise NotImplementedError("alloc_extend is only for paged allocator")
|
||||
|
||||
def alloc_decode(self, *args, **kwargs):
|
||||
raise NotImplementedError("alloc_decode is only for paged allocator")
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def alloc(self, need_size: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def free(self, free_index: torch.Tensor):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
"""An allocator managing the indices to kv cache data."""
|
||||
|
||||
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
|
||||
super().__init__(size, 1, dtype, device, kvcache)
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_pages = torch.arange(
|
||||
1, self.size + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
def available_size(self):
|
||||
# To avoid minor "len(free_pages) * 1" overhead
|
||||
return len(self.free_pages)
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > len(self.free_pages):
|
||||
return None
|
||||
|
||||
select_index = self.free_pages[:need_size]
|
||||
self.free_pages = self.free_pages[need_size:]
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_pages = torch.cat((self.free_pages, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
return self._kvcache.get_cpu_copy(indices)
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def alloc_extend_kernel(
|
||||
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
|
||||
tl.store(out_indices + pid, page * page_size)
|
||||
|
||||
|
||||
class PagedTokenToKVPoolAllocator:
|
||||
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
"""
|
||||
An allocator managing the indices to kv cache data.
|
||||
|
||||
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
super().__init__(size, page_size, dtype, device, kvcache)
|
||||
self.num_pages = size // page_size
|
||||
|
||||
self.free_pages = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
self.clear()
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
|
||||
self._kvcache = kvcache
|
||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
self.clear()
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
# page-aligned allocation, returning contiguous indices of pages
|
||||
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_pages
|
||||
|
||||
def restore_state(self, free_pages):
|
||||
self.free_pages = free_pages
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_pages = torch.arange(
|
||||
@@ -2,12 +2,13 @@ from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
|
||||
@@ -7,12 +7,12 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool_host import (
|
||||
MHATokenToKVPoolHost,
|
||||
@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
tp_cache_group: torch.distributed.ProcessGroup,
|
||||
page_size: int,
|
||||
hicache_ratio: float,
|
||||
|
||||
@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -167,84 +166,6 @@ class KVCache(abc.ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenToKVPoolAllocator:
|
||||
"""An allocator managing the indices to kv cache data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = 1
|
||||
|
||||
self.free_slots = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
self.clear()
|
||||
|
||||
self._kvcache = kvcache
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def debug_print(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_slots = torch.cat((self.free_slots, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_slots
|
||||
|
||||
def restore_state(self, free_slots):
|
||||
self.free_slots = free_slots
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_slots = torch.arange(
|
||||
1, self.size + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
return self._kvcache.get_cpu_copy(indices)
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -23,7 +23,7 @@ import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
|
||||
AllBlocksCleared,
|
||||
BlockRemoved,
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
|
||||
@@ -71,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
|
||||
GLOBAL_SERVER_ARGS_KEYS,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
PagedTokenToKVPoolAllocator,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
DoubleSparseTokenToKVPool,
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
@@ -152,7 +155,7 @@ class ModelRunner:
|
||||
server_args: ServerArgs,
|
||||
is_draft_worker: bool = False,
|
||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
||||
):
|
||||
# Parse args
|
||||
self.model_config = model_config
|
||||
|
||||
@@ -21,7 +21,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
get_last_loc,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||
|
||||
@@ -315,7 +315,7 @@ class EagleVerifyInput:
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
logits_output: torch.Tensor,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user