[Refactor] Clean up radix cache related API (#7303)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -38,7 +38,7 @@ import logging
|
||||
import threading
|
||||
from enum import Enum, auto
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -436,7 +436,7 @@ class Req:
|
||||
self,
|
||||
rid: str,
|
||||
origin_input_text: str,
|
||||
origin_input_ids: Tuple[int],
|
||||
origin_input_ids: List[int],
|
||||
sampling_params: SamplingParams,
|
||||
return_logprob: bool = False,
|
||||
top_logprobs_num: int = 0,
|
||||
@@ -467,7 +467,7 @@ class Req:
|
||||
# Each decode stage's output ids
|
||||
self.output_ids = []
|
||||
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
||||
self.fill_ids = None
|
||||
self.fill_ids = []
|
||||
self.session_id = session_id
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
@@ -519,13 +519,14 @@ class Req:
|
||||
|
||||
# Prefix info
|
||||
# The indices to kv cache for the shared prefix.
|
||||
self.prefix_indices = []
|
||||
self.prefix_indices: torch.Tensor = []
|
||||
# Number of tokens to run prefill.
|
||||
self.extend_input_len = 0
|
||||
# The relative logprob_start_len in an extend batch
|
||||
self.extend_logprob_start_len = 0
|
||||
self.last_node = None
|
||||
self.last_node_global = None
|
||||
self.last_node: Any = None
|
||||
self.last_host_node: Any = None
|
||||
self.host_hit_length = 0
|
||||
|
||||
# Whether or not if it is chunked. It increments whenever
|
||||
# it is chunked, and decrement whenever chunked request is
|
||||
@@ -644,21 +645,17 @@ class Req:
|
||||
def init_next_round_input(
|
||||
self,
|
||||
tree_cache: Optional[BasePrefixCache] = None,
|
||||
enable_hierarchical_cache=False,
|
||||
):
|
||||
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||
if tree_cache is not None:
|
||||
# tree cache is None if the prefix is not computed with tree cache.
|
||||
if enable_hierarchical_cache:
|
||||
self.prefix_indices, self.last_node, self.last_node_global = (
|
||||
tree_cache.match_prefix(
|
||||
key=self.adjust_max_prefix_ids(), include_evicted=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||
)
|
||||
(
|
||||
self.prefix_indices,
|
||||
self.last_node,
|
||||
self.last_host_node,
|
||||
self.host_hit_length,
|
||||
) = tree_cache.match_prefix(
|
||||
key=self.adjust_max_prefix_ids(),
|
||||
)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
|
||||
@@ -90,7 +90,7 @@ class SchedulePolicy:
|
||||
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
||||
if self.policy == CacheAgnosticPolicy.FCFS:
|
||||
# A shortcut for FCFS
|
||||
return
|
||||
return False
|
||||
|
||||
policy = self._determine_active_policy(waiting_queue)
|
||||
|
||||
@@ -134,7 +134,7 @@ class SchedulePolicy:
|
||||
"""
|
||||
try:
|
||||
policy_enum = CacheAwarePolicy(policy)
|
||||
if tree_cache.disable:
|
||||
if getattr(tree_cache, "disable", True):
|
||||
# If tree_cache is disabled, using CacheAgnosticPolicy policy
|
||||
return CacheAgnosticPolicy.FCFS
|
||||
return policy_enum
|
||||
@@ -158,14 +158,9 @@ class SchedulePolicy:
|
||||
prefix_ids = r.adjust_max_prefix_ids()
|
||||
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
if self.enable_hierarchical_cache:
|
||||
r.prefix_indices, r.last_node, r.last_node_global = (
|
||||
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
|
||||
)
|
||||
else:
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
)
|
||||
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
|
||||
self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
|
||||
)
|
||||
|
||||
# NOTE(sang): This logic is for in-batch prefix caching;
|
||||
# If there are more than 1 request that have small matching prefix from
|
||||
@@ -175,7 +170,7 @@ class SchedulePolicy:
|
||||
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
||||
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
||||
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
||||
in_batch_matching_prefixes, _ = (
|
||||
in_batch_matching_prefixes, _, _, _ = (
|
||||
self.waiting_queue_radix_tree.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
)
|
||||
@@ -268,6 +263,7 @@ class AddReqResult(Enum):
|
||||
class PrefillAdder:
|
||||
def __init__(
|
||||
self,
|
||||
page_size: int,
|
||||
tree_cache: BasePrefixCache,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
running_batch: ScheduleBatch,
|
||||
@@ -276,6 +272,7 @@ class PrefillAdder:
|
||||
rem_chunk_tokens: Optional[int],
|
||||
mixed_with_decode_tokens: int = 0,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self.tree_cache = tree_cache
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.running_batch = running_batch
|
||||
@@ -442,46 +439,43 @@ class PrefillAdder:
|
||||
|
||||
return self.budget_state()
|
||||
|
||||
def add_one_req(
|
||||
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
|
||||
):
|
||||
def add_one_req(self, req: Req, has_chunked_req: bool):
|
||||
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||
|
||||
total_tokens = req.extend_input_len + min(
|
||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
||||
)
|
||||
input_tokens = (
|
||||
-(-req.extend_input_len // self.tree_cache.page_size)
|
||||
* self.tree_cache.page_size
|
||||
)
|
||||
|
||||
# adjusting the input_tokens based on host_hit_length and page_size
|
||||
real_input_tokens = req.extend_input_len - req.host_hit_length
|
||||
real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size
|
||||
prefix_len = len(req.prefix_indices)
|
||||
|
||||
if total_tokens >= self.rem_total_tokens:
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
||||
if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
with self._lock_node(req.last_node):
|
||||
if total_tokens > self.rem_total_tokens:
|
||||
# self.rem_total_tokens may decrease after the lock acquisition
|
||||
if total_tokens >= self.rem_total_tokens:
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if (
|
||||
enable_hierarchical_cache
|
||||
and req.last_node_global is not None
|
||||
and req.last_node_global.evicted
|
||||
):
|
||||
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
||||
req.last_node_global, req.prefix_indices
|
||||
if req.host_hit_length > 0:
|
||||
new_indices, req.last_node = self.tree_cache.init_load_back(
|
||||
req.last_host_node, req.host_hit_length
|
||||
)
|
||||
req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
input_tokens = (
|
||||
-(-req.extend_input_len // self.tree_cache.page_size)
|
||||
* self.tree_cache.page_size
|
||||
)
|
||||
prefix_len = len(req.prefix_indices)
|
||||
|
||||
input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size
|
||||
|
||||
if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
@@ -496,7 +490,7 @@ class PrefillAdder:
|
||||
)
|
||||
else:
|
||||
# Make sure at least one page is available
|
||||
trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
|
||||
trunc_len = self.rem_chunk_tokens - self.page_size + 1
|
||||
if trunc_len <= 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
|
||||
@@ -1467,15 +1467,14 @@ class Scheduler(
|
||||
return None
|
||||
|
||||
if self.enable_hierarchical_cache:
|
||||
# check for completion of hierarchical cache activities to release memory
|
||||
self.tree_cache.writing_check()
|
||||
self.tree_cache.loading_check()
|
||||
self.tree_cache.check_hicache_events()
|
||||
|
||||
# Get priority queue
|
||||
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
||||
self.policy.calc_priority(self.waiting_queue)
|
||||
|
||||
# Prefill policy
|
||||
adder = PrefillAdder(
|
||||
self.page_size,
|
||||
self.tree_cache,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.running_batch,
|
||||
@@ -1517,19 +1516,8 @@ class Scheduler(
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
# bypass prefix_computed if enable_hierarchical_cache
|
||||
req.init_next_round_input(
|
||||
(
|
||||
None
|
||||
if (prefix_computed and not self.enable_hierarchical_cache)
|
||||
else self.tree_cache
|
||||
),
|
||||
self.enable_hierarchical_cache,
|
||||
)
|
||||
|
||||
res = adder.add_one_req(
|
||||
req, self.chunked_req, self.enable_hierarchical_cache
|
||||
)
|
||||
req.init_next_round_input(self.tree_cache)
|
||||
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
||||
|
||||
if res != AddReqResult.CONTINUE:
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
@@ -1581,7 +1569,9 @@ class Scheduler(
|
||||
)
|
||||
if self.enable_hierarchical_cache:
|
||||
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
|
||||
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()
|
||||
new_batch.hicache_consumer_index = (
|
||||
self.tree_cache.ready_to_load_host_cache()
|
||||
)
|
||||
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
|
||||
@@ -1,5 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
else:
|
||||
Req = Any # Placeholder for Req type when not type checking
|
||||
|
||||
|
||||
class MatchResult(NamedTuple):
|
||||
"""Result of a prefix match operation.
|
||||
|
||||
Attributes:
|
||||
device_indices : Indices of the KV cache on the device matched by common prefix.
|
||||
last_device_node: The last TreeNode on the device that was matched.
|
||||
last_host_node : The last TreeNode on the host that was matched.
|
||||
Note that if HiCache is not enabled,
|
||||
this **must** be the same as `last_device_node`.
|
||||
host_hit_length : Length of the KV cache hit on the host, if applicable.
|
||||
0 if HiCache is not enabled.
|
||||
"""
|
||||
|
||||
device_indices: torch.Tensor
|
||||
last_device_node: Any
|
||||
last_host_node: Any
|
||||
host_hit_length: int = 0
|
||||
|
||||
|
||||
class BasePrefixCache(ABC):
|
||||
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, **kwargs):
|
||||
def cache_finished_req(self, req: Req, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_finished_req(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_unfinished_req(self, **kwargs):
|
||||
def cache_unfinished_req(self, req: Req, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
|
||||
def pretty_print(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_load_back(
|
||||
self,
|
||||
last_host_node: Any,
|
||||
host_hit_length: int,
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""
|
||||
Preparing KV cache loading from host to device.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def ready_to_load_host_cache(self) -> Any:
|
||||
"""
|
||||
Notify the cache controller to start the KV cache loading
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def check_hicache_events(self) -> Any:
|
||||
"""
|
||||
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def take_events(self):
|
||||
return []
|
||||
|
||||
@@ -6,19 +6,13 @@ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
class ChunkCacheEntry:
|
||||
def __init__(self, rid: str, value: torch.Tensor):
|
||||
self.rid = rid
|
||||
self.value = value
|
||||
|
||||
|
||||
class ChunkCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,13 +23,16 @@ class ChunkCache(BasePrefixCache):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = page_size
|
||||
self.disable = True
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
|
||||
return [], None
|
||||
def match_prefix(self, **unused_kwargs) -> MatchResult:
|
||||
return MatchResult(
|
||||
device_indices=torch.empty((0,), dtype=torch.int64),
|
||||
last_device_node=None,
|
||||
last_host_node=None,
|
||||
)
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
@@ -54,9 +51,6 @@ class ChunkCache(BasePrefixCache):
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
req.prefix_indices = kv_indices
|
||||
|
||||
def insert(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def evict(self, num_tokens: int):
|
||||
pass
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
@@ -283,41 +284,44 @@ class HiRadixCache(RadixCache):
|
||||
def init_load_back(
|
||||
self,
|
||||
last_node: TreeNode,
|
||||
prefix_indices: torch.Tensor,
|
||||
host_hit_length: int,
|
||||
mem_quota: Optional[int] = None,
|
||||
):
|
||||
assert (
|
||||
len(prefix_indices) == 0 or prefix_indices.is_cuda
|
||||
), "indices of device kV caches should be on GPU"
|
||||
_ = host_hit_length # unused, but kept for compatibility
|
||||
if last_node.evicted:
|
||||
loading_values = self.load_back(last_node, mem_quota)
|
||||
if loading_values is not None:
|
||||
prefix_indices = (
|
||||
loading_values
|
||||
if len(prefix_indices) == 0
|
||||
else torch.cat([prefix_indices, loading_values])
|
||||
)
|
||||
logger.debug(
|
||||
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
||||
)
|
||||
return loading_values, last_node
|
||||
|
||||
while last_node.evicted:
|
||||
last_node = last_node.parent
|
||||
|
||||
return last_node, prefix_indices
|
||||
return (
|
||||
torch.empty((0,), dtype=torch.int64, device=self.device),
|
||||
last_node,
|
||||
)
|
||||
|
||||
def ready_to_load_cache(self):
|
||||
def ready_to_load_host_cache(self):
|
||||
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
||||
self.load_cache_event.set()
|
||||
return producer_index
|
||||
|
||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||
def check_hicache_events(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
if self.disable or len(key) == 0:
|
||||
if include_evicted:
|
||||
return empty_value, self.root_node, self.root_node
|
||||
else:
|
||||
return empty_value, self.root_node
|
||||
return MatchResult(
|
||||
device_indices=empty_value,
|
||||
last_device_node=self.root_node,
|
||||
last_host_node=self.root_node,
|
||||
host_hit_length=0,
|
||||
)
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(key) // self.page_size * self.page_size
|
||||
@@ -329,14 +333,18 @@ class HiRadixCache(RadixCache):
|
||||
else:
|
||||
value = empty_value
|
||||
|
||||
last_node_global = last_node
|
||||
host_hit_length = 0
|
||||
last_host_node = last_node
|
||||
while last_node.evicted:
|
||||
host_hit_length += len(last_node.host_value)
|
||||
last_node = last_node.parent
|
||||
|
||||
if include_evicted:
|
||||
return value, last_node, last_node_global
|
||||
else:
|
||||
return value, last_node
|
||||
return MatchResult(
|
||||
device_indices=value,
|
||||
last_device_node=last_node,
|
||||
last_host_node=last_host_node,
|
||||
host_hit_length=host_hit_length,
|
||||
)
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.monotonic()
|
||||
|
||||
@@ -33,8 +33,7 @@ from sglang.srt.disaggregation.kv_events import (
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,9 +46,9 @@ class TreeNode:
|
||||
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent = None
|
||||
self.key = None
|
||||
self.value = None
|
||||
self.parent: TreeNode = None
|
||||
self.key: List[int] = None
|
||||
self.value: Optional[torch.Tensor] = None
|
||||
self.lock_ref = 0
|
||||
self.last_access_time = time.monotonic()
|
||||
|
||||
@@ -57,7 +56,7 @@ class TreeNode:
|
||||
# indicating the node is loading KV cache from host
|
||||
self.loading = False
|
||||
# store the host indices of KV cache
|
||||
self.host_value = None
|
||||
self.host_value: Optional[torch.Tensor] = None
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.protected_size_ = 0
|
||||
self._record_all_cleared_event()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
Args:
|
||||
key: A list of token IDs to find a matching prefix.
|
||||
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
|
||||
than the last node's value.
|
||||
"""
|
||||
if self.disable or len(key) == 0:
|
||||
return (
|
||||
torch.empty(
|
||||
return MatchResult(
|
||||
device_indices=torch.empty(
|
||||
(0,),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
),
|
||||
self.root_node,
|
||||
last_device_node=self.root_node,
|
||||
last_host_node=self.root_node,
|
||||
)
|
||||
|
||||
if self.page_size != 1:
|
||||
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
return value, last_node
|
||||
return MatchResult(
|
||||
device_indices=value,
|
||||
last_device_node=last_node,
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
def insert(self, key: List, value=None):
|
||||
if self.disable:
|
||||
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
|
||||
Reference in New Issue
Block a user