[Feature] Radix Tree in C++ (#7369)
This commit is contained in:
@@ -569,7 +569,23 @@ class Scheduler(
|
||||
page_size=self.page_size,
|
||||
)
|
||||
else:
|
||||
if self.enable_hierarchical_cache:
|
||||
if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
|
||||
# lazy import to avoid JIT overhead
|
||||
from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
|
||||
|
||||
self.tree_cache = RadixCacheCpp(
|
||||
disable=False,
|
||||
use_hicache=self.enable_hierarchical_cache,
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool_allocator,
|
||||
tp_cache_group=self.tp_cpu_group,
|
||||
page_size=self.page_size,
|
||||
hicache_ratio=server_args.hicache_ratio,
|
||||
hicache_size=server_args.hicache_size,
|
||||
hicache_write_policy=server_args.hicache_write_policy,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
)
|
||||
elif self.enable_hierarchical_cache:
|
||||
self.tree_cache = HiRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
|
||||
1
python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format
Symbolic link
1
python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../sgl-kernel/.clang-format
|
||||
29
python/sglang/srt/mem_cache/cpp_radix_tree/common.h
Normal file
29
python/sglang/srt/mem_cache/cpp_radix_tree/common.h
Normal file
@@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <source_location>
|
||||
#include <span>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace radix_tree_v2 {
|
||||
|
||||
using token_t = std::int32_t;
|
||||
using token_vec_t = std::vector<token_t>;
|
||||
using token_slice = std::span<const token_t>;
|
||||
using NodeHandle = std::size_t;
|
||||
using IOTicket = std::uint32_t;
|
||||
|
||||
inline void _assert(
|
||||
bool condition,
|
||||
const char* message = "Assertion failed",
|
||||
std::source_location loc = std::source_location::current()) {
|
||||
if (!condition) [[unlikely]] {
|
||||
std::string msg = message;
|
||||
msg = msg + " at " + loc.file_name() + ":" + std::to_string(loc.line()) + " in " + loc.function_name();
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace radix_tree_v2
|
||||
182
python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py
Normal file
182
python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
radix_tree_cpp = load(
|
||||
name="radix_tree_cpp",
|
||||
sources=[
|
||||
f"{_abs_path}/tree_v2_binding.cpp",
|
||||
f"{_abs_path}/tree_v2_debug.cpp",
|
||||
f"{_abs_path}/tree_v2.cpp",
|
||||
],
|
||||
extra_cflags=["-O3", "-std=c++20"],
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class TreeNodeCpp:
|
||||
"""
|
||||
A placeholder for the TreeNode class. Cannot be constructed elsewhere.
|
||||
"""
|
||||
|
||||
class IOHandle:
|
||||
"""
|
||||
A placeholder for the IOHandle class. Cannot be constructed elsewhere.
|
||||
"""
|
||||
|
||||
class RadixTreeCpp:
|
||||
def __init__(
|
||||
self,
|
||||
disabled: bool,
|
||||
host_size: Optional[int],
|
||||
page_size: int,
|
||||
write_through_threshold: int,
|
||||
):
|
||||
"""
|
||||
Initializes the RadixTreeCpp instance.
|
||||
Args:
|
||||
disabled (bool): If True, the radix tree is disabled.
|
||||
host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
|
||||
page_size (int): Size of the page for the radix tree.
|
||||
write_through_threshold (int): Threshold for writing through from GPU to CPU.
|
||||
"""
|
||||
self.tree = radix_tree_cpp.RadixTree( # type: ignore
|
||||
disabled, host_size, page_size, write_through_threshold
|
||||
)
|
||||
|
||||
def match_prefix(
|
||||
self, prefix: List[int]
|
||||
) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
|
||||
"""
|
||||
Matches a prefix in the radix tree.
|
||||
Args:
|
||||
prefix (List[int]): The prefix to match.
|
||||
Returns:
|
||||
Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
|
||||
0. A list of indices that is matched by the prefix on the GPU.
|
||||
1. Sum length of the indices matched on the CPU.
|
||||
2. The last node of the prefix matched on the GPU.
|
||||
3. The last node of the prefix matched on the CPU.
|
||||
"""
|
||||
return self.tree.match_prefix(prefix)
|
||||
|
||||
def evict(self, num_tokens: int) -> List[torch.Tensor]:
|
||||
"""
|
||||
Evicts a number of tokens from the radix tree.
|
||||
Args:
|
||||
num_tokens (int): The number of tokens to evict.
|
||||
Returns:
|
||||
List[torch.Tensor]: A list of indices that were evicted.
|
||||
"""
|
||||
return self.tree.evict(num_tokens)
|
||||
|
||||
def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
|
||||
"""
|
||||
Locks or unlocks a reference to a tree node.
|
||||
After locking, the node will not be evicted from the radix tree.
|
||||
Args:
|
||||
handle (TreeNodeCpp): The tree node to lock or unlock.
|
||||
lock (bool): If True, locks the node; if False, unlocks it.
|
||||
"""
|
||||
return self.tree.lock_ref(handle, lock)
|
||||
|
||||
def writing_through(
|
||||
self, key: List[int], indices: torch.Tensor
|
||||
) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
||||
"""
|
||||
Inserts a key-value pair into the radix tree and perform write-through check.
|
||||
Args:
|
||||
key (List[int]): The key to insert.
|
||||
indices (torch.Tensor): The value associated with the key.
|
||||
Returns:
|
||||
Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
||||
0. A list of (IOHandle, device indices, host indices) tuples.
|
||||
These IOhandles require write-through to the CPU in python side.
|
||||
1. The number of indices that are matched on device.
|
||||
"""
|
||||
return self.tree.writing_through(key, indices)
|
||||
|
||||
def loading_onboard(
|
||||
self,
|
||||
host_node: TreeNodeCpp,
|
||||
new_device_indices: torch.Tensor,
|
||||
) -> Tuple[IOHandle, List[torch.Tensor]]:
|
||||
"""
|
||||
Updates the device indices of tree nodes within a range on the tree.
|
||||
Args:
|
||||
host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
|
||||
new_device_indices (torch.Tensor): The new device indices to set.
|
||||
The length of this tensor must be exactly host indices length.
|
||||
Returns:
|
||||
Tuple[IOHandle, List[torch.Tensor]]:
|
||||
0. An IOHandle that requires loading to the CPU in python side.
|
||||
1. A list of host indices corresponding to the new device indices.
|
||||
"""
|
||||
return self.tree.loading_onboard(host_node, new_device_indices)
|
||||
|
||||
def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
|
||||
"""
|
||||
Commits the write-through process for a tree node.
|
||||
Args:
|
||||
handle (IOHandle): The IOHandle to commit.
|
||||
success (bool): If True, commits the write-through; if False, just indicates failure.
|
||||
"""
|
||||
return self.tree.commit_writing_through(handle, success)
|
||||
|
||||
def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
|
||||
"""
|
||||
Commits the load onboard process for tree nodes within a range on the tree.
|
||||
Args:
|
||||
handle (IOHandle): The IOHandle to commit.
|
||||
success (bool): If True, commits the load-onboard; if False, just indicates failure.
|
||||
"""
|
||||
return self.tree.commit_loading_onboard(handle, success)
|
||||
|
||||
def evictable_size(self) -> int:
|
||||
"""
|
||||
Returns the size of the evictable part of the radix tree.
|
||||
This is the size of the part that can be evicted from the GPU (ref_count = 0).
|
||||
Returns:
|
||||
int: The size of the evictable part.
|
||||
"""
|
||||
return self.tree.evictable_size()
|
||||
|
||||
def protected_size(self) -> int:
|
||||
"""
|
||||
Returns the size of the protected part of the radix tree.
|
||||
This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
|
||||
Returns:
|
||||
int: The size of the protected part.
|
||||
"""
|
||||
return self.tree.protected_size()
|
||||
|
||||
def total_size(self) -> int:
|
||||
"""
|
||||
Returns the total size of the radix tree (including CPU nodes).
|
||||
Returns:
|
||||
int: The total size of the radix tree.
|
||||
"""
|
||||
return self.tree.total_size()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Resets the radix tree, clearing all nodes and indices.
|
||||
"""
|
||||
return self.tree.reset()
|
||||
|
||||
def debug_print(self) -> None:
|
||||
"""
|
||||
Prints the internal state of the radix tree for debugging purposes.
|
||||
"""
|
||||
return self.tree.debug_print()
|
||||
|
||||
else:
|
||||
# Real implementation of the classes for runtime
|
||||
RadixTreeCpp = radix_tree_cpp.RadixTree
|
||||
TreeNodeCpp = object
|
||||
IOHandle = object
|
||||
143
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp
Normal file
143
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp
Normal file
@@ -0,0 +1,143 @@
|
||||
#include "tree_v2.h"
|
||||
|
||||
#include <ATen/core/TensorBody.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "tree_v2_impl.h"
|
||||
#include "tree_v2_node.h"
|
||||
|
||||
namespace radix_tree_v2 {
|
||||
|
||||
static NodeHandle node2id(TreeNode* node) {
|
||||
return node->node_id;
|
||||
}
|
||||
|
||||
// compare function for the TreeNode pointers based on their time
|
||||
// we use LRU, so we want to evict the least recently used nodes
|
||||
// since std::priority_queue is a max-heap, we need to reverse the comparison
|
||||
static constexpr auto cmp = [](TreeNode* lhs, TreeNode* rhs) { return lhs->time() > rhs->time(); };
|
||||
|
||||
RadixTree::RadixTree(bool disabled, std::optional<std::size_t> host_size, std::size_t page_size, std::size_t threshold)
|
||||
: m_impl(std::make_unique<Impl>(disabled, host_size.has_value(), page_size, host_size.value_or(0), threshold)) {}
|
||||
|
||||
RadixTree::~RadixTree() = default;
|
||||
|
||||
std::tuple<std::vector<at::Tensor>, std::size_t, NodeHandle, NodeHandle>
|
||||
RadixTree::match_prefix(const token_vec_t& _key) {
|
||||
if (m_impl->disabled) return {};
|
||||
|
||||
const auto key = token_slice{_key.data(), m_impl->align(_key.size())};
|
||||
const auto [host_node, _] = m_impl->tree_walk(key);
|
||||
|
||||
// walk up to the first non-evicted node
|
||||
std::size_t host_hit_length = 0;
|
||||
const auto device_node = host_node;
|
||||
|
||||
// collect all the device indices
|
||||
std::vector<at::Tensor> indices{};
|
||||
walk_to_root(device_node, [&](TreeNode* n) { indices.push_back(n->device_indices()); });
|
||||
std::reverse(indices.begin(), indices.end());
|
||||
|
||||
return {std::move(indices), host_hit_length, node2id(device_node), node2id(host_node)};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> RadixTree::evict(std::size_t num_tokens) {
|
||||
if (m_impl->disabled || num_tokens == 0) return {};
|
||||
auto heap = std::priority_queue{cmp, m_impl->collect_leaves_device()};
|
||||
std::vector<at::Tensor> evicted_values;
|
||||
// evict nodes until we reach the desired number of tokens
|
||||
std::size_t num_evict = 0;
|
||||
while (num_evict < num_tokens && !heap.empty()) {
|
||||
const auto node = heap.top();
|
||||
heap.pop();
|
||||
// when ref_count == 0, can't be writing through
|
||||
_assert(node->on_gpu() && node->ref_count == 0);
|
||||
if (!node->is_io_free()) continue; // skip nodes that are undergoing IO (i.e. indices protected)
|
||||
evicted_values.push_back(node->device_indices());
|
||||
num_evict += node->length();
|
||||
const auto parent = node->parent();
|
||||
m_impl->remove_device_node(node);
|
||||
if (parent->is_leaf_device() && parent->ref_count == 0)
|
||||
heap.push(parent); // push parent to the heap if it is now a free leaf
|
||||
}
|
||||
|
||||
return evicted_values;
|
||||
}
|
||||
|
||||
std::tuple<std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>>, std::size_t>
|
||||
RadixTree::writing_through(const token_vec_t& _key, at::Tensor value) {
|
||||
if (m_impl->disabled) return {};
|
||||
_assert(_key.size() == std::size_t(value.size(0)), "Key and value must have the same size");
|
||||
|
||||
// just align the key to the page size, clip the unaligned tail
|
||||
const auto key = token_slice{_key.data(), m_impl->align(_key.size())};
|
||||
|
||||
// walk the tree to find the right place to insert
|
||||
const auto [host_node, host_prefix_length] = m_impl->tree_walk(key);
|
||||
|
||||
// insert and create a new node if the remaining part of the key is not empty
|
||||
if (host_prefix_length != key.size()) {
|
||||
m_impl->create_device_node(
|
||||
host_node,
|
||||
{key.begin() + host_prefix_length, key.end()},
|
||||
value.slice(/*dim=*/0, host_prefix_length, key.size()));
|
||||
}
|
||||
|
||||
// add the hit count for the device node
|
||||
walk_to_root(host_node, [&](TreeNode* n) { n->hit_count++; });
|
||||
|
||||
std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>> result;
|
||||
|
||||
// don't write through if hicache is disabled (no host memory), fast path
|
||||
if (!m_impl->use_hicache) return {std::move(result), host_prefix_length};
|
||||
throw std::runtime_error("Not implemented yet");
|
||||
}
|
||||
|
||||
std::tuple<IOTicket, std::vector<at::Tensor>> RadixTree::loading_onboard(NodeHandle, at::Tensor) {
|
||||
if (m_impl->disabled) return {};
|
||||
throw std::runtime_error("Not implemented yet");
|
||||
}
|
||||
|
||||
void RadixTree::commit_writing_through(IOTicket, bool) {
|
||||
if (m_impl->disabled) return;
|
||||
throw std::runtime_error("Not implemented yet");
|
||||
}
|
||||
|
||||
void RadixTree::commit_loading_onboard(IOTicket, bool) {
|
||||
if (m_impl->disabled) return;
|
||||
throw std::runtime_error("Not implemented yet");
|
||||
}
|
||||
|
||||
void RadixTree::reset() {
|
||||
m_impl->reset();
|
||||
}
|
||||
|
||||
void RadixTree::lock_ref(NodeHandle node_id, bool increment) {
|
||||
if (m_impl->disabled) return;
|
||||
m_impl->lock_ref(node_id, increment);
|
||||
}
|
||||
|
||||
std::size_t RadixTree::evictable_size() const {
|
||||
return m_impl->evictable_size();
|
||||
}
|
||||
|
||||
std::size_t RadixTree::protected_size() const {
|
||||
return m_impl->protected_size();
|
||||
}
|
||||
|
||||
std::size_t RadixTree::total_size() const {
|
||||
return m_impl->total_size();
|
||||
}
|
||||
|
||||
} // namespace radix_tree_v2
|
||||
59
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h
Normal file
59
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h
Normal file
@@ -0,0 +1,59 @@
|
||||
#pragma once
|
||||
#include <ATen/core/TensorBody.h>
|
||||
#include <c10/core/Device.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
namespace radix_tree_v2 {
|
||||
|
||||
struct RadixTree {
|
||||
public:
|
||||
RadixTree(bool disabled, std::optional<std::size_t> host_size, std::size_t page_size, std::size_t threshold);
|
||||
~RadixTree();
|
||||
|
||||
// Trees should not be copied or moved, as they manage their own memory and state.
|
||||
RadixTree(const RadixTree&) = delete;
|
||||
RadixTree(RadixTree&&) = delete;
|
||||
RadixTree& operator=(const RadixTree&) = delete;
|
||||
RadixTree& operator=(RadixTree&&) = delete;
|
||||
|
||||
/// @return (device indices that are matched, host indices length, device node, host node)
|
||||
std::tuple<std::vector<at::Tensor>, std::size_t, NodeHandle, NodeHandle> match_prefix(const token_vec_t& key);
|
||||
/// @return Device indices that need to be evicted (on python side).
|
||||
std::vector<at::Tensor> evict(std::size_t num_tokens);
|
||||
/// @brief (Un-)Lock a node.
|
||||
void lock_ref(NodeHandle node_id, bool increment /* increment or decrement */);
|
||||
/// @brief Update new key-value pair and try to perform write-through.
|
||||
std::tuple<std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>>, std::size_t>
|
||||
writing_through(const token_vec_t& key, at::Tensor value);
|
||||
/// @brief Load to device from host within a range of nodes.
|
||||
std::tuple<IOTicket, std::vector<at::Tensor>> loading_onboard(NodeHandle host_id, at::Tensor indices);
|
||||
/// @brief Commit a transaction of write-through.
|
||||
void commit_writing_through(IOTicket ticket, bool success);
|
||||
/// @brief Commit a transaction of load onboard.
|
||||
void commit_loading_onboard(IOTicket ticket, bool success);
|
||||
/// @brief Clear and reset the tree.
|
||||
void reset();
|
||||
|
||||
/// @return How many size are still evictable (on device + not locked).
|
||||
std::size_t evictable_size() const;
|
||||
/// @return How many size are protected (locked).
|
||||
std::size_t protected_size() const;
|
||||
/// @return How many size are used on device.
|
||||
std::size_t total_size() const;
|
||||
|
||||
/// @brief Print debug information of the tree.
|
||||
void debug_print() const;
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> m_impl;
|
||||
};
|
||||
|
||||
} // namespace radix_tree_v2
|
||||
@@ -0,0 +1,32 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
|
||||
#include "tree_v2.h"
|
||||
|
||||
PYBIND11_MODULE(radix_tree_cpp, m) {
|
||||
using namespace radix_tree_v2;
|
||||
namespace py = pybind11;
|
||||
py::class_<RadixTree>(m, "RadixTree")
|
||||
.def(
|
||||
py::init<bool, std::optional<std::size_t>, std::size_t, std::size_t>(),
|
||||
py::arg("disabled"),
|
||||
py::arg("host_size"),
|
||||
py::arg("page_size"),
|
||||
py::arg("write_through_threshold"))
|
||||
.def("match_prefix", &RadixTree::match_prefix)
|
||||
.def("evict", &RadixTree::evict)
|
||||
.def("lock_ref", &RadixTree::lock_ref)
|
||||
.def("evictable_size", &RadixTree::evictable_size)
|
||||
.def("protected_size", &RadixTree::protected_size)
|
||||
.def("total_size", &RadixTree::total_size)
|
||||
.def("writing_through", &RadixTree::writing_through)
|
||||
.def("loading_onboard", &RadixTree::loading_onboard)
|
||||
.def("commit_writing_through", &RadixTree::commit_writing_through)
|
||||
.def("commit_loading_onboard", &RadixTree::commit_loading_onboard)
|
||||
.def("reset", &RadixTree::reset)
|
||||
.def("debug_print", &RadixTree::debug_print);
|
||||
}
|
||||
194
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp
Normal file
194
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp
Normal file
@@ -0,0 +1,194 @@
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "tree_v2.h"
|
||||
#include "tree_v2_impl.h"
|
||||
|
||||
namespace radix_tree_v2 {
|
||||
|
||||
void RadixTree::debug_print() const {
|
||||
m_impl->debug_print(std::clog);
|
||||
}
|
||||
|
||||
static constexpr auto npos = std::size_t(-1);
|
||||
|
||||
void RadixTree::Impl::debug_print(std::ostream& os) const {
|
||||
static constexpr auto _check = [](bool condition, auto msg, std::size_t id = npos) {
|
||||
if (!condition) {
|
||||
std::string suffix = id == npos ? "" : " [id = " + std::to_string(id) + "]";
|
||||
throw std::runtime_error(std::string("RadixTree::debug_print failed: ") + msg + suffix);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr auto _print_node = [](TreeNode* node, std::size_t depth, std::ostream& os) {
|
||||
const auto length = node->length();
|
||||
os << node->node_id << " [depth = " << depth << "] [len = " << length << "]";
|
||||
|
||||
// placement status
|
||||
if (node->on_both()) {
|
||||
os << " [cpu + gpu]";
|
||||
} else if (node->on_gpu()) {
|
||||
os << " [gpu]";
|
||||
} else if (node->on_cpu()) {
|
||||
os << " [cpu]";
|
||||
} else {
|
||||
_check(false, "Node is not on GPU or CPU", node->node_id);
|
||||
}
|
||||
|
||||
// IO status
|
||||
if (node->is_io_free()) {
|
||||
os << " [io = free]";
|
||||
} else if (node->is_io_device_to_host()) {
|
||||
os << " [io = gpu -> cpu]";
|
||||
} else if (node->is_io_host_to_device()) {
|
||||
os << " [io = cpu -> gpu]";
|
||||
} else {
|
||||
_check(false, "Node is in unknown IO state", node->node_id);
|
||||
}
|
||||
|
||||
os << " [rc = " << node->ref_count << "]";
|
||||
os << " [hit = " << node->hit_count << "]";
|
||||
};
|
||||
|
||||
static constexpr auto _print_indices = [](at::Tensor indices, std::ostream& os) {
|
||||
if (!indices.defined()) {
|
||||
os << "[[N/A]]";
|
||||
return indices;
|
||||
}
|
||||
indices = indices.to(c10::kCPU, c10::kLong, false, false, c10::MemoryFormat::Contiguous);
|
||||
const auto length = indices.numel();
|
||||
os << "[";
|
||||
auto* data_ptr = indices.data_ptr<int64_t>();
|
||||
for (const auto i : c10::irange(indices.size(0))) {
|
||||
os << data_ptr[i];
|
||||
if (i != length - 1) os << ", ";
|
||||
}
|
||||
os << "]";
|
||||
return indices;
|
||||
};
|
||||
|
||||
os << "Evictable size: " << evictable_size() << std::endl;
|
||||
os << "Protected size: " << protected_size() << std::endl;
|
||||
os << "Total size: " << const_cast<Impl*>(this)->total_size() << std::endl;
|
||||
std::vector<std::tuple<TreeNode*, TreeNode*, token_slice>> stack;
|
||||
auto root = const_cast<TreeNode*>(&m_root);
|
||||
os << root->node_id << " [root]" << std::endl;
|
||||
for (const auto& [key, child] : *root) {
|
||||
stack.push_back({child.get(), root, key});
|
||||
}
|
||||
|
||||
std::unordered_map<TreeNode*, std::size_t> depth_map;
|
||||
std::string indent_buffer;
|
||||
depth_map[root] = 0;
|
||||
std::vector<NodeHandle> visited_id;
|
||||
std::size_t evictable_size_real = 0;
|
||||
while (!stack.empty()) {
|
||||
const auto [node, parent, key] = stack.back();
|
||||
stack.pop_back();
|
||||
visited_id.push_back(node->node_id);
|
||||
const auto nid = node->node_id;
|
||||
_check(node != nullptr, "Node is null", nid);
|
||||
_check(node->on_gpu() || node->on_cpu(), "Node is not on GPU or CPU", nid);
|
||||
_check(node->parent() == parent, "Parent is not correct", nid);
|
||||
_check(key.size() == page_size && node->diff_key(key, 0) == page_size, "Key is not correct", nid);
|
||||
_check(depth_map.count(node) == 0, "Node is visited twice", nid);
|
||||
_check(m_node_map.count(nid) == 1, "Node is not in the map", nid);
|
||||
_check(m_node_map.at(nid) == node, "Node in the map is not the same as the one in the stack", nid);
|
||||
_check(!node->on_gpu() || parent->is_root() || parent->on_gpu(), "Node on GPU must have a GPU/root parent", nid);
|
||||
if (!node->is_io_free()) {
|
||||
_check(node->ref_count > 0, "Node is in IO state but not protected", nid);
|
||||
_check(node->on_both(), "Node in IO state must be on both CPU and GPU", nid);
|
||||
}
|
||||
|
||||
if (node->on_gpu() && node->ref_count == 0) {
|
||||
evictable_size_real += node->length();
|
||||
}
|
||||
|
||||
const auto depth = (depth_map[node] = depth_map[parent] + 1);
|
||||
indent_buffer.resize(depth * 2, ' ');
|
||||
os << indent_buffer;
|
||||
_print_node(node, depth, os);
|
||||
os << std::endl;
|
||||
for (const auto& [key, child] : *node) {
|
||||
stack.push_back({child.get(), node, key});
|
||||
}
|
||||
}
|
||||
|
||||
_check(evictable_size_real == evictable_size(), "Evictable size is wrong");
|
||||
_check(m_node_map.count(root->node_id) == 1, "Root node is not in the map");
|
||||
_check(m_node_map.at(root->node_id) == root, "Root node in the map is not correct");
|
||||
|
||||
std::sort(visited_id.begin(), visited_id.end());
|
||||
if (visited_id.size() != m_node_map.size() - 1) {
|
||||
// Some error in the tree, not all nodes are visited
|
||||
std::string id_list;
|
||||
id_list += "(visited: ";
|
||||
id_list += std::to_string(root->node_id) + " ";
|
||||
for (const auto& id : visited_id) {
|
||||
id_list += std::to_string(id) + " ";
|
||||
}
|
||||
id_list += "), (in map: ";
|
||||
for (const auto& [id, _] : m_node_map) {
|
||||
id_list += std::to_string(id) + " ";
|
||||
}
|
||||
id_list += ")";
|
||||
_check(false, "Not all nodes are visited " + id_list);
|
||||
}
|
||||
|
||||
static const auto kSGLANG_RADIX_CPP_DEBUG_LIMIT = [] {
|
||||
const char* env = std::getenv("SGLANG_RADIX_CPP_DEBUG_LIMIT");
|
||||
const std::size_t default_limit = 16;
|
||||
if (env != nullptr) {
|
||||
try {
|
||||
return static_cast<std::size_t>(std::stoull(env));
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Invalid SGLANG_RADIX_CPP_DEBUG_LIMIT value: " << env //
|
||||
<< ". Using default value =" << default_limit << std::endl;
|
||||
}
|
||||
}
|
||||
return default_limit;
|
||||
}();
|
||||
|
||||
for (const auto nid : visited_id) {
|
||||
const auto node = m_node_map.at(nid);
|
||||
// print key and indices
|
||||
const auto& key = node->_unsafe_tokens();
|
||||
if (key.size() > kSGLANG_RADIX_CPP_DEBUG_LIMIT) {
|
||||
os << "Node " << nid << ": key is too long (" << key.size() << " tokens), skipping..." << std::endl;
|
||||
continue;
|
||||
}
|
||||
os << "Node " << nid << ": key = [";
|
||||
for (const auto& i : c10::irange(key.size())) {
|
||||
os << key[i];
|
||||
if (i != key.size() - 1) os << ", ";
|
||||
}
|
||||
|
||||
_check(key.size() % page_size == 0, "Misaligned key", nid);
|
||||
|
||||
os << "] device_indices = ";
|
||||
const auto device_indices = _print_indices(node->device_indices(), os);
|
||||
if (device_indices.defined()) {
|
||||
std::size_t length = device_indices.numel();
|
||||
_check(device_indices.dim() == 1, "Device indices must be 1D tensor", nid);
|
||||
_check(length == node->length(), "Wrong device indices size", nid);
|
||||
}
|
||||
|
||||
os << " host_indices = ";
|
||||
const auto host_indices = _print_indices(node->host_indices(), os);
|
||||
if (host_indices.defined()) {
|
||||
std::size_t length = host_indices.numel();
|
||||
_check(host_indices.dim() == 1, "Host indices must be 1D tensor", nid);
|
||||
_check(length == node->length(), "Wrong host indices size", nid);
|
||||
}
|
||||
os << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace radix_tree_v2
|
||||
276
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h
Normal file
276
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h
Normal file
@@ -0,0 +1,276 @@
|
||||
#pragma once
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <iosfwd>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "tree_v2.h"
|
||||
#include "tree_v2_node.h"
|
||||
|
||||
namespace radix_tree_v2 {
|
||||
|
||||
using node_iterator_t = typename TreeNode::iterator_t;
|
||||
|
||||
struct RadixTree::Impl {
|
||||
public:
|
||||
Impl(bool disabled, bool use_hicache, std::size_t page_size, std::size_t host_size, std::size_t threshold)
|
||||
: m_root(/*node_id_=*/0),
|
||||
m_evictable_size(0),
|
||||
m_protected_size(0),
|
||||
m_cached_vec(),
|
||||
m_node_map(),
|
||||
m_node_counter(1), // start from 1 to avoid confusion with root node
|
||||
disabled(disabled),
|
||||
use_hicache(use_hicache),
|
||||
page_size(page_size),
|
||||
threshold(threshold) {
|
||||
_assert(page_size > 0, "Page size must be greater than zero");
|
||||
_assert(use_hicache == (host_size > 0), "Hierarchical cache is enabled iff host size > 0");
|
||||
m_root.ref_count = 1; // root node is always protected
|
||||
m_cached_vec.reserve(page_size); // to avoid repeated allocations
|
||||
m_node_map[m_root.node_id] = &m_root; // add root to the map
|
||||
}
|
||||
|
||||
TreeNode* split_node(node_iterator_t iterator, std::size_t prefix_length) {
|
||||
// from `parent -> old_node` to `parent-> new_node -> old_node`
|
||||
// the prefix part of the old node is moved to the new node
|
||||
auto old_node_ptr = std::move(iterator->second);
|
||||
auto new_node_ptr = std::make_unique<TreeNode>(m_node_counter++);
|
||||
auto* old_node = old_node_ptr.get();
|
||||
auto* new_node = new_node_ptr.get();
|
||||
auto* parent = old_node->parent();
|
||||
// set up data structures
|
||||
split_prefix(new_node, old_node, prefix_length);
|
||||
// set up parent-child relationship
|
||||
add_child(new_node, std::move(old_node_ptr));
|
||||
add_child(parent, std::move(new_node_ptr), iterator);
|
||||
m_node_map[new_node->node_id] = new_node; // add to the map
|
||||
return new_node;
|
||||
}
|
||||
|
||||
// node: x -> [GPU]
|
||||
TreeNode* create_device_node(TreeNode* parent, token_vec_t vec, at::Tensor indices) {
|
||||
auto new_node_ptr = std::make_unique<TreeNode>(m_node_counter++);
|
||||
auto new_node = new_node_ptr.get();
|
||||
new_node_ptr->_unsafe_tokens() = std::move(vec);
|
||||
new_node_ptr->_unsafe_device_indices() = std::move(indices);
|
||||
m_evictable_size += new_node_ptr->length();
|
||||
add_child(parent, std::move(new_node_ptr));
|
||||
m_node_map[new_node->node_id] = new_node; // add to the map
|
||||
return new_node;
|
||||
}
|
||||
|
||||
// node: [GPU] -> x
|
||||
void remove_device_node(TreeNode* node) {
|
||||
_assert(node->on_gpu_only() && node->ref_count == 0);
|
||||
m_evictable_size -= node->length();
|
||||
node->parent()->erase_child(get_key(node));
|
||||
m_node_map.erase(node->node_id); // remove from the map
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Walk the tree to find the node that matches the key.
|
||||
* If the key partially matches a node, it will split that node.
|
||||
* @return A pair containing the last node that matches the key and
|
||||
* the total prefix length matched (on gpu and cpu) so far.
|
||||
*/
|
||||
std::pair<TreeNode*, std::size_t> tree_walk(token_slice key) {
|
||||
_assert(key.size() % page_size == 0, "Key should be page-aligned");
|
||||
|
||||
std::size_t total_prefix_length = 0;
|
||||
TreeNode* node = &m_root;
|
||||
|
||||
const auto now = std::chrono::steady_clock::now();
|
||||
while (key.size() > 0) {
|
||||
const auto iterator = node->find_child(get_key(key));
|
||||
if (iterator == node->end()) break;
|
||||
|
||||
// walk to the child node
|
||||
node = iterator->second.get();
|
||||
|
||||
// at least `page_size` tokens are matched, and there may be more tokens to match
|
||||
// the return value prefix_length is no less than `page_size`
|
||||
const auto prefix_length = align(node->diff_key(key, page_size) + page_size);
|
||||
total_prefix_length += prefix_length;
|
||||
|
||||
// split the node if the prefix is not the whole token vector
|
||||
if (prefix_length < node->length()) {
|
||||
return {split_node(iterator, prefix_length), total_prefix_length};
|
||||
}
|
||||
|
||||
// we have matched the whole key, continue to the next node
|
||||
node->access(now);
|
||||
key = key.subspan(prefix_length);
|
||||
}
|
||||
|
||||
return {node, total_prefix_length};
|
||||
}
|
||||
|
||||
std::vector<TreeNode*> collect_leaves() const {
|
||||
std::vector<TreeNode*> leaves;
|
||||
std::vector<TreeNode*> stack = {};
|
||||
for (const auto& [_, child] : m_root) {
|
||||
stack.push_back(child.get());
|
||||
}
|
||||
while (!stack.empty()) {
|
||||
const auto node = stack.back();
|
||||
stack.pop_back();
|
||||
if (node->is_leaf()) {
|
||||
if (node->ref_count == 0) {
|
||||
leaves.push_back(node);
|
||||
}
|
||||
} else {
|
||||
for (const auto& [_, child] : *node) {
|
||||
stack.push_back(child.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
return leaves;
|
||||
}
|
||||
|
||||
std::vector<TreeNode*> collect_leaves_device() const {
|
||||
// for non-hicache, every leaf device node is a leaf node (since no backup on host)
|
||||
if (!use_hicache) return collect_leaves();
|
||||
std::vector<TreeNode*> leaves;
|
||||
std::vector<TreeNode*> stack = {};
|
||||
for (const auto& [_, child] : m_root) {
|
||||
stack.push_back(child.get());
|
||||
}
|
||||
while (!stack.empty()) {
|
||||
const auto node = stack.back();
|
||||
stack.pop_back();
|
||||
if (!node->on_gpu()) continue; // skip nodes that are not on GPU
|
||||
if (node->is_leaf_device()) {
|
||||
if (node->ref_count == 0) {
|
||||
leaves.push_back(node);
|
||||
}
|
||||
} else {
|
||||
for (const auto& [_, child] : *node) {
|
||||
stack.push_back(child.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
return leaves;
|
||||
}
|
||||
|
||||
void lock_ref(TreeNode* node, bool increment) {
|
||||
if (node->is_root()) return; // skip root node
|
||||
_assert(node->on_gpu(), "Cannot lock reference on an evicted node");
|
||||
if (increment)
|
||||
walk_to_root(node, [this](TreeNode* n) {
|
||||
if (n->ref_count == 0) {
|
||||
m_evictable_size -= n->length();
|
||||
m_protected_size += n->length();
|
||||
}
|
||||
n->ref_count++;
|
||||
});
|
||||
else
|
||||
walk_to_root(node, [this](TreeNode* n) {
|
||||
_assert(n->ref_count != 0, "Cannot decrement reference count = zero");
|
||||
n->ref_count--;
|
||||
if (n->ref_count == 0) {
|
||||
m_protected_size -= n->length();
|
||||
m_evictable_size += n->length();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void lock_ref(NodeHandle node_ptr, bool increment) {
|
||||
return lock_ref(id2node(node_ptr), increment);
|
||||
}
|
||||
|
||||
void lock(TreeNode* node) {
|
||||
return lock_ref(node, /*increment=*/true);
|
||||
}
|
||||
|
||||
void unlock(TreeNode* node) {
|
||||
return lock_ref(node, /*increment=*/false);
|
||||
}
|
||||
|
||||
std::size_t total_size() const {
|
||||
std::size_t size = 0;
|
||||
std::vector<const TreeNode*> stack = {&m_root};
|
||||
while (!stack.empty()) {
|
||||
auto* node = stack.back();
|
||||
stack.pop_back();
|
||||
size += node->length();
|
||||
for (const auto& [_, child] : *node)
|
||||
stack.push_back(child.get());
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
std::size_t evictable_size() const {
|
||||
return m_evictable_size;
|
||||
}
|
||||
|
||||
std::size_t protected_size() const {
|
||||
return m_protected_size;
|
||||
}
|
||||
|
||||
std::size_t align(std::size_t size) const {
|
||||
return (size / page_size) * page_size; // align to page size
|
||||
}
|
||||
|
||||
TreeNode* id2node(NodeHandle node_id) const {
|
||||
const auto iterator = m_node_map.find(node_id);
|
||||
_assert(iterator != m_node_map.end(), "Node not found in the map");
|
||||
return iterator->second;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
_assert(m_root.ref_count == 1, "Root node must be protected during reset");
|
||||
m_node_counter = 1; // reset node counter
|
||||
m_root.root_reset();
|
||||
m_evictable_size = 0;
|
||||
m_protected_size = 0;
|
||||
m_node_map.clear();
|
||||
m_node_map[m_root.node_id] = &m_root; // re-add root to the map
|
||||
}
|
||||
|
||||
void debug_print(std::ostream& os) const;
|
||||
|
||||
private:
|
||||
// some auxiliary functions
|
||||
token_vec_t& get_key(token_slice tokens) {
|
||||
_assert(tokens.size() >= page_size, "Key should be at least page-sized");
|
||||
tokens = tokens.subspan(0, page_size);
|
||||
m_cached_vec.assign(tokens.begin(), tokens.end());
|
||||
return m_cached_vec;
|
||||
}
|
||||
|
||||
// justify for _unsafe call: we need to read the key part of the tokens
|
||||
token_vec_t& get_key(TreeNode* node) {
|
||||
return get_key(node->_unsafe_tokens());
|
||||
}
|
||||
|
||||
void add_child(TreeNode* parent, std::unique_ptr<TreeNode>&& child) {
|
||||
parent->add_child(get_key(child.get()), std::move(child));
|
||||
}
|
||||
|
||||
void add_child(TreeNode* parent, std::unique_ptr<TreeNode>&& child, node_iterator_t it) {
|
||||
parent->add_child(it, std::move(child));
|
||||
}
|
||||
|
||||
TreeNode m_root; // root node of the tree
|
||||
std::size_t m_evictable_size; // number of evictable tokens on GPU (lock ref = 0)
|
||||
std::size_t m_protected_size; // number of protected tokens on GPU (lock ref > 0)
|
||||
token_vec_t m_cached_vec; // cached vector of tokens for the current operation
|
||||
std::unordered_map<std::size_t, TreeNode*> m_node_map; // map of node keys to nodes
|
||||
std::size_t m_node_counter; // counter for node IDs
|
||||
|
||||
public:
|
||||
// some public constant configurations (without m_ prefix)
|
||||
const bool disabled; // whether the cache is enabled, or just a temporary cache
|
||||
const bool use_hicache; // whether to use the HiCache for this tree
|
||||
const std::size_t page_size; // size of each page in the cache
|
||||
const std::size_t threshold; // threshold for write_through
|
||||
};
|
||||
|
||||
} // namespace radix_tree_v2
|
||||
257
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h
Normal file
257
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h
Normal file
@@ -0,0 +1,257 @@
|
||||
#pragma once
|
||||
#include <ATen/core/TensorBody.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <ranges>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
namespace radix_tree_v2 {
|
||||
|
||||
struct std_vector_hash {
|
||||
// see https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
|
||||
std::size_t operator()(const token_vec_t& vec) const {
|
||||
std::size_t hash = 0;
|
||||
for (const auto& token : vec) {
|
||||
hash ^= token + 0x9e3779b9 + (hash << 6) + (hash >> 2);
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
struct TreeNode {
|
||||
public:
|
||||
using childern_map_t = std::unordered_map<token_vec_t, std::unique_ptr<TreeNode>, std_vector_hash>;
|
||||
using iterator_t = typename childern_map_t::iterator;
|
||||
using const_iterator_t = typename childern_map_t::const_iterator;
|
||||
using timestamp_t = std::chrono::steady_clock::time_point;
|
||||
|
||||
TreeNode(std::size_t node_id_)
|
||||
: ref_count(0),
|
||||
hit_count(0),
|
||||
m_io_locked(std::nullopt),
|
||||
m_io_status(IOStatus::None),
|
||||
m_io_ticket(),
|
||||
m_tokens(),
|
||||
m_device_indices(),
|
||||
m_host_indices(),
|
||||
m_parent(),
|
||||
m_children(),
|
||||
m_last_access_time(std::chrono::steady_clock::now()),
|
||||
node_id(node_id_) {}
|
||||
|
||||
void access(timestamp_t time = std::chrono::steady_clock::now()) {
|
||||
m_last_access_time = time;
|
||||
}
|
||||
|
||||
bool is_root() const {
|
||||
return m_parent == nullptr;
|
||||
}
|
||||
|
||||
timestamp_t time() const {
|
||||
return m_last_access_time;
|
||||
}
|
||||
|
||||
bool on_gpu() const {
|
||||
return m_device_indices.defined();
|
||||
}
|
||||
|
||||
bool on_cpu() const {
|
||||
return m_host_indices.defined();
|
||||
}
|
||||
|
||||
bool on_gpu_only() const {
|
||||
return on_gpu() && !on_cpu();
|
||||
}
|
||||
|
||||
bool on_cpu_only() const {
|
||||
return !on_gpu() && on_cpu();
|
||||
}
|
||||
|
||||
bool on_both() const {
|
||||
return on_gpu() && on_cpu();
|
||||
}
|
||||
|
||||
std::size_t length() const {
|
||||
return m_tokens.size();
|
||||
}
|
||||
|
||||
bool is_leaf() const {
|
||||
return m_children.empty();
|
||||
}
|
||||
|
||||
bool is_leaf_device() const {
|
||||
for (const auto& [_, child] : m_children)
|
||||
if (child->on_gpu()) return false; // at least one child is on the device
|
||||
return true;
|
||||
}
|
||||
|
||||
void add_child(const token_vec_t& v, std::unique_ptr<TreeNode>&& child) {
|
||||
child->m_parent = this;
|
||||
m_children[v] = std::move(child);
|
||||
}
|
||||
|
||||
void add_child(iterator_t it, std::unique_ptr<TreeNode>&& child) {
|
||||
child->m_parent = this;
|
||||
it->second = std::move(child);
|
||||
}
|
||||
|
||||
void erase_child(const token_vec_t& v) {
|
||||
_assert(m_children.erase(v) > 0, "Child node not found");
|
||||
}
|
||||
|
||||
iterator_t find_child(const token_vec_t& v) {
|
||||
return m_children.find(v);
|
||||
}
|
||||
|
||||
iterator_t begin() {
|
||||
return m_children.begin();
|
||||
}
|
||||
|
||||
iterator_t end() {
|
||||
return m_children.end();
|
||||
}
|
||||
|
||||
const_iterator_t begin() const {
|
||||
return m_children.begin();
|
||||
}
|
||||
|
||||
const_iterator_t end() const {
|
||||
return m_children.end();
|
||||
}
|
||||
|
||||
TreeNode* parent() {
|
||||
return m_parent;
|
||||
}
|
||||
|
||||
// set up all data structures except for parent-child relationship
|
||||
friend void split_prefix(TreeNode* new_node, TreeNode* old_node, std::size_t prefix_length) {
|
||||
auto tokens = std::move(old_node->m_tokens);
|
||||
_assert(0 < prefix_length && prefix_length < tokens.size(), "Invalid prefix size for split");
|
||||
|
||||
// set up tokens
|
||||
old_node->m_tokens = token_vec_t(tokens.begin() + prefix_length, tokens.end());
|
||||
new_node->m_tokens = std::move(tokens);
|
||||
new_node->m_tokens.resize(prefix_length);
|
||||
|
||||
// set up values
|
||||
const int64_t new_size = new_node->length();
|
||||
const int64_t old_size = old_node->length();
|
||||
if (old_node->m_device_indices.defined()) {
|
||||
auto new_indices = old_node->m_device_indices.split_with_sizes({new_size, old_size});
|
||||
new_node->m_device_indices = std::move(new_indices[0]);
|
||||
old_node->m_device_indices = std::move(new_indices[1]);
|
||||
}
|
||||
if (old_node->m_host_indices.defined()) {
|
||||
auto new_indices = old_node->m_host_indices.split_with_sizes({new_size, old_size});
|
||||
new_node->m_host_indices = std::move(new_indices[0]);
|
||||
old_node->m_host_indices = std::move(new_indices[1]);
|
||||
}
|
||||
|
||||
// set up ref counts and hit counts
|
||||
new_node->ref_count = old_node->ref_count;
|
||||
new_node->hit_count = old_node->hit_count;
|
||||
|
||||
// If the old node (child) was locked for IO, the new node (parent) does not need
|
||||
// to be locked, since it is naturally protected by the child node's lock.
|
||||
if (old_node->m_io_locked.has_value()) {
|
||||
new_node->m_io_locked = false;
|
||||
new_node->m_io_status = old_node->m_io_status;
|
||||
new_node->m_io_ticket = old_node->m_io_ticket;
|
||||
}
|
||||
}
|
||||
|
||||
/// @return The first index in `m_tokens` that differs from `key`.
|
||||
std::size_t diff_key(token_slice key, std::size_t offset) const {
|
||||
const auto a = token_slice{key}.subspan(offset);
|
||||
const auto b = token_slice{m_tokens}.subspan(offset);
|
||||
const auto [it_a, it_b] = std::ranges::mismatch(a, b);
|
||||
return it_a - a.begin(); // return the index of the first differing token
|
||||
}
|
||||
|
||||
at::Tensor device_indices() const {
|
||||
return m_device_indices;
|
||||
}
|
||||
at::Tensor host_indices() const {
|
||||
return m_host_indices;
|
||||
}
|
||||
|
||||
// visiting tokens are always unsafe (use `diff_key` instead)
|
||||
token_vec_t& _unsafe_tokens() {
|
||||
return m_tokens;
|
||||
}
|
||||
at::Tensor& _unsafe_device_indices() {
|
||||
return m_device_indices;
|
||||
}
|
||||
at::Tensor& _unsafe_host_indices() {
|
||||
return m_host_indices;
|
||||
}
|
||||
|
||||
bool is_io_free() const {
|
||||
return m_io_status == IOStatus::None;
|
||||
}
|
||||
|
||||
bool is_io_device_to_host() const {
|
||||
return m_io_status == IOStatus::DeviceToHost;
|
||||
}
|
||||
|
||||
bool is_io_host_to_device() const {
|
||||
return m_io_status == IOStatus::HostToDevice;
|
||||
}
|
||||
|
||||
void root_reset() {
|
||||
_assert(is_root(), "Only root node can call root_reset");
|
||||
_assert(
|
||||
m_io_status == IOStatus::None && m_io_locked == std::nullopt,
|
||||
"IO operation in progress, cannot reset root node");
|
||||
_assert(this->m_tokens.empty(), "Root node tokens should be empty on reset");
|
||||
_assert(
|
||||
!this->m_device_indices.defined() && !this->m_host_indices.defined(),
|
||||
"Root node indices should be always be empty and never assigned");
|
||||
m_children.clear();
|
||||
this->access();
|
||||
}
|
||||
|
||||
public:
|
||||
std::size_t ref_count;
|
||||
std::size_t hit_count;
|
||||
|
||||
private:
|
||||
enum class IOStatus : std::uint8_t {
|
||||
None,
|
||||
HostToDevice,
|
||||
DeviceToHost,
|
||||
};
|
||||
|
||||
std::optional<bool> m_io_locked; // whether the node is locked in IO operation
|
||||
IOStatus m_io_status;
|
||||
IOTicket m_io_ticket;
|
||||
|
||||
token_vec_t m_tokens;
|
||||
at::Tensor m_device_indices; // indices of device value
|
||||
at::Tensor m_host_indices; // indices of host value
|
||||
TreeNode* m_parent;
|
||||
childern_map_t m_children;
|
||||
timestamp_t m_last_access_time;
|
||||
|
||||
public:
|
||||
const std::size_t node_id; // unique ID for the node
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
inline TreeNode* walk_to_root(TreeNode* t, const F& f) {
|
||||
while (!t->is_root()) {
|
||||
f(t);
|
||||
t = t->parent();
|
||||
}
|
||||
return t; // return the root node
|
||||
}
|
||||
|
||||
} // namespace radix_tree_v2
|
||||
229
python/sglang/srt/mem_cache/radix_cache_cpp.py
Normal file
229
python/sglang/srt/mem_cache/radix_cache_cpp.py
Normal file
@@ -0,0 +1,229 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Set
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
||||
IOHandle,
|
||||
RadixTreeCpp,
|
||||
TreeNodeCpp,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RadixCacheCpp(BasePrefixCache):
|
||||
def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Merge a list of tensors into a single tensor.
|
||||
Args:
|
||||
l (List[torch.Tensor]): List of tensors to merge.
|
||||
Returns:
|
||||
torch.Tensor: Merged tensor.
|
||||
"""
|
||||
if len(l) == 0:
|
||||
return torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
elif len(l) == 1:
|
||||
return l[0]
|
||||
else:
|
||||
return torch.cat(l)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
disable: bool,
|
||||
use_hicache: bool,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: BaseTokenToKVPoolAllocator,
|
||||
tp_cache_group: torch.distributed.ProcessGroup,
|
||||
page_size: int,
|
||||
hicache_ratio: float,
|
||||
hicache_size: int,
|
||||
hicache_write_policy: str,
|
||||
enable_kv_cache_events: bool = False,
|
||||
hicache_oracle: bool = False,
|
||||
enable_write_cancel: bool = False,
|
||||
):
|
||||
self.disable = disable
|
||||
self.enable_write_cancel = enable_write_cancel
|
||||
|
||||
assert (
|
||||
enable_kv_cache_events is False
|
||||
), "HiRadixCache does not support kv cache events yet"
|
||||
self.kv_cache = token_to_kv_pool.get_kvcache()
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
self.ongoing_write_through: Set[IOHandle] = set()
|
||||
# record the node segments with ongoing load back
|
||||
self.ongoing_load_back: Set[IOHandle] = set()
|
||||
# todo: dynamically adjust the threshold
|
||||
self.write_through_threshold = (
|
||||
1 if hicache_write_policy == "write_through" else 2
|
||||
)
|
||||
self.device = token_to_kv_pool.device
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.page_size = page_size
|
||||
|
||||
self.tp_group = tp_cache_group
|
||||
|
||||
if not use_hicache:
|
||||
self.tree = RadixTreeCpp(
|
||||
disabled=self.disable,
|
||||
page_size=page_size,
|
||||
host_size=None, # no host cache, this should be removed in the future
|
||||
write_through_threshold=self.write_through_threshold,
|
||||
)
|
||||
self.cache_controller = None
|
||||
return # early return if hicache is not used
|
||||
|
||||
raise NotImplementedError("Host cache is not supported yet")
|
||||
|
||||
def reset(self):
|
||||
if self.cache_controller is not None:
|
||||
# need to clear the acks before resetting the cache controller
|
||||
raise NotImplementedError("Host cache is not supported yet")
|
||||
self.tree.reset()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
||||
self.tree.match_prefix(key)
|
||||
)
|
||||
return MatchResult(
|
||||
device_indices=self._merge_tensor(device_indices_vec),
|
||||
last_device_node=node_gpu,
|
||||
last_host_node=node_cpu,
|
||||
host_hit_length=host_indices_length,
|
||||
)
|
||||
|
||||
def _insert(self, key: List[int], value: torch.Tensor) -> int:
|
||||
"""
|
||||
Insert a key-value pair into the radix tree.
|
||||
Args:
|
||||
key (List[int]): The key to insert, represented as a list of integers.
|
||||
value (torch.Tensor): The value to associate with the key.
|
||||
Returns:
|
||||
int: Number of device indices that were already present in the tree before the insertion.
|
||||
"""
|
||||
ongoing_write, length = self.tree.writing_through(key, value)
|
||||
if self.cache_controller is None:
|
||||
assert len(ongoing_write) == 0, "Implementation error"
|
||||
return length
|
||||
|
||||
raise NotImplementedError("Host cache is not supported yet")
|
||||
|
||||
def dec_lock_ref(self, node: TreeNodeCpp):
|
||||
"""
|
||||
Decrement the reference count of a node to root of the radix tree.
|
||||
Args:
|
||||
node (TreeNodeCpp): The handle of the node to decrement the reference count for.
|
||||
"""
|
||||
self.tree.lock_ref(node, False) # do not increment
|
||||
|
||||
def inc_lock_ref(self, node: TreeNodeCpp):
|
||||
"""
|
||||
Increment the reference count of from a node to root of the radix tree.
|
||||
Args:
|
||||
node (TreeNodeCpp): The handle of the node to increment the reference count for.
|
||||
"""
|
||||
self.tree.lock_ref(node, True)
|
||||
|
||||
def evict(self, num_tokens: int):
|
||||
evicted_device_indices = self.tree.evict(num_tokens)
|
||||
for indice in evicted_device_indices:
|
||||
self.token_to_kv_pool.free(indice)
|
||||
|
||||
def evictable_size(self):
|
||||
return self.tree.evictable_size()
|
||||
|
||||
def protected_size(self):
|
||||
return self.tree.protected_size()
|
||||
|
||||
def total_size(self):
|
||||
return self.tree.total_size()
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
"""Cache request when it finishes."""
|
||||
assert req.req_pool_idx is not None
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
overall_len = len(token_ids) # prefill + decode
|
||||
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
|
||||
|
||||
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||
# it will automatically align them, but length of them should be equal
|
||||
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||
new_prefix_len = self._insert(token_ids, kv_indices)
|
||||
|
||||
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||
|
||||
# KVCache between old & new is newly generated, but already exists in the pool
|
||||
# we need to free this newly generated kv indices
|
||||
if old_prefix_len < new_prefix_len:
|
||||
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||
|
||||
# need to free the unaligned part, since it cannot be inserted into the radix tree
|
||||
if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
|
||||
(unaligned_len := overall_len % self.page_size) > 0
|
||||
):
|
||||
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
|
||||
self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
def cache_unfinished_req(self, req: Req):
|
||||
"""Cache request when it is unfinished."""
|
||||
assert req.req_pool_idx is not None
|
||||
token_ids = req.fill_ids
|
||||
prefill_len = len(token_ids) # prefill only (maybe chunked)
|
||||
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len]
|
||||
|
||||
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||
# it will automatically align them, but length of them should be equal
|
||||
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||
new_prefix_len = self._insert(token_ids, kv_indices)
|
||||
|
||||
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||
|
||||
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
||||
# The prefix indices need to updated to reuse the kv indices in the pool
|
||||
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
|
||||
new_indices = self._merge_tensor(new_indices_vec)
|
||||
assert new_prefix_len <= len(new_indices)
|
||||
|
||||
# KVCache between old & new is newly generated, but already exists in the pool
|
||||
# we need to free this newly generated kv indices and reuse the indices in the pool
|
||||
if old_prefix_len < new_prefix_len:
|
||||
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||
reused_indices = new_indices[old_prefix_len:new_prefix_len]
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, old_prefix_len:new_prefix_len
|
||||
] = reused_indices
|
||||
|
||||
if req.last_node != new_last_node:
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
|
||||
# NOTE: there might be unaligned tail, so we may need to append it
|
||||
assert len(new_indices) <= prefill_len < len(new_indices) + self.page_size
|
||||
if self.page_size != 1 and len(new_indices) < prefill_len:
|
||||
req.prefix_indices = torch.cat(
|
||||
[new_indices, kv_indices[len(new_indices) :]]
|
||||
)
|
||||
else:
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
return self.tree.debug_print()
|
||||
Reference in New Issue
Block a user