From 05c9bc8956912060968635ba90a140314cde260d Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 22 Jun 2025 12:37:18 +0800 Subject: [PATCH] [minor] simplify the `TokenToKVPoolAllocator` (#7414) --- python/sglang/srt/disaggregation/decode.py | 11 +- python/sglang/srt/disaggregation/prefill.py | 1 - .../sglang/srt/managers/cache_controller.py | 9 +- python/sglang/srt/managers/schedule_batch.py | 7 +- python/sglang/srt/managers/schedule_policy.py | 10 +- python/sglang/srt/managers/scheduler.py | 1 - python/sglang/srt/managers/tp_worker.py | 5 +- .../{paged_allocator.py => allocator.py} | 159 ++++++++++++++---- python/sglang/srt/mem_cache/chunk_cache.py | 7 +- python/sglang/srt/mem_cache/hiradix_cache.py | 4 +- python/sglang/srt/mem_cache/memory_pool.py | 79 --------- python/sglang/srt/mem_cache/radix_cache.py | 8 +- .../sglang/srt/model_executor/model_runner.py | 9 +- python/sglang/srt/speculative/eagle_utils.py | 4 +- 14 files changed, 165 insertions(+), 149 deletions(-) rename python/sglang/srt/mem_cache/{paged_allocator.py => allocator.py} (77%) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index f625fe171..c8a0067c0 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,13 +21,11 @@ Life cycle of a request in the decode server from __future__ import annotations import logging -import os from collections import deque from dataclasses import dataclass from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple, Union -import numpy as np import torch from torch.distributed import ProcessGroup @@ -47,12 +45,9 @@ from sglang.srt.disaggregation.utils import ( prepare_abort, ) from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.memory_pool import ( - KVCache, - ReqToTokenPool, - TokenToKVPoolAllocator, -) +from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import require_mlp_sync @@ -141,7 +136,7 @@ class DecodePreallocQueue: def __init__( self, req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, draft_token_to_kv_pool: Optional[KVCache], req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, metadata_buffers: MetadataBuffers, diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index adeadfe1b..8900731f7 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -25,7 +25,6 @@ from collections import deque from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional -import numpy as np import torch from sglang.srt.disaggregation.base import BaseKVManager, KVPoll diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 0fd102b6b..bd2ddcc5c 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -18,12 +18,13 @@ import logging import math import threading from queue import Empty, Full, PriorityQueue, Queue -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import torch -from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator -from sglang.srt.mem_cache.memory_pool_host import HostKVCache +if TYPE_CHECKING: + from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator + from sglang.srt.mem_cache.memory_pool_host import HostKVCache logger = logging.getLogger(__name__) @@ -163,7 +164,7 @@ class HiCacheController: def __init__( self, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, mem_pool_host: HostKVCache, page_size: int, load_cache_event: threading.Event = None, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7ad3ee3c8..3e9039e8c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.layers.multimodal import gpu_tensor_hash +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.metrics.collector import TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Request, memory pool, and cache reqs: List[Req] req_to_token_pool: ReqToTokenPool = None - token_to_kv_pool_allocator: TokenToKVPoolAllocator = None + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None tree_cache: BasePrefixCache = None # Batch configs @@ -907,7 +908,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): cls, reqs: List[Req], req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, tree_cache: BasePrefixCache, model_config: ModelConfig, enable_overlap: bool, diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 08e8b2a8a..87a6a145b 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,15 +20,17 @@ import random from collections import defaultdict from contextlib import contextmanager from enum import Enum, auto -from typing import Dict, List, Optional, Set, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode +if TYPE_CHECKING: + from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator + # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # This can prevent the server from being too conservative. # Note that this only clips the estimation in the scheduler but does not change the stop @@ -265,7 +269,7 @@ class PrefillAdder: self, page_size: int, tree_cache: BasePrefixCache, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, running_batch: ScheduleBatch, new_token_ratio: float, rem_input_tokens: int, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b030f5545..6dbb75182 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -23,7 +23,6 @@ import time from collections import defaultdict, deque from concurrent import futures from dataclasses import dataclass -from http import HTTPStatus from pathlib import Path from types import SimpleNamespace from typing import Dict, List, Optional, Tuple, Union diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 88bbde1b6..73a12e285 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs @@ -57,7 +58,7 @@ class TpModelWorker: nccl_port: int, is_draft_worker: bool = False, req_to_token_pool: Optional[ReqToTokenPool] = None, - token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, + token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None, ): # Parse args self.tp_size = server_args.tp_size diff --git a/python/sglang/srt/mem_cache/paged_allocator.py b/python/sglang/srt/mem_cache/allocator.py similarity index 77% rename from python/sglang/srt/mem_cache/paged_allocator.py rename to python/sglang/srt/mem_cache/allocator.py index f6f3e23d8..61f1c842b 100644 --- a/python/sglang/srt/mem_cache/paged_allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2025 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,13 +19,132 @@ limitations under the License. Page-aligned memory pool. """ +import abc +from typing import TYPE_CHECKING + import torch import triton import triton.language as tl -from sglang.srt.mem_cache.memory_pool import KVCache from sglang.srt.utils import get_bool_env_var, next_power_of_2 +if TYPE_CHECKING: + from sglang.srt.mem_cache.memory_pool import KVCache + + +class BaseTokenToKVPoolAllocator(abc.ABC): + @abc.abstractmethod + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + device: str, + kvcache: KVCache, + ): + self.size = size + self.page_size = page_size + self.dtype = dtype + self.device = device + self._kvcache = kvcache + + self.free_pages = None + self.is_not_in_free_group = True + self.free_group = [] + + def debug_print(self) -> str: + return "" + + def available_size(self): + return len(self.free_pages) * self.page_size + + def get_kvcache(self): + return self._kvcache + + def restore_state(self, free_pages): + self.free_pages = free_pages + + def backup_state(self): + return self.free_pages + + def free_group_begin(self): + self.is_not_in_free_group = False + self.free_group = [] + + def free_group_end(self): + self.is_not_in_free_group = True + if self.free_group: + self.free(torch.cat(self.free_group)) + + def get_cpu_copy(self, *args, **kwargs): + # FIXME: reuse the get_cpu_copy after paged allocator is implemented + raise NotImplementedError() + + def load_cpu_copy(self, *args, **kwargs): + # FIXME: reuse the load_cpu_copy after paged allocator is implemented + raise NotImplementedError() + + def alloc_extend(self, *args, **kwargs): + raise NotImplementedError("alloc_extend is only for paged allocator") + + def alloc_decode(self, *args, **kwargs): + raise NotImplementedError("alloc_decode is only for paged allocator") + + @abc.abstractmethod + def clear(self): + raise NotImplementedError() + + @abc.abstractmethod + def alloc(self, need_size: int): + raise NotImplementedError() + + @abc.abstractmethod + def free(self, free_index: torch.Tensor): + raise NotImplementedError() + + +class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): + """An allocator managing the indices to kv cache data.""" + + def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache): + super().__init__(size, 1, dtype, device, kvcache) + self.clear() + + def clear(self): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.free_pages = torch.arange( + 1, self.size + 1, dtype=torch.int64, device=self.device + ) + self.is_not_in_free_group = True + self.free_group = [] + + def available_size(self): + # To avoid minor "len(free_pages) * 1" overhead + return len(self.free_pages) + + def alloc(self, need_size: int): + if need_size > len(self.free_pages): + return None + + select_index = self.free_pages[:need_size] + self.free_pages = self.free_pages[need_size:] + return select_index + + def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + + if self.is_not_in_free_group: + self.free_pages = torch.cat((self.free_pages, free_index)) + else: + self.free_group.append(free_index) + + def get_cpu_copy(self, indices): + return self._kvcache.get_cpu_copy(indices) + + def load_cpu_copy(self, kv_cache_cpu, indices): + return self._kvcache.load_cpu_copy(kv_cache_cpu, indices) + @triton.jit def alloc_extend_kernel( @@ -154,7 +275,7 @@ def alloc_decode_kernel( tl.store(out_indices + pid, page * page_size) -class PagedTokenToKVPoolAllocator: +class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): """ An allocator managing the indices to kv cache data. @@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator: device: str, kvcache: KVCache, ): - self.size = size - self.dtype = dtype - self.device = device - self.page_size = page_size + super().__init__(size, page_size, dtype, device, kvcache) self.num_pages = size // page_size - - self.free_pages = None - self.is_not_in_free_group = True - self.free_group = [] - self.clear() self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") - - self._kvcache = kvcache self.ret_values = torch.empty((), dtype=torch.int64, device=self.device) - - def available_size(self): - return len(self.free_pages) * self.page_size - - def get_kvcache(self): - return self._kvcache + self.clear() def alloc(self, need_size: int): # page-aligned allocation, returning contiguous indices of pages @@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator: if self.debug_mode: assert len(torch.unique(self.free_pages)) == len(self.free_pages) - def free_group_begin(self): - self.is_not_in_free_group = False - self.free_group = [] - - def free_group_end(self): - self.is_not_in_free_group = True - if self.free_group: - self.free(torch.cat(self.free_group)) - - def backup_state(self): - return self.free_pages - - def restore_state(self, free_pages): - self.free_pages = free_pages - def clear(self): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.free_pages = torch.arange( diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 80bdb9690..68a993b51 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -2,12 +2,13 @@ from __future__ import annotations """Cache for chunked prefill, used when RadixCache is disabled.""" -from typing import TYPE_CHECKING, Any, Callable, List, Tuple +from typing import TYPE_CHECKING, Any import torch +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, page_size: int, ): self.req_to_token_pool = req_to_token_pool diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index cf7357bc0..31918b150 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -7,12 +7,12 @@ from typing import List, Optional import torch from sglang.srt.managers.cache_controller import HiCacheController +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import MatchResult from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, - TokenToKVPoolAllocator, ) from sglang.srt.mem_cache.memory_pool_host import ( MHATokenToKVPoolHost, @@ -28,7 +28,7 @@ class HiRadixCache(RadixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, tp_cache_group: torch.distributed.ProcessGroup, page_size: int, hicache_ratio: float, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 5306ce175..c7580c622 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache. import abc import logging -import os from contextlib import nullcontext from typing import List, Optional, Tuple, Union @@ -167,84 +166,6 @@ class KVCache(abc.ABC): raise NotImplementedError() -class TokenToKVPoolAllocator: - """An allocator managing the indices to kv cache data.""" - - def __init__( - self, - size: int, - dtype: torch.dtype, - device: str, - kvcache: KVCache, - ): - self.size = size - self.dtype = dtype - self.device = device - self.page_size = 1 - - self.free_slots = None - self.is_not_in_free_group = True - self.free_group = [] - self.clear() - - self._kvcache = kvcache - - def available_size(self): - return len(self.free_slots) - - def debug_print(self) -> str: - return "" - - def get_kvcache(self): - return self._kvcache - - def alloc(self, need_size: int): - if need_size > len(self.free_slots): - return None - - select_index = self.free_slots[:need_size] - self.free_slots = self.free_slots[need_size:] - return select_index - - def free(self, free_index: torch.Tensor): - if free_index.numel() == 0: - return - - if self.is_not_in_free_group: - self.free_slots = torch.cat((self.free_slots, free_index)) - else: - self.free_group.append(free_index) - - def free_group_begin(self): - self.is_not_in_free_group = False - self.free_group = [] - - def free_group_end(self): - self.is_not_in_free_group = True - if self.free_group: - self.free(torch.cat(self.free_group)) - - def backup_state(self): - return self.free_slots - - def restore_state(self, free_slots): - self.free_slots = free_slots - - def clear(self): - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.free_slots = torch.arange( - 1, self.size + 1, dtype=torch.int64, device=self.device - ) - self.is_not_in_free_group = True - self.free_group = [] - - def get_cpu_copy(self, indices): - return self._kvcache.get_cpu_copy(indices) - - def load_cpu_copy(self, kv_cache_cpu, indices): - return self._kvcache.load_cpu_copy(kv_cache_cpu, indices) - - class MHATokenToKVPool(KVCache): def __init__( diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 256595b7a..72241b829 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -23,7 +23,7 @@ import heapq import time from collections import defaultdict from functools import partial -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional import torch @@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import ( AllBlocksCleared, BlockRemoved, BlockStored, - KVCacheEvent, ) +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, page_size: int, disable: bool = False, enable_kv_cache_events: bool = False, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 32d1e18da..2743fe51e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -71,14 +71,17 @@ from sglang.srt.managers.schedule_batch import ( GLOBAL_SERVER_ARGS_KEYS, global_server_args_dict, ) +from sglang.srt.mem_cache.allocator import ( + BaseTokenToKVPoolAllocator, + PagedTokenToKVPoolAllocator, + TokenToKVPoolAllocator, +) from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, - TokenToKVPoolAllocator, ) -from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors @@ -152,7 +155,7 @@ class ModelRunner: server_args: ServerArgs, is_draft_worker: bool = False, req_to_token_pool: Optional[ReqToTokenPool] = None, - token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, + token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None, ): # Parse args self.model_config = model_config diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index b69e2939c..cb9d86cf7 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -21,7 +21,7 @@ from sglang.srt.managers.schedule_batch import ( get_last_loc, global_server_args_dict, ) -from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 @@ -315,7 +315,7 @@ class EagleVerifyInput: self, batch: ScheduleBatch, logits_output: torch.Tensor, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, page_size: int, vocab_mask: Optional[torch.Tensor] = None, # For grammar ) -> torch.Tensor: