[Refactor] Clean up radix cache related API (#7303)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -38,7 +38,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -436,7 +436,7 @@ class Req:
|
|||||||
self,
|
self,
|
||||||
rid: str,
|
rid: str,
|
||||||
origin_input_text: str,
|
origin_input_text: str,
|
||||||
origin_input_ids: Tuple[int],
|
origin_input_ids: List[int],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
return_logprob: bool = False,
|
return_logprob: bool = False,
|
||||||
top_logprobs_num: int = 0,
|
top_logprobs_num: int = 0,
|
||||||
@@ -467,7 +467,7 @@ class Req:
|
|||||||
# Each decode stage's output ids
|
# Each decode stage's output ids
|
||||||
self.output_ids = []
|
self.output_ids = []
|
||||||
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
||||||
self.fill_ids = None
|
self.fill_ids = []
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.input_embeds = input_embeds
|
self.input_embeds = input_embeds
|
||||||
|
|
||||||
@@ -519,13 +519,14 @@ class Req:
|
|||||||
|
|
||||||
# Prefix info
|
# Prefix info
|
||||||
# The indices to kv cache for the shared prefix.
|
# The indices to kv cache for the shared prefix.
|
||||||
self.prefix_indices = []
|
self.prefix_indices: torch.Tensor = []
|
||||||
# Number of tokens to run prefill.
|
# Number of tokens to run prefill.
|
||||||
self.extend_input_len = 0
|
self.extend_input_len = 0
|
||||||
# The relative logprob_start_len in an extend batch
|
# The relative logprob_start_len in an extend batch
|
||||||
self.extend_logprob_start_len = 0
|
self.extend_logprob_start_len = 0
|
||||||
self.last_node = None
|
self.last_node: Any = None
|
||||||
self.last_node_global = None
|
self.last_host_node: Any = None
|
||||||
|
self.host_hit_length = 0
|
||||||
|
|
||||||
# Whether or not if it is chunked. It increments whenever
|
# Whether or not if it is chunked. It increments whenever
|
||||||
# it is chunked, and decrement whenever chunked request is
|
# it is chunked, and decrement whenever chunked request is
|
||||||
@@ -644,21 +645,17 @@ class Req:
|
|||||||
def init_next_round_input(
|
def init_next_round_input(
|
||||||
self,
|
self,
|
||||||
tree_cache: Optional[BasePrefixCache] = None,
|
tree_cache: Optional[BasePrefixCache] = None,
|
||||||
enable_hierarchical_cache=False,
|
|
||||||
):
|
):
|
||||||
self.fill_ids = self.origin_input_ids + self.output_ids
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||||
if tree_cache is not None:
|
if tree_cache is not None:
|
||||||
# tree cache is None if the prefix is not computed with tree cache.
|
(
|
||||||
if enable_hierarchical_cache:
|
self.prefix_indices,
|
||||||
self.prefix_indices, self.last_node, self.last_node_global = (
|
self.last_node,
|
||||||
tree_cache.match_prefix(
|
self.last_host_node,
|
||||||
key=self.adjust_max_prefix_ids(), include_evicted=True
|
self.host_hit_length,
|
||||||
)
|
) = tree_cache.match_prefix(
|
||||||
)
|
key=self.adjust_max_prefix_ids(),
|
||||||
else:
|
)
|
||||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
|
||||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
|
||||||
)
|
|
||||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||||
|
|
||||||
def adjust_max_prefix_ids(self):
|
def adjust_max_prefix_ids(self):
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class SchedulePolicy:
|
|||||||
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
||||||
if self.policy == CacheAgnosticPolicy.FCFS:
|
if self.policy == CacheAgnosticPolicy.FCFS:
|
||||||
# A shortcut for FCFS
|
# A shortcut for FCFS
|
||||||
return
|
return False
|
||||||
|
|
||||||
policy = self._determine_active_policy(waiting_queue)
|
policy = self._determine_active_policy(waiting_queue)
|
||||||
|
|
||||||
@@ -134,7 +134,7 @@ class SchedulePolicy:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
policy_enum = CacheAwarePolicy(policy)
|
policy_enum = CacheAwarePolicy(policy)
|
||||||
if tree_cache.disable:
|
if getattr(tree_cache, "disable", True):
|
||||||
# If tree_cache is disabled, using CacheAgnosticPolicy policy
|
# If tree_cache is disabled, using CacheAgnosticPolicy policy
|
||||||
return CacheAgnosticPolicy.FCFS
|
return CacheAgnosticPolicy.FCFS
|
||||||
return policy_enum
|
return policy_enum
|
||||||
@@ -158,14 +158,9 @@ class SchedulePolicy:
|
|||||||
prefix_ids = r.adjust_max_prefix_ids()
|
prefix_ids = r.adjust_max_prefix_ids()
|
||||||
|
|
||||||
# NOTE: the prefix_indices must always be aligned with last_node
|
# NOTE: the prefix_indices must always be aligned with last_node
|
||||||
if self.enable_hierarchical_cache:
|
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
|
||||||
r.prefix_indices, r.last_node, r.last_node_global = (
|
self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
|
||||||
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
|
||||||
rid=r.rid, key=prefix_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(sang): This logic is for in-batch prefix caching;
|
# NOTE(sang): This logic is for in-batch prefix caching;
|
||||||
# If there are more than 1 request that have small matching prefix from
|
# If there are more than 1 request that have small matching prefix from
|
||||||
@@ -175,7 +170,7 @@ class SchedulePolicy:
|
|||||||
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
||||||
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
||||||
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
||||||
in_batch_matching_prefixes, _ = (
|
in_batch_matching_prefixes, _, _, _ = (
|
||||||
self.waiting_queue_radix_tree.match_prefix(
|
self.waiting_queue_radix_tree.match_prefix(
|
||||||
rid=r.rid, key=prefix_ids
|
rid=r.rid, key=prefix_ids
|
||||||
)
|
)
|
||||||
@@ -268,6 +263,7 @@ class AddReqResult(Enum):
|
|||||||
class PrefillAdder:
|
class PrefillAdder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
page_size: int,
|
||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
running_batch: ScheduleBatch,
|
running_batch: ScheduleBatch,
|
||||||
@@ -276,6 +272,7 @@ class PrefillAdder:
|
|||||||
rem_chunk_tokens: Optional[int],
|
rem_chunk_tokens: Optional[int],
|
||||||
mixed_with_decode_tokens: int = 0,
|
mixed_with_decode_tokens: int = 0,
|
||||||
):
|
):
|
||||||
|
self.page_size = page_size
|
||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
self.running_batch = running_batch
|
self.running_batch = running_batch
|
||||||
@@ -442,46 +439,43 @@ class PrefillAdder:
|
|||||||
|
|
||||||
return self.budget_state()
|
return self.budget_state()
|
||||||
|
|
||||||
def add_one_req(
|
def add_one_req(self, req: Req, has_chunked_req: bool):
|
||||||
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
|
|
||||||
):
|
|
||||||
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
||||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||||
|
|
||||||
total_tokens = req.extend_input_len + min(
|
total_tokens = req.extend_input_len + min(
|
||||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
||||||
)
|
)
|
||||||
input_tokens = (
|
|
||||||
-(-req.extend_input_len // self.tree_cache.page_size)
|
# adjusting the input_tokens based on host_hit_length and page_size
|
||||||
* self.tree_cache.page_size
|
real_input_tokens = req.extend_input_len - req.host_hit_length
|
||||||
)
|
real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size
|
||||||
prefix_len = len(req.prefix_indices)
|
prefix_len = len(req.prefix_indices)
|
||||||
|
|
||||||
if total_tokens >= self.rem_total_tokens:
|
if total_tokens >= self.rem_total_tokens:
|
||||||
return AddReqResult.NO_TOKEN
|
return AddReqResult.NO_TOKEN
|
||||||
|
|
||||||
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
|
||||||
return AddReqResult.OTHER
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
with self._lock_node(req.last_node):
|
with self._lock_node(req.last_node):
|
||||||
if total_tokens > self.rem_total_tokens:
|
# self.rem_total_tokens may decrease after the lock acquisition
|
||||||
|
if total_tokens >= self.rem_total_tokens:
|
||||||
return AddReqResult.NO_TOKEN
|
return AddReqResult.NO_TOKEN
|
||||||
|
|
||||||
if (
|
if req.host_hit_length > 0:
|
||||||
enable_hierarchical_cache
|
new_indices, req.last_node = self.tree_cache.init_load_back(
|
||||||
and req.last_node_global is not None
|
req.last_host_node, req.host_hit_length
|
||||||
and req.last_node_global.evicted
|
|
||||||
):
|
|
||||||
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
|
||||||
req.last_node_global, req.prefix_indices
|
|
||||||
)
|
)
|
||||||
|
req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
|
||||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||||
input_tokens = (
|
|
||||||
-(-req.extend_input_len // self.tree_cache.page_size)
|
|
||||||
* self.tree_cache.page_size
|
|
||||||
)
|
|
||||||
prefix_len = len(req.prefix_indices)
|
prefix_len = len(req.prefix_indices)
|
||||||
|
|
||||||
|
input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size
|
||||||
|
|
||||||
|
if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
|
||||||
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
||||||
# Non-chunked prefill
|
# Non-chunked prefill
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
@@ -496,7 +490,7 @@ class PrefillAdder:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Make sure at least one page is available
|
# Make sure at least one page is available
|
||||||
trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
|
trunc_len = self.rem_chunk_tokens - self.page_size + 1
|
||||||
if trunc_len <= 0:
|
if trunc_len <= 0:
|
||||||
return AddReqResult.OTHER
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
|
|||||||
@@ -1467,15 +1467,14 @@ class Scheduler(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if self.enable_hierarchical_cache:
|
if self.enable_hierarchical_cache:
|
||||||
# check for completion of hierarchical cache activities to release memory
|
self.tree_cache.check_hicache_events()
|
||||||
self.tree_cache.writing_check()
|
|
||||||
self.tree_cache.loading_check()
|
|
||||||
|
|
||||||
# Get priority queue
|
# Get priority queue
|
||||||
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
self.policy.calc_priority(self.waiting_queue)
|
||||||
|
|
||||||
# Prefill policy
|
# Prefill policy
|
||||||
adder = PrefillAdder(
|
adder = PrefillAdder(
|
||||||
|
self.page_size,
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.token_to_kv_pool_allocator,
|
self.token_to_kv_pool_allocator,
|
||||||
self.running_batch,
|
self.running_batch,
|
||||||
@@ -1517,19 +1516,8 @@ class Scheduler(
|
|||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
# bypass prefix_computed if enable_hierarchical_cache
|
req.init_next_round_input(self.tree_cache)
|
||||||
req.init_next_round_input(
|
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
||||||
(
|
|
||||||
None
|
|
||||||
if (prefix_computed and not self.enable_hierarchical_cache)
|
|
||||||
else self.tree_cache
|
|
||||||
),
|
|
||||||
self.enable_hierarchical_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
res = adder.add_one_req(
|
|
||||||
req, self.chunked_req, self.enable_hierarchical_cache
|
|
||||||
)
|
|
||||||
|
|
||||||
if res != AddReqResult.CONTINUE:
|
if res != AddReqResult.CONTINUE:
|
||||||
if res == AddReqResult.NO_TOKEN:
|
if res == AddReqResult.NO_TOKEN:
|
||||||
@@ -1581,7 +1569,9 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
if self.enable_hierarchical_cache:
|
if self.enable_hierarchical_cache:
|
||||||
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
|
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
|
||||||
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()
|
new_batch.hicache_consumer_index = (
|
||||||
|
self.tree_cache.ready_to_load_host_cache()
|
||||||
|
)
|
||||||
|
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,31 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Tuple
|
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
|
else:
|
||||||
|
Req = Any # Placeholder for Req type when not type checking
|
||||||
|
|
||||||
|
|
||||||
|
class MatchResult(NamedTuple):
|
||||||
|
"""Result of a prefix match operation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
device_indices : Indices of the KV cache on the device matched by common prefix.
|
||||||
|
last_device_node: The last TreeNode on the device that was matched.
|
||||||
|
last_host_node : The last TreeNode on the host that was matched.
|
||||||
|
Note that if HiCache is not enabled,
|
||||||
|
this **must** be the same as `last_device_node`.
|
||||||
|
host_hit_length : Length of the KV cache hit on the host, if applicable.
|
||||||
|
0 if HiCache is not enabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_indices: torch.Tensor
|
||||||
|
last_device_node: Any
|
||||||
|
last_host_node: Any
|
||||||
|
host_hit_length: int = 0
|
||||||
|
|
||||||
|
|
||||||
class BasePrefixCache(ABC):
|
class BasePrefixCache(ABC):
|
||||||
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
|
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def insert(self, **kwargs):
|
def cache_finished_req(self, req: Req, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cache_finished_req(self, **kwargs):
|
def cache_unfinished_req(self, req: Req, **kwargs):
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def cache_unfinished_req(self, **kwargs):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
|
|||||||
def pretty_print(self):
|
def pretty_print(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def init_load_back(
|
||||||
|
self,
|
||||||
|
last_host_node: Any,
|
||||||
|
host_hit_length: int,
|
||||||
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
|
"""
|
||||||
|
Preparing KV cache loading from host to device.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def ready_to_load_host_cache(self) -> Any:
|
||||||
|
"""
|
||||||
|
Notify the cache controller to start the KV cache loading
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def check_hicache_events(self) -> Any:
|
||||||
|
"""
|
||||||
|
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def take_events(self):
|
def take_events(self):
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -6,19 +6,13 @@ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
|
|
||||||
|
|
||||||
class ChunkCacheEntry:
|
|
||||||
def __init__(self, rid: str, value: torch.Tensor):
|
|
||||||
self.rid = rid
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkCache(BasePrefixCache):
|
class ChunkCache(BasePrefixCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -29,13 +23,16 @@ class ChunkCache(BasePrefixCache):
|
|||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
self.disable = True
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
|
def match_prefix(self, **unused_kwargs) -> MatchResult:
|
||||||
return [], None
|
return MatchResult(
|
||||||
|
device_indices=torch.empty((0,), dtype=torch.int64),
|
||||||
|
last_device_node=None,
|
||||||
|
last_host_node=None,
|
||||||
|
)
|
||||||
|
|
||||||
def cache_finished_req(self, req: Req):
|
def cache_finished_req(self, req: Req):
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
@@ -54,9 +51,6 @@ class ChunkCache(BasePrefixCache):
|
|||||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||||
req.prefix_indices = kv_indices
|
req.prefix_indices = kv_indices
|
||||||
|
|
||||||
def insert(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def evict(self, num_tokens: int):
|
def evict(self, num_tokens: int):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.cache_controller import HiCacheController
|
from sglang.srt.managers.cache_controller import HiCacheController
|
||||||
|
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
MHATokenToKVPool,
|
MHATokenToKVPool,
|
||||||
MLATokenToKVPool,
|
MLATokenToKVPool,
|
||||||
@@ -283,41 +284,44 @@ class HiRadixCache(RadixCache):
|
|||||||
def init_load_back(
|
def init_load_back(
|
||||||
self,
|
self,
|
||||||
last_node: TreeNode,
|
last_node: TreeNode,
|
||||||
prefix_indices: torch.Tensor,
|
host_hit_length: int,
|
||||||
mem_quota: Optional[int] = None,
|
mem_quota: Optional[int] = None,
|
||||||
):
|
):
|
||||||
assert (
|
_ = host_hit_length # unused, but kept for compatibility
|
||||||
len(prefix_indices) == 0 or prefix_indices.is_cuda
|
|
||||||
), "indices of device kV caches should be on GPU"
|
|
||||||
if last_node.evicted:
|
if last_node.evicted:
|
||||||
loading_values = self.load_back(last_node, mem_quota)
|
loading_values = self.load_back(last_node, mem_quota)
|
||||||
if loading_values is not None:
|
if loading_values is not None:
|
||||||
prefix_indices = (
|
|
||||||
loading_values
|
|
||||||
if len(prefix_indices) == 0
|
|
||||||
else torch.cat([prefix_indices, loading_values])
|
|
||||||
)
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
||||||
)
|
)
|
||||||
|
return loading_values, last_node
|
||||||
|
|
||||||
while last_node.evicted:
|
while last_node.evicted:
|
||||||
last_node = last_node.parent
|
last_node = last_node.parent
|
||||||
|
|
||||||
return last_node, prefix_indices
|
return (
|
||||||
|
torch.empty((0,), dtype=torch.int64, device=self.device),
|
||||||
|
last_node,
|
||||||
|
)
|
||||||
|
|
||||||
def ready_to_load_cache(self):
|
def ready_to_load_host_cache(self):
|
||||||
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
||||||
self.load_cache_event.set()
|
self.load_cache_event.set()
|
||||||
return producer_index
|
return producer_index
|
||||||
|
|
||||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
def check_hicache_events(self):
|
||||||
|
self.writing_check()
|
||||||
|
self.loading_check()
|
||||||
|
|
||||||
|
def match_prefix(self, key: List[int], **kwargs):
|
||||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||||
if self.disable or len(key) == 0:
|
if self.disable or len(key) == 0:
|
||||||
if include_evicted:
|
return MatchResult(
|
||||||
return empty_value, self.root_node, self.root_node
|
device_indices=empty_value,
|
||||||
else:
|
last_device_node=self.root_node,
|
||||||
return empty_value, self.root_node
|
last_host_node=self.root_node,
|
||||||
|
host_hit_length=0,
|
||||||
|
)
|
||||||
|
|
||||||
if self.page_size != 1:
|
if self.page_size != 1:
|
||||||
page_aligned_len = len(key) // self.page_size * self.page_size
|
page_aligned_len = len(key) // self.page_size * self.page_size
|
||||||
@@ -329,14 +333,18 @@ class HiRadixCache(RadixCache):
|
|||||||
else:
|
else:
|
||||||
value = empty_value
|
value = empty_value
|
||||||
|
|
||||||
last_node_global = last_node
|
host_hit_length = 0
|
||||||
|
last_host_node = last_node
|
||||||
while last_node.evicted:
|
while last_node.evicted:
|
||||||
|
host_hit_length += len(last_node.host_value)
|
||||||
last_node = last_node.parent
|
last_node = last_node.parent
|
||||||
|
|
||||||
if include_evicted:
|
return MatchResult(
|
||||||
return value, last_node, last_node_global
|
device_indices=value,
|
||||||
else:
|
last_device_node=last_node,
|
||||||
return value, last_node
|
last_host_node=last_host_node,
|
||||||
|
host_hit_length=host_hit_length,
|
||||||
|
)
|
||||||
|
|
||||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||||
node.last_access_time = time.monotonic()
|
node.last_access_time = time.monotonic()
|
||||||
|
|||||||
@@ -33,8 +33,7 @@ from sglang.srt.disaggregation.kv_events import (
|
|||||||
BlockStored,
|
BlockStored,
|
||||||
KVCacheEvent,
|
KVCacheEvent,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -47,9 +46,9 @@ class TreeNode:
|
|||||||
|
|
||||||
def __init__(self, id: Optional[int] = None):
|
def __init__(self, id: Optional[int] = None):
|
||||||
self.children = defaultdict(TreeNode)
|
self.children = defaultdict(TreeNode)
|
||||||
self.parent = None
|
self.parent: TreeNode = None
|
||||||
self.key = None
|
self.key: List[int] = None
|
||||||
self.value = None
|
self.value: Optional[torch.Tensor] = None
|
||||||
self.lock_ref = 0
|
self.lock_ref = 0
|
||||||
self.last_access_time = time.monotonic()
|
self.last_access_time = time.monotonic()
|
||||||
|
|
||||||
@@ -57,7 +56,7 @@ class TreeNode:
|
|||||||
# indicating the node is loading KV cache from host
|
# indicating the node is loading KV cache from host
|
||||||
self.loading = False
|
self.loading = False
|
||||||
# store the host indices of KV cache
|
# store the host indices of KV cache
|
||||||
self.host_value = None
|
self.host_value: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
self.id = TreeNode.counter if id is None else id
|
self.id = TreeNode.counter if id is None else id
|
||||||
TreeNode.counter += 1
|
TreeNode.counter += 1
|
||||||
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
self.protected_size_ = 0
|
self.protected_size_ = 0
|
||||||
self._record_all_cleared_event()
|
self._record_all_cleared_event()
|
||||||
|
|
||||||
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||||
"""Find the matching prefix from the radix tree.
|
"""Find the matching prefix from the radix tree.
|
||||||
Args:
|
Args:
|
||||||
key: A list of token IDs to find a matching prefix.
|
key: A list of token IDs to find a matching prefix.
|
||||||
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
|
|||||||
than the last node's value.
|
than the last node's value.
|
||||||
"""
|
"""
|
||||||
if self.disable or len(key) == 0:
|
if self.disable or len(key) == 0:
|
||||||
return (
|
return MatchResult(
|
||||||
torch.empty(
|
device_indices=torch.empty(
|
||||||
(0,),
|
(0,),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
self.root_node,
|
last_device_node=self.root_node,
|
||||||
|
last_host_node=self.root_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.page_size != 1:
|
if self.page_size != 1:
|
||||||
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
|
|||||||
value = torch.cat(value)
|
value = torch.cat(value)
|
||||||
else:
|
else:
|
||||||
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||||
return value, last_node
|
return MatchResult(
|
||||||
|
device_indices=value,
|
||||||
|
last_device_node=last_node,
|
||||||
|
last_host_node=last_node,
|
||||||
|
)
|
||||||
|
|
||||||
def insert(self, key: List, value=None):
|
def insert(self, key: List, value=None):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The prefix indices could be updated, reuse it
|
# The prefix indices could be updated, reuse it
|
||||||
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
||||||
self.req_to_token_pool.write(
|
self.req_to_token_pool.write(
|
||||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||||
new_indices[len(req.prefix_indices) :],
|
new_indices[len(req.prefix_indices) :],
|
||||||
|
|||||||
Reference in New Issue
Block a user