[Feature] Support mamba radix cache v0 (#11214)
Co-authored-by: hanming-lu <hanming@x.ai> Co-authored-by: hzh0425 <hzh0425@apache.org> Co-authored-by: thalahors <ericalcaide1@gmail.com>
This commit is contained in:
@@ -65,7 +65,8 @@ from sglang.srt.mem_cache.common import (
|
||||
alloc_for_extend,
|
||||
alloc_token_slots,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
@@ -522,6 +523,7 @@ class Req:
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx: Optional[int] = None
|
||||
self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
@@ -727,7 +729,12 @@ class Req:
|
||||
self.last_host_node,
|
||||
self.host_hit_length,
|
||||
) = tree_cache.match_prefix(
|
||||
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
|
||||
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
|
||||
**(
|
||||
{"req": self, "cow_mamba": True}
|
||||
if isinstance(tree_cache, MambaRadixCache)
|
||||
else {}
|
||||
),
|
||||
)
|
||||
self.last_matched_prefix_len = len(self.prefix_indices)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
@@ -877,6 +884,7 @@ class Req:
|
||||
self.extend_logprob_start_len = 0
|
||||
self.is_chunked = 0
|
||||
self.req_pool_idx = None
|
||||
self.mamba_pool_idx = None
|
||||
self.already_computed = 0
|
||||
|
||||
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
||||
@@ -1071,6 +1079,27 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
||||
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
||||
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
|
||||
if mamba_available_size < num_reqs:
|
||||
if self.tree_cache is not None and isinstance(
|
||||
self.tree_cache, MambaRadixCache
|
||||
):
|
||||
mamba_num = max(0, num_reqs - mamba_available_size)
|
||||
self.tree_cache.evict_mamba(mamba_num)
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
||||
else:
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
"alloc_req_slots runs out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`. "
|
||||
f"{self.req_to_token_pool.available_size()=}, "
|
||||
f"{num_reqs=}, "
|
||||
)
|
||||
return req_pool_indices
|
||||
|
||||
def allocate_for_eagle_v2(self):
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
|
||||
|
||||
@@ -27,6 +27,7 @@ import torch
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
@@ -357,6 +358,7 @@ class PrefillAdder:
|
||||
self.is_hybrid = isinstance(
|
||||
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
||||
)
|
||||
self.is_hybrid_gdn_cache = isinstance(self.tree_cache, MambaRadixCache)
|
||||
|
||||
self.priority_scheduling_preemption_threshold = (
|
||||
priority_scheduling_preemption_threshold
|
||||
@@ -380,6 +382,11 @@ class PrefillAdder:
|
||||
self.token_to_kv_pool_allocator.swa_available_size()
|
||||
+ self.tree_cache.swa_evictable_size(),
|
||||
)
|
||||
elif self.is_hybrid_gdn_cache:
|
||||
available_and_evictable = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.full_evictable_size()
|
||||
)
|
||||
else:
|
||||
available_and_evictable = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
@@ -397,6 +404,11 @@ class PrefillAdder:
|
||||
self.token_to_kv_pool_allocator.swa_available_size()
|
||||
+ self.tree_cache.swa_evictable_size(),
|
||||
)
|
||||
elif self.is_hybrid_gdn_cache:
|
||||
available_and_evictable = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.full_evictable_size()
|
||||
)
|
||||
else:
|
||||
available_and_evictable = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
|
||||
@@ -146,6 +146,7 @@ from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
@@ -470,6 +471,10 @@ class Scheduler(
|
||||
|
||||
# Hybrid memory pool
|
||||
self.is_hybrid = self.tp_worker.is_hybrid
|
||||
self.is_hybrid_gdn = (
|
||||
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
|
||||
)
|
||||
|
||||
if self.is_hybrid:
|
||||
self.sliding_window_size = self.tp_worker.sliding_window_size
|
||||
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
|
||||
@@ -816,6 +821,16 @@ class Scheduler(
|
||||
disable=server_args.disable_radix_cache,
|
||||
is_eagle=self.spec_algorithm.is_eagle(),
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
assert (
|
||||
self.server_args.disaggregation_mode == "null"
|
||||
), "Hybrid GDN mode does not support disaggregation yet"
|
||||
self.tree_cache = MambaRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
elif server_args.enable_lmcache:
|
||||
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
||||
LMCRadixCache,
|
||||
@@ -1689,6 +1704,25 @@ class Scheduler(
|
||||
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
||||
)
|
||||
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
|
||||
(
|
||||
full_num_used,
|
||||
mamba_num_used,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
mamba_available_size,
|
||||
mamba_evictable_size,
|
||||
) = self._get_mamba_token_info()
|
||||
memory_leak = (
|
||||
full_num_used != self.tree_cache.full_protected_size()
|
||||
or mamba_num_used != self.tree_cache.mamba_protected_size()
|
||||
)
|
||||
token_msg = (
|
||||
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
|
||||
)
|
||||
else:
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
@@ -1739,6 +1773,17 @@ class Scheduler(
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
elif self.is_hybrid_gdn:
|
||||
(
|
||||
num_used,
|
||||
_,
|
||||
token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_mamba_token_info()
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
@@ -1766,7 +1811,9 @@ class Scheduler(
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_tree_cache(self):
|
||||
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
|
||||
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
|
||||
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
|
||||
):
|
||||
self.tree_cache.sanity_check()
|
||||
|
||||
def _get_token_info(self):
|
||||
@@ -1776,6 +1823,35 @@ class Scheduler(
|
||||
token_usage = num_used / self.max_total_num_tokens
|
||||
return num_used, token_usage, available_size, evictable_size
|
||||
|
||||
def _get_mamba_token_info(self):
|
||||
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
|
||||
full_available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
full_evictable_size = (
|
||||
self.tree_cache.full_evictable_size() if is_radix_tree else 0
|
||||
)
|
||||
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
|
||||
mamba_evictable_size = (
|
||||
self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
|
||||
)
|
||||
full_num_used = self.token_to_kv_pool_allocator.size - (
|
||||
full_available_size + full_evictable_size
|
||||
)
|
||||
mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
|
||||
mamba_available_size + mamba_evictable_size
|
||||
)
|
||||
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
|
||||
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
|
||||
return (
|
||||
full_num_used,
|
||||
mamba_num_used,
|
||||
full_token_usage,
|
||||
mamba_usage,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
mamba_available_size,
|
||||
mamba_evictable_size,
|
||||
)
|
||||
|
||||
def _get_swa_token_info(self):
|
||||
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
full_evictable_size = self.tree_cache.full_evictable_size()
|
||||
|
||||
@@ -104,6 +104,23 @@ class SchedulerMetricsMixin:
|
||||
f"full token usage: {full_token_usage:.2f}, "
|
||||
f"swa token usage: {swa_token_usage:.2f}, "
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
(
|
||||
full_num_used,
|
||||
_,
|
||||
full_token_usage,
|
||||
mamba_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_mamba_token_info()
|
||||
num_used = full_num_used
|
||||
token_usage = full_token_usage
|
||||
token_usage_msg = (
|
||||
f"full token usage: {full_token_usage:.2f}, "
|
||||
f"mamba usage: {mamba_usage:.2f}, "
|
||||
)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
token_usage_msg = f"token usage: {token_usage:.2f}, "
|
||||
@@ -203,6 +220,25 @@ class SchedulerMetricsMixin:
|
||||
f"#swa token: {swa_num_used}, "
|
||||
f"swa token usage: {swa_token_usage:.2f}, "
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
(
|
||||
full_num_used,
|
||||
mamba_used,
|
||||
full_token_usage,
|
||||
mamba_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_mamba_token_info()
|
||||
num_used = full_num_used
|
||||
token_usage = full_token_usage
|
||||
token_usage_msg = (
|
||||
f"#full token: {full_num_used}, "
|
||||
f"full token usage: {full_token_usage:.2f}, "
|
||||
f"mamba num: {mamba_used}, "
|
||||
f"mamba usage: {mamba_usage:.2f}, "
|
||||
)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
|
||||
|
||||
995
python/sglang/srt/mem_cache/mamba_radix_cache.py
Normal file
995
python/sglang/srt/mem_cache/mamba_radix_cache.py
Normal file
@@ -0,0 +1,995 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
The radix tree data structure for managing the hybrid (full and Mamba) KV cache.
|
||||
"""
|
||||
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import (
|
||||
RadixKey,
|
||||
_key_match_page_size1,
|
||||
_key_match_paged,
|
||||
get_child_key,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TreeNode:
|
||||
|
||||
counter = 0
|
||||
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent: TreeNode = None
|
||||
self.key: RadixKey = None
|
||||
self.value: Optional[torch.Tensor] = None
|
||||
self.mamba_value: Optional[torch.Tensor] = None
|
||||
# invariant: for any node, if mamba_lock_ref is locked, full_lock_ref must be locked;
|
||||
# if full_lock_ref is locked, mamba_lock_ref doesn't need to be locked. So,
|
||||
# full_lock_ref is always >= mamba_lock_ref.
|
||||
# for full_lock, once it is locked, its parent must be locked as well
|
||||
# for mamba_lock, it only need lock node itself
|
||||
self.full_lock_ref = 0
|
||||
self.mamba_lock_ref = 0
|
||||
# last access time is only used for sanity check. LRU is maintained by the lru list.
|
||||
self.last_access_time = time.monotonic()
|
||||
|
||||
self.hit_count = 0
|
||||
# store the host indices of KV cache
|
||||
self.host_value = None
|
||||
|
||||
# for lru list, invariant:
|
||||
# 1. prev has greater last_access_time
|
||||
# 2. next has smaller last_access_time
|
||||
self.prev = None
|
||||
self.next = None
|
||||
self.mamba_prev = None
|
||||
self.mamba_next = None
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
|
||||
@property
|
||||
def evicted(self):
|
||||
return self.value is None
|
||||
|
||||
@property
|
||||
def backuped(self):
|
||||
return self.host_value is not None
|
||||
|
||||
def __lt__(self, other: "TreeNode"):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
class LRUList:
|
||||
def __init__(self, mamba: bool = False):
|
||||
self.mamba = mamba
|
||||
if self.mamba:
|
||||
self.prv = "mamba_prev"
|
||||
self.nxt = "mamba_next"
|
||||
self.lock_ref = "mamba_lock_ref"
|
||||
else:
|
||||
self.prv = "prev"
|
||||
self.nxt = "next"
|
||||
self.lock_ref = "full_lock_ref"
|
||||
# Initialize dummy head and tail nodes
|
||||
self.head = TreeNode() # Most recently used side
|
||||
self.tail = TreeNode() # Least recently used side
|
||||
setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail
|
||||
setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head
|
||||
self.cache = {}
|
||||
|
||||
def _add_node(self, node):
|
||||
"""Helper to add node right after head (most recently used)"""
|
||||
self._add_node_after(self.head, node)
|
||||
|
||||
def _add_node_after(self, old_node, new_node):
|
||||
"""Helper to add node right after old_node"""
|
||||
setattr(new_node, self.prv, old_node) # new_node.prev = old_node
|
||||
setattr(
|
||||
new_node, self.nxt, getattr(old_node, self.nxt)
|
||||
) # new_node.next = old_node.next
|
||||
setattr(
|
||||
getattr(old_node, self.nxt), self.prv, new_node
|
||||
) # old_node.next.prev = new_node
|
||||
setattr(old_node, self.nxt, new_node) # old_node.next = new_node
|
||||
|
||||
def _remove_node(self, node):
|
||||
"""Helper to remove node from linked list"""
|
||||
setattr(
|
||||
getattr(node, self.prv), self.nxt, getattr(node, self.nxt)
|
||||
) # node.prev.next = node.next
|
||||
setattr(
|
||||
getattr(node, self.nxt), self.prv, getattr(node, self.prv)
|
||||
) # node.next.prev = node.prev
|
||||
|
||||
def _get_lru(self) -> Optional[TreeNode]:
|
||||
"""
|
||||
Get the least recently used node
|
||||
"""
|
||||
if len(self.cache) == 0:
|
||||
return None
|
||||
return getattr(self.tail, self.prv)
|
||||
|
||||
def reset_node_mru(self, node):
|
||||
"""
|
||||
Move a (existing) node to most recently used position
|
||||
"""
|
||||
assert node.id in self.cache, f"Resetting node {node.id=} not in lru list"
|
||||
assert (
|
||||
not self.mamba or node.mamba_value is not None
|
||||
), f"Resetting mamba tombstone node in mamba lru list: {node.id=}"
|
||||
self._remove_node(node)
|
||||
self._add_node(node)
|
||||
|
||||
def reset_node_and_parents_mru(self, node, root_node):
|
||||
"""
|
||||
Move an (existing) node and its parents to most recently used position. Child node is
|
||||
more recently used than parent node.
|
||||
"""
|
||||
prev_node = self.head
|
||||
while node != root_node:
|
||||
if not self.mamba or node.mamba_value is not None:
|
||||
assert (
|
||||
node.id in self.cache
|
||||
), f"Resetting node {node.id=} not in lru list when resetting node and parents mru"
|
||||
self._remove_node(node)
|
||||
self._add_node_after(prev_node, node)
|
||||
prev_node = node
|
||||
node = node.parent
|
||||
|
||||
def insert_mru(self, node):
|
||||
"""
|
||||
Insert a (new) node as most recently used
|
||||
"""
|
||||
assert (
|
||||
not self.mamba or node.mamba_value is not None
|
||||
), f"Inserting mamba tombstone node in mamba lru list: {node.id=}"
|
||||
assert (
|
||||
node.id not in self.cache
|
||||
), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}"
|
||||
self.cache[node.id] = node
|
||||
self._add_node(node)
|
||||
|
||||
def remove_node(self, node: TreeNode):
|
||||
"""
|
||||
Remove node from lru list
|
||||
"""
|
||||
assert node.id in self.cache, f"Removing node {node.id=} not in lru list"
|
||||
assert (
|
||||
not self.mamba or node.mamba_value is not None
|
||||
), f"Removing mamba tombstone node from mamba lru list: {node.id=}"
|
||||
del self.cache[node.id]
|
||||
self._remove_node(node)
|
||||
|
||||
def get_lru_no_lock(self) -> Optional[TreeNode]:
|
||||
"""
|
||||
Get the least recently used node that is not locked
|
||||
"""
|
||||
return self.get_prev_no_lock(self.tail, check_id=False)
|
||||
|
||||
def get_leaf_lru_no_lock(self) -> Optional[TreeNode]:
|
||||
"""
|
||||
Get the least recently used leaf node that is not locked
|
||||
"""
|
||||
return self.get_prev_leaf_no_lock(self.tail, check_id=False)
|
||||
|
||||
def get_prev_no_lock(
|
||||
self, node: TreeNode, check_id: bool = True
|
||||
) -> Optional[TreeNode]:
|
||||
"""
|
||||
Get the previous (i.e. more recently used) node that is not locked
|
||||
"""
|
||||
if check_id:
|
||||
assert (
|
||||
node.id in self.cache
|
||||
), f"Getting prev of node {node.id=} not in lru list"
|
||||
x = getattr(node, self.prv) # x = node.prev
|
||||
while getattr(x, self.lock_ref) > 0:
|
||||
x = getattr(x, self.prv) # x = x.prev
|
||||
# if x is the head, it means there is no node in the lru list without lock
|
||||
if x == self.head:
|
||||
return None
|
||||
return x
|
||||
|
||||
def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True):
|
||||
"""
|
||||
Get the previous (i.e. more recently used) leaf node that is not locked
|
||||
"""
|
||||
if check_id:
|
||||
assert (
|
||||
node.id in self.cache
|
||||
), f"Getting prev of node {node.id=} not in lru list"
|
||||
x = getattr(node, self.prv) # x = node.prev
|
||||
while getattr(x, self.lock_ref) > 0 or len(x.children) > 0:
|
||||
x = getattr(x, self.prv) # x = x.prev
|
||||
# if x is the head, it means there is no leaf node in the lru list without lock
|
||||
if x == self.head:
|
||||
return None
|
||||
return x
|
||||
|
||||
def in_list(self, node: Optional[TreeNode]):
|
||||
"""
|
||||
Check if the node is in the lru list
|
||||
"""
|
||||
if not node:
|
||||
return False
|
||||
return node.id in self.cache
|
||||
|
||||
# Note: this is expensive, only use for debug
|
||||
def sanity_check_evictable_size(self):
|
||||
"""
|
||||
Check the evictable size (i.e. the size of the nodes that are not locked)
|
||||
"""
|
||||
node = self.get_lru_no_lock()
|
||||
evictable_size = 0
|
||||
while self.in_list(node):
|
||||
evictable_size += (
|
||||
len(node.value) if not self.mamba else len(node.mamba_value)
|
||||
)
|
||||
node = self.get_prev_no_lock(node)
|
||||
return evictable_size
|
||||
|
||||
# Note: this is expensive, only use for debug or idle check
|
||||
def sanity_check(self, tree_cache: "MambaRadixCache"):
|
||||
"""
|
||||
Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and
|
||||
checking if the lru list is valid.
|
||||
"""
|
||||
try:
|
||||
if self.mamba:
|
||||
nodes = tree_cache._collect_nontombstone_nodes()
|
||||
else:
|
||||
nodes = tree_cache._collect_all_nodes()
|
||||
total_nodes = len(nodes)
|
||||
total_lru = len(self.cache)
|
||||
# heapify based on last_access_time
|
||||
heapq.heapify(nodes)
|
||||
# the root node is not in the lru list
|
||||
assert len(nodes) == (
|
||||
total_lru + (0 if self.mamba else 1)
|
||||
), f"len(nodes): {len(nodes)}, total_lru: {total_lru}"
|
||||
|
||||
x_lru = self._get_lru()
|
||||
while len(nodes):
|
||||
x = heapq.heappop(nodes)
|
||||
if x == tree_cache.root_node:
|
||||
# root node is not in the lru list
|
||||
continue
|
||||
assert (
|
||||
x == x_lru
|
||||
), f"Incorrect LRU list, {self.mamba=}, x: {x.id=} != x_lru: {x_lru.id=}"
|
||||
assert (
|
||||
x_lru.full_lock_ref == 0
|
||||
), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.id=}"
|
||||
assert (
|
||||
x_lru.mamba_lock_ref == 0
|
||||
), f"x_lru should not be locked when idle, {x_lru.mamba_lock_ref=}, {x_lru.id=}"
|
||||
x_lru = getattr(x, self.prv)
|
||||
|
||||
if self.mamba:
|
||||
evictable_size = tree_cache.mamba_evictable_size()
|
||||
lru_list_evictable_size = tree_cache.mamba_lru_list_evictable_size()
|
||||
else:
|
||||
evictable_size = tree_cache.full_evictable_size()
|
||||
lru_list_evictable_size = tree_cache.full_lru_list_evictable_size()
|
||||
|
||||
assert (
|
||||
evictable_size == lru_list_evictable_size
|
||||
), f"{self.mamba=}, total nodes: {total_nodes}, total lru: {total_lru}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}"
|
||||
except Exception as e:
|
||||
msg = f"Mamba Radix tree sanity check failed, ping @yizhang2077: {e}"
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
|
||||
|
||||
class MambaRadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: HybridReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
):
|
||||
assert isinstance(token_to_kv_pool_allocator, TokenToKVPoolAllocator)
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
|
||||
assert page_size == 1, "Only support page_size=1 in mamba radix cache now."
|
||||
self.page_size = page_size
|
||||
self.disable = disable
|
||||
|
||||
if self.token_to_kv_pool_allocator:
|
||||
self.device = self.token_to_kv_pool_allocator.device
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.key_match_fn = _key_match_page_size1
|
||||
self.get_child_key_fn = get_child_key
|
||||
self.reset()
|
||||
|
||||
##### Public API #####
|
||||
|
||||
def reset(self) -> None:
|
||||
self.root_node = TreeNode()
|
||||
self.root_node.key = []
|
||||
self.root_node.value = []
|
||||
self.root_node.full_lock_ref = 1
|
||||
self.root_node.mamba_lock_ref = 1
|
||||
self.full_evictable_size_ = 0
|
||||
self.mamba_evictable_size_ = 0
|
||||
self.full_protected_size_ = 0
|
||||
self.mamba_protected_size_ = 0
|
||||
# LRU lists are used to maintain the order of eviction of the nodes in the tree
|
||||
self.full_lru_list = LRUList(mamba=False)
|
||||
self.mamba_lru_list = LRUList(mamba=True)
|
||||
|
||||
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
Args:
|
||||
key: A RadixKey contains token IDs to find a matching prefix.
|
||||
Returns:
|
||||
A tuple of a tensor of matching prefix token IDs and
|
||||
the last node that contains the prefix values. Note that
|
||||
this API can modify the internal state of the Radix tree.
|
||||
The last node create a new child if the prefix is shorter
|
||||
than the last node's value.
|
||||
"""
|
||||
cow_mamba: bool = kwargs.get("cow_mamba", False)
|
||||
req: Req = kwargs.get("req", None)
|
||||
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
device_indices=torch.empty(
|
||||
(0,),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
),
|
||||
last_device_node=self.root_node,
|
||||
last_host_node=self.root_node,
|
||||
)
|
||||
|
||||
value, last_node = self._match_prefix_helper(key)
|
||||
|
||||
# copy mamba state to req local space if cow is true
|
||||
if cow_mamba and last_node.mamba_value is not None:
|
||||
assert req.req_pool_idx is None # req_pool_idx is uninitialed
|
||||
|
||||
# for reqs without mamba cache
|
||||
if req.mamba_pool_idx is None:
|
||||
dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
|
||||
# try to alloc again, protect last_node from eviction
|
||||
if dst_index is None:
|
||||
self.inc_lock_ref(last_node)
|
||||
self.evict_mamba(1)
|
||||
dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
|
||||
self.dec_lock_ref(last_node)
|
||||
assert dst_index is not None, "Can not alloc mamba cache"
|
||||
src_index = last_node.mamba_value
|
||||
self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
|
||||
req.mamba_pool_idx = dst_index[0]
|
||||
else:
|
||||
src_index = last_node.mamba_value
|
||||
dst_index = req.mamba_pool_idx.unsqueeze(0)
|
||||
self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
|
||||
|
||||
if value:
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
|
||||
return MatchResult(
|
||||
device_indices=value,
|
||||
last_device_node=last_node,
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
def insert(self, key: RadixKey, value=None, mamba_value=None) -> Tuple[int, bool]:
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
if value is None:
|
||||
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
|
||||
return self._insert_helper(self.root_node, key, value, mamba_value)
|
||||
|
||||
def cache_finished_req(self, req: Req) -> None:
|
||||
"""Cache request when it finishes."""
|
||||
if self.disable:
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx,
|
||||
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
# insert the token_ids and kv_indices into the radix tree
|
||||
# Note: the insert function already frees the overlapped kv_indices
|
||||
mamba_value = (
|
||||
self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
|
||||
.unsqueeze(-1)
|
||||
.clone()
|
||||
)
|
||||
|
||||
new_prefix_len, mamba_exist = self.insert(
|
||||
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
mamba_value,
|
||||
)
|
||||
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist)
|
||||
self.dec_lock_ref(req.last_node)
|
||||
|
||||
def cache_unfinished_req(self, req: Req, chunked=False) -> None:
|
||||
"""Cache request when it is unfinished."""
|
||||
if self.disable:
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(req.fill_ids)
|
||||
]
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
req.prefix_indices = kv_indices
|
||||
return
|
||||
|
||||
token_ids = req.fill_ids
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||
|
||||
mamba_value = self.req_to_token_pool.get_mamba_indices(
|
||||
req.req_pool_idx
|
||||
).unsqueeze(-1)
|
||||
# radix tree mamba value is forked from req space
|
||||
mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value)
|
||||
|
||||
# if alloc mamba cache failed, do evict and alloc again
|
||||
if mamba_value_forked is None:
|
||||
self.evict_mamba(1)
|
||||
mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(
|
||||
mamba_value
|
||||
)
|
||||
assert mamba_value_forked is not None, "Can not alloc mamba cache"
|
||||
new_prefix_len, mamba_exist = self.insert(
|
||||
RadixKey(page_aligned_token_ids, req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
mamba_value_forked,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
# there is a mamba cache in radix cache, release it
|
||||
if mamba_exist:
|
||||
self.req_to_token_pool.mamba_pool.free(mamba_value_forked)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||
RadixKey(page_aligned_token_ids, req.extra_key)
|
||||
)
|
||||
|
||||
if not mamba_exist:
|
||||
assert torch.equal(new_last_node.mamba_value, mamba_value_forked)
|
||||
|
||||
assert len(req.prefix_indices) <= len(
|
||||
new_indices
|
||||
), f"{req.prefix_indices=}, {new_indices=}"
|
||||
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
|
||||
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
)
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self) -> None:
|
||||
self._print_helper(self.root_node, 0)
|
||||
total_size, total_mamba_size = self._total_size_helper()
|
||||
print(f"#full_tokens: {total_size}, #mamba_num: {total_mamba_size}")
|
||||
|
||||
def total_size(self) -> Tuple[int, int]:
|
||||
return self._total_size_helper()
|
||||
|
||||
def _evict_leaf_node(
|
||||
self, x: TreeNode, is_evict_mamba: bool
|
||||
) -> Tuple[int, int, TreeNode, TreeNode]:
|
||||
assert (
|
||||
x.full_lock_ref == 0 and x.mamba_lock_ref == 0
|
||||
), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}"
|
||||
|
||||
assert x.mamba_value is not None, f"leaf node mamba value is not None, {x.id=}"
|
||||
# 1. a leaf node, free full tokens and mamba
|
||||
self.token_to_kv_pool_allocator.free(x.value)
|
||||
full_num_evicted = len(x.value)
|
||||
self.req_to_token_pool.mamba_pool.free(x.mamba_value)
|
||||
mamba_num_evicted = len(x.mamba_value)
|
||||
|
||||
# 2. get the next node, update the lru lists
|
||||
if is_evict_mamba:
|
||||
x_next = self.mamba_lru_list.get_prev_no_lock(x)
|
||||
else:
|
||||
x_next = self.full_lru_list.get_prev_leaf_no_lock(x)
|
||||
self.full_lru_list.remove_node(x)
|
||||
self.mamba_lru_list.remove_node(x)
|
||||
|
||||
# 3. delete the leaf node
|
||||
self._delete_leaf(x)
|
||||
|
||||
# 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
|
||||
x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x)
|
||||
full_num_evicted += leaf_full_num_evicted
|
||||
return full_num_evicted, mamba_num_evicted, x, x_next
|
||||
|
||||
def evict_mamba(self, mamba_num: int) -> None:
|
||||
if self.disable or mamba_num <= 0:
|
||||
return
|
||||
# get the least recently used node that is not locked, doesn't have to be a leaf
|
||||
x = self.mamba_lru_list.get_lru_no_lock()
|
||||
mamba_num_evicted = 0
|
||||
# evict lru leaf nodes until mamba_num_tokens is reached
|
||||
while mamba_num_evicted < mamba_num and (self.mamba_lru_list.in_list(x)):
|
||||
assert x.mamba_value is not None, f"node has no mamba value, {x.id=}"
|
||||
assert (
|
||||
len(x.mamba_value) == 1
|
||||
), f"node has abnormal mamba length, {x.id=}, {len(x.mamba_value)=}"
|
||||
assert x != self.root_node, f"root node is not evictable, {x.id=}"
|
||||
assert x.mamba_lock_ref == 0, f"node is in use by mamba kv indices, {x.id=}"
|
||||
|
||||
if len(x.children) > 0:
|
||||
# 1. an internal node, free mamba tokens.
|
||||
self.req_to_token_pool.mamba_pool.free(x.mamba_value)
|
||||
mamba_num_evicted += len(x.mamba_value)
|
||||
|
||||
# 2. get the next node, update the lru lists
|
||||
x_next = self.mamba_lru_list.get_prev_no_lock(x)
|
||||
self.mamba_lru_list.remove_node(x)
|
||||
|
||||
# 3. tombstone the node
|
||||
self._tombstone_internal_node(x)
|
||||
else:
|
||||
_, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x, True)
|
||||
mamba_num_evicted += mamba_evicted_delta
|
||||
|
||||
x = x_next
|
||||
|
||||
def evict(self, full_num_tokens: int) -> None:
|
||||
if self.disable or full_num_tokens <= 0:
|
||||
return
|
||||
|
||||
full_num_evicted = 0
|
||||
# get the least recently used leaf node that is not locked
|
||||
x = self.full_lru_list.get_leaf_lru_no_lock()
|
||||
|
||||
while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x):
|
||||
assert (
|
||||
x != self.root_node
|
||||
), f"root node should not exist in full lru list, {x.id=}"
|
||||
full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x, False)
|
||||
full_num_evicted += full_num_evicted_delta
|
||||
|
||||
# if parent has no more children, it is a leaf. It is possible that this node is lru, so
|
||||
# we need to get the first leaf node in the lru list
|
||||
if len(x.parent.children) == 0:
|
||||
x_next = self.full_lru_list.get_leaf_lru_no_lock()
|
||||
|
||||
x = x_next
|
||||
|
||||
def inc_lock_ref(self, node: TreeNode) -> Optional[int]:
|
||||
"""
|
||||
Increment the lock reference count for the node.
|
||||
It locks the full_lock_ref for nodes between the [last node, root), exclusive.
|
||||
It locks the mamba_lock_ref for current node if its mamba_value exists.
|
||||
"""
|
||||
if self.disable:
|
||||
return None
|
||||
|
||||
# protect mamba value in current node if it exists
|
||||
if node.mamba_value is not None:
|
||||
if node.mamba_lock_ref == 0:
|
||||
self.mamba_evictable_size_ -= len(node.mamba_value)
|
||||
self.mamba_protected_size_ += len(node.mamba_value)
|
||||
node.mamba_lock_ref += 1
|
||||
|
||||
while node != self.root_node:
|
||||
# lock full from node to root
|
||||
assert (
|
||||
node.full_lock_ref >= 0
|
||||
), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
|
||||
if node.full_lock_ref == 0:
|
||||
self.full_evictable_size_ -= len(node.value)
|
||||
self.full_protected_size_ += len(node.value)
|
||||
node.full_lock_ref += 1
|
||||
node = node.parent
|
||||
return None
|
||||
|
||||
def dec_lock_ref(self, node: TreeNode):
|
||||
"""
|
||||
Decrement the lock reference count for the node.
|
||||
It unlocks the full_lock_ref for nodes between the [last node, root), exclusive.
|
||||
It unlocks the mamba_lock_ref for current node if its mamba_value exists.
|
||||
"""
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
if node.mamba_value is not None:
|
||||
assert (
|
||||
node.mamba_lock_ref > 0
|
||||
), f"dec_lock_ref on node with {node.mamba_lock_ref=}, {node.id=}"
|
||||
if node.mamba_lock_ref == 1:
|
||||
self.mamba_evictable_size_ += len(node.mamba_value)
|
||||
self.mamba_protected_size_ -= len(node.mamba_value)
|
||||
node.mamba_lock_ref -= 1
|
||||
|
||||
while node != self.root_node:
|
||||
assert (
|
||||
node.full_lock_ref > 0
|
||||
), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
|
||||
if node.full_lock_ref == 1:
|
||||
self.full_evictable_size_ += len(node.value)
|
||||
self.full_protected_size_ -= len(node.value)
|
||||
node.full_lock_ref -= 1
|
||||
node = node.parent
|
||||
|
||||
def sanity_check(self):
|
||||
self.full_lru_list.sanity_check(self)
|
||||
self.mamba_lru_list.sanity_check(self)
|
||||
|
||||
def evictable_size(self) -> Tuple[int, int]:
|
||||
# Note: use full_evictable_size() and mamba_evictable_size() instead.
|
||||
raise NotImplementedError
|
||||
|
||||
def full_evictable_size(self) -> int:
|
||||
return self.full_evictable_size_
|
||||
|
||||
def mamba_evictable_size(self) -> int:
|
||||
return self.mamba_evictable_size_
|
||||
|
||||
# Note: this is expensive, only use for debug
|
||||
def full_lru_list_evictable_size(self) -> int:
|
||||
return self.full_lru_list.sanity_check_evictable_size()
|
||||
|
||||
# Note: this is expensive, only use for debug
|
||||
def mamba_lru_list_evictable_size(self) -> int:
|
||||
return self.mamba_lru_list.sanity_check_evictable_size()
|
||||
|
||||
def protected_size(self) -> Tuple[int, int]:
|
||||
# Note: use full_protected_size() and mamba_protected_size() instead.
|
||||
raise NotImplementedError
|
||||
|
||||
def full_protected_size(self) -> int:
|
||||
# protected size refers to the size of the full cache that is locked
|
||||
return self.full_protected_size_
|
||||
|
||||
def mamba_protected_size(self) -> int:
|
||||
# protected size refers to the size of the mamba cache that is locked
|
||||
return self.mamba_protected_size_
|
||||
|
||||
def all_values_flatten(self) -> torch.Tensor:
|
||||
values = []
|
||||
|
||||
def _dfs_helper(node: TreeNode):
|
||||
for _, child in node.children.items():
|
||||
values.append(child.value)
|
||||
_dfs_helper(child)
|
||||
|
||||
_dfs_helper(self.root_node)
|
||||
return torch.cat(values)
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(
|
||||
self, key: RadixKey
|
||||
) -> Tuple[List[torch.Tensor], TreeNode]:
|
||||
"""
|
||||
Mamba prefix matching helper. It factors in the sliding window size such that
|
||||
the matched node is guaranteed to either 1. connected to root without mamba tombstone,
|
||||
or 2. the number of matching tokens from the matched node to the last mamba tombstone
|
||||
node is greater than or equal to the sliding window size.
|
||||
"""
|
||||
node = self.root_node
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
value = []
|
||||
best_value_len = 0
|
||||
best_last_node = node
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
child = node.children[child_key]
|
||||
# update best_value_len and best_last_node if needed
|
||||
if node.mamba_value is not None:
|
||||
best_value_len = len(value)
|
||||
best_last_node = node
|
||||
|
||||
prefix_len = self.key_match_fn(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
value.append(new_node.value)
|
||||
node = new_node
|
||||
break
|
||||
else:
|
||||
value.append(child.value)
|
||||
node = child
|
||||
key = key[prefix_len:]
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
# handle best_value_len and best_last_node, for the case that last node is fully matched
|
||||
if node.mamba_value is not None:
|
||||
best_value_len = len(value)
|
||||
best_last_node = node
|
||||
|
||||
# update time for matched nodes, and make nodes closer to root to be least recently used
|
||||
# this allows mamba to evict nodes closer to root first
|
||||
self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
|
||||
self.mamba_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
|
||||
|
||||
# This last_access_time is for sanity check, can be deleted after validation in production
|
||||
cur_time = time.monotonic()
|
||||
while node:
|
||||
node.last_access_time = cur_time
|
||||
cur_time -= 0.0001
|
||||
node = node.parent
|
||||
|
||||
return value[:best_value_len], best_last_node
|
||||
|
||||
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
|
||||
# new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
new_node.parent = child.parent
|
||||
new_node.mamba_value = None # mamba cache can not be split
|
||||
new_node.full_lock_ref = child.full_lock_ref
|
||||
new_node.mamba_lock_ref = 0
|
||||
new_node.key = child.key[:split_len]
|
||||
new_node.value = child.value[:split_len]
|
||||
|
||||
# child time should be later than parent's time for mamba tombstone
|
||||
child.last_access_time = time.monotonic()
|
||||
|
||||
self.full_lru_list.remove_node(child)
|
||||
if child.mamba_value is not None:
|
||||
self.mamba_lru_list.remove_node(child)
|
||||
child.parent = new_node
|
||||
child.key = child.key[split_len:]
|
||||
child.value = child.value[split_len:]
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
|
||||
# insert the new node and child into the lru lists, insert
|
||||
# parent first so that parent is after child in the lru list
|
||||
self.full_lru_list.insert_mru(new_node)
|
||||
self.full_lru_list.insert_mru(child)
|
||||
if child.mamba_value is not None:
|
||||
self.mamba_lru_list.insert_mru(child)
|
||||
return new_node
|
||||
|
||||
def _insert_helper(
|
||||
self,
|
||||
node: TreeNode,
|
||||
key: RadixKey,
|
||||
value,
|
||||
mamba_value,
|
||||
) -> Tuple[int, bool]:
|
||||
# Update the last access time from root to leaf, so that
|
||||
# mamba will tombstone the node closer to root first
|
||||
assert mamba_value is not None, "Mamba value should not be None here."
|
||||
node.last_access_time = time.monotonic()
|
||||
if node != self.root_node:
|
||||
self.full_lru_list.reset_node_mru(node)
|
||||
if node.mamba_value is not None:
|
||||
self.mamba_lru_list.reset_node_mru(node)
|
||||
if len(key) == 0:
|
||||
return 0, True
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
total_prefix_length = 0
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
node = node.children[child_key]
|
||||
node.last_access_time = time.monotonic()
|
||||
self.full_lru_list.reset_node_mru(node)
|
||||
if node.mamba_value is not None:
|
||||
self.mamba_lru_list.reset_node_mru(node)
|
||||
prefix_len = self.key_match_fn(node.key, key)
|
||||
total_prefix_length += prefix_len
|
||||
key = key[prefix_len:]
|
||||
value = value[prefix_len:]
|
||||
|
||||
if prefix_len < len(node.key):
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
node = new_node
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
mamba_value_exist = False
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = value
|
||||
new_node.mamba_value = mamba_value
|
||||
self.full_lru_list.insert_mru(new_node)
|
||||
self.full_evictable_size_ += len(value)
|
||||
self.mamba_evictable_size_ += len(mamba_value)
|
||||
self.mamba_lru_list.insert_mru(new_node)
|
||||
node.children[child_key] = new_node
|
||||
elif node.mamba_value is None: # add for mamba tombstone
|
||||
node.mamba_value = mamba_value
|
||||
self.mamba_evictable_size_ += len(mamba_value)
|
||||
self.mamba_lru_list.insert_mru(node)
|
||||
else:
|
||||
mamba_value_exist = True
|
||||
self.mamba_lru_list.reset_node_mru(node)
|
||||
|
||||
return total_prefix_length, mamba_value_exist
|
||||
|
||||
def _iteratively_delete_tombstone_leaf(
|
||||
self, node: TreeNode
|
||||
) -> Tuple[TreeNode, int]:
|
||||
full_num_evicted = 0
|
||||
while node.parent.mamba_value is None and len(node.parent.children) == 0:
|
||||
# root node is not evictable
|
||||
if node.parent == self.root_node:
|
||||
break
|
||||
# if locked, means node is in use, skip
|
||||
if node.parent.full_lock_ref > 0:
|
||||
break
|
||||
assert (
|
||||
node.parent.mamba_lock_ref == 0
|
||||
), f"tombstone mamba_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.mamba_lock_ref=}, {node.parent.id=}"
|
||||
# delete tombstone node evicts full tokens
|
||||
self.token_to_kv_pool_allocator.free(node.parent.value)
|
||||
full_num_evicted += len(node.parent.value)
|
||||
self.full_lru_list.remove_node(node.parent)
|
||||
self._delete_tombstone_leaf(node.parent)
|
||||
node = node.parent
|
||||
|
||||
return node, full_num_evicted
|
||||
|
||||
def _delete_leaf(self, node: TreeNode) -> None:
|
||||
assert (
|
||||
node.mamba_value is not None
|
||||
), f"Invariant violated: leaf node is a tombstone, {node.id=}"
|
||||
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
|
||||
for k, v in node.parent.children.items():
|
||||
if v == node:
|
||||
break
|
||||
del node.parent.children[k]
|
||||
self.full_evictable_size_ -= len(node.key)
|
||||
self.mamba_evictable_size_ -= len(node.mamba_value)
|
||||
|
||||
def _tombstone_internal_node(self, node: TreeNode) -> None:
|
||||
assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}"
|
||||
self.mamba_evictable_size_ -= len(node.mamba_value)
|
||||
node.mamba_value = None
|
||||
|
||||
def _delete_tombstone_leaf(self, node: TreeNode) -> None:
|
||||
assert (
|
||||
node.mamba_value is None
|
||||
), f"Deleting a unexpected non-tombstone leaf node, {node.id=}"
|
||||
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
|
||||
for k, v in node.parent.children.items():
|
||||
if v == node:
|
||||
break
|
||||
del node.parent.children[k]
|
||||
self.full_evictable_size_ -= len(node.key)
|
||||
|
||||
def _collect_leaves(self) -> List[TreeNode]:
|
||||
ret_list = []
|
||||
stack = [self.root_node]
|
||||
|
||||
while stack:
|
||||
cur_node = stack.pop()
|
||||
if len(cur_node.children) == 0:
|
||||
ret_list.append(cur_node)
|
||||
else:
|
||||
stack.extend(cur_node.children.values())
|
||||
|
||||
return ret_list
|
||||
|
||||
def _collect_nontombstone_nodes(self) -> List[TreeNode]:
|
||||
ret_list = []
|
||||
stack = [self.root_node]
|
||||
|
||||
while stack:
|
||||
cur_node = stack.pop()
|
||||
if cur_node.mamba_value is not None:
|
||||
ret_list.append(cur_node)
|
||||
stack.extend(cur_node.children.values())
|
||||
|
||||
return ret_list
|
||||
|
||||
def _collect_all_nodes(self) -> List[TreeNode]:
|
||||
ret_list = []
|
||||
stack = [self.root_node]
|
||||
while stack:
|
||||
cur_node = stack.pop()
|
||||
ret_list.append(cur_node)
|
||||
stack.extend(cur_node.children.values())
|
||||
return ret_list
|
||||
|
||||
def _print_helper(self, node: TreeNode, indent: int) -> None:
|
||||
"""Prints the radix tree in a human-readable format."""
|
||||
stack = [(node, indent)]
|
||||
while stack:
|
||||
current_node, current_indent = stack.pop()
|
||||
print(
|
||||
" " * current_indent,
|
||||
f"[{current_node.id}]",
|
||||
len(current_node.key),
|
||||
f"fr={current_node.full_lock_ref}",
|
||||
f"mr={current_node.mamba_lock_ref}",
|
||||
f"fll={self.full_lru_list.in_list(current_node)}",
|
||||
f"mll={self.mamba_lru_list.in_list(current_node)}",
|
||||
f"mv={current_node.mamba_value}",
|
||||
)
|
||||
for key, child in current_node.children.items():
|
||||
stack.append((child, current_indent + 2))
|
||||
|
||||
assert key == self.get_child_key_fn(
|
||||
child.key
|
||||
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
||||
|
||||
def _total_size_helper(self) -> Tuple[int, int]:
|
||||
total_size = 0
|
||||
total_mamba_size = 0
|
||||
stack = [self.root_node]
|
||||
while stack:
|
||||
current_node = stack.pop()
|
||||
total_size += len(current_node.value)
|
||||
if current_node.mamba_value is not None:
|
||||
total_mamba_size += len(current_node.mamba_value)
|
||||
for child in current_node.children.values():
|
||||
if child.evicted:
|
||||
continue
|
||||
stack.append(child)
|
||||
return total_size, total_mamba_size
|
||||
@@ -190,6 +190,7 @@ class MambaPool:
|
||||
)
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"max_mamba_cache_size: {size}, "
|
||||
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
||||
@@ -199,11 +200,13 @@ class MambaPool:
|
||||
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"max_mamba_cache_size: {size}, "
|
||||
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||
)
|
||||
self.size = size
|
||||
self.free_slots = list(range(size))
|
||||
self.device = device
|
||||
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
|
||||
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
|
||||
|
||||
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
|
||||
@@ -216,7 +219,7 @@ class MambaPool:
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self, need_size: int) -> Optional[List[int]]:
|
||||
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
@@ -225,17 +228,30 @@ class MambaPool:
|
||||
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: Union[int, List[int]]):
|
||||
if isinstance(free_index, (int,)):
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.free_slots.extend(free_index)
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
self.free_slots = torch.cat((self.free_slots, free_index))
|
||||
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
|
||||
:, free_index
|
||||
] = 0
|
||||
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size))
|
||||
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
|
||||
|
||||
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
|
||||
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
|
||||
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
|
||||
:, src_index
|
||||
]
|
||||
return
|
||||
|
||||
def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
dst_index = self.alloc(1)
|
||||
if dst_index == None:
|
||||
return None
|
||||
self.copy_from(src_index, dst_index)
|
||||
return dst_index
|
||||
|
||||
|
||||
class HybridReqToTokenPool(ReqToTokenPool):
|
||||
@@ -245,6 +261,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
||||
self,
|
||||
*,
|
||||
size: int,
|
||||
mamba_size: int,
|
||||
max_context_len: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
@@ -259,7 +276,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
||||
)
|
||||
|
||||
self.mamba_pool = MambaPool(
|
||||
size=size,
|
||||
size=mamba_size,
|
||||
cache_params=cache_params,
|
||||
device=device,
|
||||
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
||||
@@ -271,9 +288,6 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
||||
size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
|
||||
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
|
||||
|
||||
# For chunk prefill req, we do not need to allocate mamba cache,
|
||||
# We could use allocated mamba cache instead.
|
||||
def alloc(
|
||||
@@ -285,14 +299,14 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
||||
|
||||
mamba_index = []
|
||||
for req in reqs:
|
||||
rid = req.rid
|
||||
if rid in self.rid_to_mamba_index_mapping:
|
||||
mid = self.rid_to_mamba_index_mapping[rid]
|
||||
elif (mid := self.mamba_pool.alloc(1)) is not None:
|
||||
mid = mid[0]
|
||||
self.rid_to_mamba_index_mapping[rid] = mid
|
||||
self.mamba_index_to_rid_mapping[mid] = rid
|
||||
mamba_index.append(mid)
|
||||
mid = None
|
||||
if req.mamba_pool_idx is not None: # for radix cache
|
||||
mid = req.mamba_pool_idx
|
||||
else:
|
||||
mid = self.mamba_pool.alloc(1)[0]
|
||||
req.mamba_pool_idx = mid
|
||||
if mid is not None:
|
||||
mamba_index.append(mid)
|
||||
assert len(select_index) == len(
|
||||
mamba_index
|
||||
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
|
||||
@@ -313,17 +327,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
||||
|
||||
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
||||
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
||||
if isinstance(free_index, (int,)):
|
||||
free_index = [free_index]
|
||||
super().free(free_index)
|
||||
if free_mamba_cache:
|
||||
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
|
||||
mamba_index_list = mamba_index.tolist()
|
||||
if isinstance(mamba_index_list, int):
|
||||
mamba_index_list = [mamba_index_list]
|
||||
self.mamba_pool.free(mamba_index_list)
|
||||
for mid in mamba_index_list:
|
||||
rid = self.mamba_index_to_rid_mapping[mid]
|
||||
self.mamba_index_to_rid_mapping.pop(mid)
|
||||
self.rid_to_mamba_index_mapping.pop(rid)
|
||||
self.mamba_pool.free(mamba_index)
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
|
||||
@@ -191,6 +191,9 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
# Detect stragger ranks in model loading
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
|
||||
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -382,26 +385,10 @@ class ModelRunner:
|
||||
if architectures and not any("Llama4" in arch for arch in architectures):
|
||||
self.is_hybrid = self.model_config.is_hybrid = True
|
||||
|
||||
if config := self.mambaish_config:
|
||||
if config := self.mamba2_config:
|
||||
class_name = config.__class__.__name__
|
||||
logger.warning(f"{class_name} model detected, disable radix cache")
|
||||
self.server_args.disable_radix_cache = True
|
||||
if self.server_args.max_mamba_cache_size is None:
|
||||
if self.server_args.max_running_requests is not None:
|
||||
self.server_args.max_mamba_cache_size = (
|
||||
self.server_args.max_running_requests
|
||||
)
|
||||
else:
|
||||
self.server_args.max_mamba_cache_size = 512
|
||||
if self.hybrid_gdn_config is not None:
|
||||
self.server_args.max_mamba_cache_size = (
|
||||
self.server_args.max_mamba_cache_size
|
||||
// (
|
||||
self.server_args.dp_size
|
||||
if self.server_args.enable_dp_attention
|
||||
else 1
|
||||
)
|
||||
)
|
||||
|
||||
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
||||
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
||||
@@ -1330,15 +1317,60 @@ class ModelRunner:
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
if config := self.mambaish_config:
|
||||
rest_memory -= (
|
||||
self.server_args.max_mamba_cache_size
|
||||
* config.mamba2_cache_params.mamba_cache_per_req
|
||||
/ (1 << 30)
|
||||
)
|
||||
if self.mambaish_config is not None:
|
||||
rest_memory = self.handle_max_mamba_cache(rest_memory)
|
||||
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||
return max_num_token
|
||||
|
||||
def handle_max_mamba_cache(self, total_rest_memory):
|
||||
config = self.mambaish_config
|
||||
server_args = self.server_args
|
||||
assert config is not None
|
||||
|
||||
speculativa_ratio = (
|
||||
0
|
||||
if server_args.speculative_num_draft_tokens is None
|
||||
else server_args.speculative_num_draft_tokens
|
||||
)
|
||||
if (
|
||||
server_args.disable_radix_cache
|
||||
or config.mamba2_cache_params.mamba_cache_per_req == 0
|
||||
):
|
||||
# with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
|
||||
if server_args.max_mamba_cache_size is None:
|
||||
if server_args.max_running_requests is not None:
|
||||
server_args.max_mamba_cache_size = server_args.max_running_requests
|
||||
else:
|
||||
server_args.max_mamba_cache_size = 512
|
||||
else:
|
||||
# allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
|
||||
# solve the equations:
|
||||
# 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
|
||||
# 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
|
||||
mamba_state_memory_raw = (
|
||||
total_rest_memory
|
||||
* server_args.mamba_full_memory_ratio
|
||||
/ (1 + server_args.mamba_full_memory_ratio)
|
||||
)
|
||||
# calculate the max_mamba_cache_size based on the given total mamba memory
|
||||
server_args.max_mamba_cache_size = int(
|
||||
(mamba_state_memory_raw * (1 << 30))
|
||||
// config.mamba2_cache_params.mamba_cache_per_req
|
||||
// (1 + speculativa_ratio)
|
||||
)
|
||||
|
||||
if self.hybrid_gdn_config is not None:
|
||||
server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
|
||||
server_args.dp_size if server_args.enable_dp_attention else 1
|
||||
)
|
||||
mamba_state_memory = (
|
||||
server_args.max_mamba_cache_size
|
||||
* config.mamba2_cache_params.mamba_cache_per_req
|
||||
* (1 + speculativa_ratio)
|
||||
/ (1 << 30)
|
||||
)
|
||||
return total_rest_memory - mamba_state_memory
|
||||
|
||||
@property
|
||||
def hybrid_gdn_config(self):
|
||||
config = self.model_config.hf_config
|
||||
@@ -1511,8 +1543,16 @@ class ModelRunner:
|
||||
),
|
||||
4096,
|
||||
)
|
||||
|
||||
if self.mambaish_config is not None:
|
||||
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
||||
ratio = (
|
||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
|
||||
if not self.server_args.disable_radix_cache
|
||||
else 1
|
||||
)
|
||||
max_num_reqs = min(
|
||||
max_num_reqs, self.server_args.max_mamba_cache_size // ratio
|
||||
)
|
||||
|
||||
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
||||
if self.is_draft_worker:
|
||||
@@ -1595,6 +1635,7 @@ class ModelRunner:
|
||||
elif config := self.mambaish_config:
|
||||
self.req_to_token_pool = HybridReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
mamba_size=self.server_args.max_mamba_cache_size,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
|
||||
@@ -362,6 +362,7 @@ class ServerArgs:
|
||||
# Mamba cache
|
||||
max_mamba_cache_size: Optional[int] = None
|
||||
mamba_ssm_dtype: str = "float32"
|
||||
mamba_full_memory_ratio: float = 0.2
|
||||
|
||||
# Hierarchical cache
|
||||
enable_hierarchical_cache: bool = False
|
||||
@@ -2433,6 +2434,12 @@ class ServerArgs:
|
||||
choices=["float32", "bfloat16"],
|
||||
help="The data type of the SSM states in mamba cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mamba-full-memory-ratio",
|
||||
type=float,
|
||||
default=ServerArgs.mamba_full_memory_ratio,
|
||||
help="The ratio of mamba state memory to full kv cache memory.",
|
||||
)
|
||||
# Args for multi-item-scoring
|
||||
parser.add_argument(
|
||||
"--multi-item-scoring-delimiter",
|
||||
|
||||
Reference in New Issue
Block a user