[Feature] Radix Tree in C++ (#7369)
This commit is contained in:
@@ -569,7 +569,23 @@ class Scheduler(
|
|||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.enable_hierarchical_cache:
|
if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
|
||||||
|
# lazy import to avoid JIT overhead
|
||||||
|
from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
|
||||||
|
|
||||||
|
self.tree_cache = RadixCacheCpp(
|
||||||
|
disable=False,
|
||||||
|
use_hicache=self.enable_hierarchical_cache,
|
||||||
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
|
token_to_kv_pool=self.token_to_kv_pool_allocator,
|
||||||
|
tp_cache_group=self.tp_cpu_group,
|
||||||
|
page_size=self.page_size,
|
||||||
|
hicache_ratio=server_args.hicache_ratio,
|
||||||
|
hicache_size=server_args.hicache_size,
|
||||||
|
hicache_write_policy=server_args.hicache_write_policy,
|
||||||
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||||
|
)
|
||||||
|
elif self.enable_hierarchical_cache:
|
||||||
self.tree_cache = HiRadixCache(
|
self.tree_cache = HiRadixCache(
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
|||||||
1
python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format
Symbolic link
1
python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../../../../sgl-kernel/.clang-format
|
||||||
29
python/sglang/srt/mem_cache/cpp_radix_tree/common.h
Normal file
29
python/sglang/srt/mem_cache/cpp_radix_tree/common.h
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <source_location>
|
||||||
|
#include <span>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace radix_tree_v2 {
|
||||||
|
|
||||||
|
using token_t = std::int32_t;
|
||||||
|
using token_vec_t = std::vector<token_t>;
|
||||||
|
using token_slice = std::span<const token_t>;
|
||||||
|
using NodeHandle = std::size_t;
|
||||||
|
using IOTicket = std::uint32_t;
|
||||||
|
|
||||||
|
inline void _assert(
|
||||||
|
bool condition,
|
||||||
|
const char* message = "Assertion failed",
|
||||||
|
std::source_location loc = std::source_location::current()) {
|
||||||
|
if (!condition) [[unlikely]] {
|
||||||
|
std::string msg = message;
|
||||||
|
msg = msg + " at " + loc.file_name() + ":" + std::to_string(loc.line()) + " in " + loc.function_name();
|
||||||
|
throw std::runtime_error(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace radix_tree_v2
|
||||||
182
python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py
Normal file
182
python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
|
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
radix_tree_cpp = load(
|
||||||
|
name="radix_tree_cpp",
|
||||||
|
sources=[
|
||||||
|
f"{_abs_path}/tree_v2_binding.cpp",
|
||||||
|
f"{_abs_path}/tree_v2_debug.cpp",
|
||||||
|
f"{_abs_path}/tree_v2.cpp",
|
||||||
|
],
|
||||||
|
extra_cflags=["-O3", "-std=c++20"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
class TreeNodeCpp:
|
||||||
|
"""
|
||||||
|
A placeholder for the TreeNode class. Cannot be constructed elsewhere.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class IOHandle:
|
||||||
|
"""
|
||||||
|
A placeholder for the IOHandle class. Cannot be constructed elsewhere.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class RadixTreeCpp:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
disabled: bool,
|
||||||
|
host_size: Optional[int],
|
||||||
|
page_size: int,
|
||||||
|
write_through_threshold: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the RadixTreeCpp instance.
|
||||||
|
Args:
|
||||||
|
disabled (bool): If True, the radix tree is disabled.
|
||||||
|
host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
|
||||||
|
page_size (int): Size of the page for the radix tree.
|
||||||
|
write_through_threshold (int): Threshold for writing through from GPU to CPU.
|
||||||
|
"""
|
||||||
|
self.tree = radix_tree_cpp.RadixTree( # type: ignore
|
||||||
|
disabled, host_size, page_size, write_through_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
def match_prefix(
|
||||||
|
self, prefix: List[int]
|
||||||
|
) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
|
||||||
|
"""
|
||||||
|
Matches a prefix in the radix tree.
|
||||||
|
Args:
|
||||||
|
prefix (List[int]): The prefix to match.
|
||||||
|
Returns:
|
||||||
|
Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
|
||||||
|
0. A list of indices that is matched by the prefix on the GPU.
|
||||||
|
1. Sum length of the indices matched on the CPU.
|
||||||
|
2. The last node of the prefix matched on the GPU.
|
||||||
|
3. The last node of the prefix matched on the CPU.
|
||||||
|
"""
|
||||||
|
return self.tree.match_prefix(prefix)
|
||||||
|
|
||||||
|
def evict(self, num_tokens: int) -> List[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Evicts a number of tokens from the radix tree.
|
||||||
|
Args:
|
||||||
|
num_tokens (int): The number of tokens to evict.
|
||||||
|
Returns:
|
||||||
|
List[torch.Tensor]: A list of indices that were evicted.
|
||||||
|
"""
|
||||||
|
return self.tree.evict(num_tokens)
|
||||||
|
|
||||||
|
def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
|
||||||
|
"""
|
||||||
|
Locks or unlocks a reference to a tree node.
|
||||||
|
After locking, the node will not be evicted from the radix tree.
|
||||||
|
Args:
|
||||||
|
handle (TreeNodeCpp): The tree node to lock or unlock.
|
||||||
|
lock (bool): If True, locks the node; if False, unlocks it.
|
||||||
|
"""
|
||||||
|
return self.tree.lock_ref(handle, lock)
|
||||||
|
|
||||||
|
def writing_through(
|
||||||
|
self, key: List[int], indices: torch.Tensor
|
||||||
|
) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
||||||
|
"""
|
||||||
|
Inserts a key-value pair into the radix tree and perform write-through check.
|
||||||
|
Args:
|
||||||
|
key (List[int]): The key to insert.
|
||||||
|
indices (torch.Tensor): The value associated with the key.
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
||||||
|
0. A list of (IOHandle, device indices, host indices) tuples.
|
||||||
|
These IOhandles require write-through to the CPU in python side.
|
||||||
|
1. The number of indices that are matched on device.
|
||||||
|
"""
|
||||||
|
return self.tree.writing_through(key, indices)
|
||||||
|
|
||||||
|
def loading_onboard(
|
||||||
|
self,
|
||||||
|
host_node: TreeNodeCpp,
|
||||||
|
new_device_indices: torch.Tensor,
|
||||||
|
) -> Tuple[IOHandle, List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Updates the device indices of tree nodes within a range on the tree.
|
||||||
|
Args:
|
||||||
|
host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
|
||||||
|
new_device_indices (torch.Tensor): The new device indices to set.
|
||||||
|
The length of this tensor must be exactly host indices length.
|
||||||
|
Returns:
|
||||||
|
Tuple[IOHandle, List[torch.Tensor]]:
|
||||||
|
0. An IOHandle that requires loading to the CPU in python side.
|
||||||
|
1. A list of host indices corresponding to the new device indices.
|
||||||
|
"""
|
||||||
|
return self.tree.loading_onboard(host_node, new_device_indices)
|
||||||
|
|
||||||
|
def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
|
||||||
|
"""
|
||||||
|
Commits the write-through process for a tree node.
|
||||||
|
Args:
|
||||||
|
handle (IOHandle): The IOHandle to commit.
|
||||||
|
success (bool): If True, commits the write-through; if False, just indicates failure.
|
||||||
|
"""
|
||||||
|
return self.tree.commit_writing_through(handle, success)
|
||||||
|
|
||||||
|
def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
|
||||||
|
"""
|
||||||
|
Commits the load onboard process for tree nodes within a range on the tree.
|
||||||
|
Args:
|
||||||
|
handle (IOHandle): The IOHandle to commit.
|
||||||
|
success (bool): If True, commits the load-onboard; if False, just indicates failure.
|
||||||
|
"""
|
||||||
|
return self.tree.commit_loading_onboard(handle, success)
|
||||||
|
|
||||||
|
def evictable_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the size of the evictable part of the radix tree.
|
||||||
|
This is the size of the part that can be evicted from the GPU (ref_count = 0).
|
||||||
|
Returns:
|
||||||
|
int: The size of the evictable part.
|
||||||
|
"""
|
||||||
|
return self.tree.evictable_size()
|
||||||
|
|
||||||
|
def protected_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the size of the protected part of the radix tree.
|
||||||
|
This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
|
||||||
|
Returns:
|
||||||
|
int: The size of the protected part.
|
||||||
|
"""
|
||||||
|
return self.tree.protected_size()
|
||||||
|
|
||||||
|
def total_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the total size of the radix tree (including CPU nodes).
|
||||||
|
Returns:
|
||||||
|
int: The total size of the radix tree.
|
||||||
|
"""
|
||||||
|
return self.tree.total_size()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Resets the radix tree, clearing all nodes and indices.
|
||||||
|
"""
|
||||||
|
return self.tree.reset()
|
||||||
|
|
||||||
|
def debug_print(self) -> None:
|
||||||
|
"""
|
||||||
|
Prints the internal state of the radix tree for debugging purposes.
|
||||||
|
"""
|
||||||
|
return self.tree.debug_print()
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Real implementation of the classes for runtime
|
||||||
|
RadixTreeCpp = radix_tree_cpp.RadixTree
|
||||||
|
TreeNodeCpp = object
|
||||||
|
IOHandle = object
|
||||||
143
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp
Normal file
143
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
#include "tree_v2.h"
|
||||||
|
|
||||||
|
#include <ATen/core/TensorBody.h>
|
||||||
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <ATen/ops/tensor.h>
|
||||||
|
#include <ATen/ops/zeros.h>
|
||||||
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "tree_v2_impl.h"
|
||||||
|
#include "tree_v2_node.h"
|
||||||
|
|
||||||
|
namespace radix_tree_v2 {
|
||||||
|
|
||||||
|
static NodeHandle node2id(TreeNode* node) {
|
||||||
|
return node->node_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare function for the TreeNode pointers based on their time
|
||||||
|
// we use LRU, so we want to evict the least recently used nodes
|
||||||
|
// since std::priority_queue is a max-heap, we need to reverse the comparison
|
||||||
|
static constexpr auto cmp = [](TreeNode* lhs, TreeNode* rhs) { return lhs->time() > rhs->time(); };
|
||||||
|
|
||||||
|
RadixTree::RadixTree(bool disabled, std::optional<std::size_t> host_size, std::size_t page_size, std::size_t threshold)
|
||||||
|
: m_impl(std::make_unique<Impl>(disabled, host_size.has_value(), page_size, host_size.value_or(0), threshold)) {}
|
||||||
|
|
||||||
|
RadixTree::~RadixTree() = default;
|
||||||
|
|
||||||
|
std::tuple<std::vector<at::Tensor>, std::size_t, NodeHandle, NodeHandle>
|
||||||
|
RadixTree::match_prefix(const token_vec_t& _key) {
|
||||||
|
if (m_impl->disabled) return {};
|
||||||
|
|
||||||
|
const auto key = token_slice{_key.data(), m_impl->align(_key.size())};
|
||||||
|
const auto [host_node, _] = m_impl->tree_walk(key);
|
||||||
|
|
||||||
|
// walk up to the first non-evicted node
|
||||||
|
std::size_t host_hit_length = 0;
|
||||||
|
const auto device_node = host_node;
|
||||||
|
|
||||||
|
// collect all the device indices
|
||||||
|
std::vector<at::Tensor> indices{};
|
||||||
|
walk_to_root(device_node, [&](TreeNode* n) { indices.push_back(n->device_indices()); });
|
||||||
|
std::reverse(indices.begin(), indices.end());
|
||||||
|
|
||||||
|
return {std::move(indices), host_hit_length, node2id(device_node), node2id(host_node)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<at::Tensor> RadixTree::evict(std::size_t num_tokens) {
|
||||||
|
if (m_impl->disabled || num_tokens == 0) return {};
|
||||||
|
auto heap = std::priority_queue{cmp, m_impl->collect_leaves_device()};
|
||||||
|
std::vector<at::Tensor> evicted_values;
|
||||||
|
// evict nodes until we reach the desired number of tokens
|
||||||
|
std::size_t num_evict = 0;
|
||||||
|
while (num_evict < num_tokens && !heap.empty()) {
|
||||||
|
const auto node = heap.top();
|
||||||
|
heap.pop();
|
||||||
|
// when ref_count == 0, can't be writing through
|
||||||
|
_assert(node->on_gpu() && node->ref_count == 0);
|
||||||
|
if (!node->is_io_free()) continue; // skip nodes that are undergoing IO (i.e. indices protected)
|
||||||
|
evicted_values.push_back(node->device_indices());
|
||||||
|
num_evict += node->length();
|
||||||
|
const auto parent = node->parent();
|
||||||
|
m_impl->remove_device_node(node);
|
||||||
|
if (parent->is_leaf_device() && parent->ref_count == 0)
|
||||||
|
heap.push(parent); // push parent to the heap if it is now a free leaf
|
||||||
|
}
|
||||||
|
|
||||||
|
return evicted_values;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>>, std::size_t>
|
||||||
|
RadixTree::writing_through(const token_vec_t& _key, at::Tensor value) {
|
||||||
|
if (m_impl->disabled) return {};
|
||||||
|
_assert(_key.size() == std::size_t(value.size(0)), "Key and value must have the same size");
|
||||||
|
|
||||||
|
// just align the key to the page size, clip the unaligned tail
|
||||||
|
const auto key = token_slice{_key.data(), m_impl->align(_key.size())};
|
||||||
|
|
||||||
|
// walk the tree to find the right place to insert
|
||||||
|
const auto [host_node, host_prefix_length] = m_impl->tree_walk(key);
|
||||||
|
|
||||||
|
// insert and create a new node if the remaining part of the key is not empty
|
||||||
|
if (host_prefix_length != key.size()) {
|
||||||
|
m_impl->create_device_node(
|
||||||
|
host_node,
|
||||||
|
{key.begin() + host_prefix_length, key.end()},
|
||||||
|
value.slice(/*dim=*/0, host_prefix_length, key.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the hit count for the device node
|
||||||
|
walk_to_root(host_node, [&](TreeNode* n) { n->hit_count++; });
|
||||||
|
|
||||||
|
std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>> result;
|
||||||
|
|
||||||
|
// don't write through if hicache is disabled (no host memory), fast path
|
||||||
|
if (!m_impl->use_hicache) return {std::move(result), host_prefix_length};
|
||||||
|
throw std::runtime_error("Not implemented yet");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<IOTicket, std::vector<at::Tensor>> RadixTree::loading_onboard(NodeHandle, at::Tensor) {
|
||||||
|
if (m_impl->disabled) return {};
|
||||||
|
throw std::runtime_error("Not implemented yet");
|
||||||
|
}
|
||||||
|
|
||||||
|
void RadixTree::commit_writing_through(IOTicket, bool) {
|
||||||
|
if (m_impl->disabled) return;
|
||||||
|
throw std::runtime_error("Not implemented yet");
|
||||||
|
}
|
||||||
|
|
||||||
|
void RadixTree::commit_loading_onboard(IOTicket, bool) {
|
||||||
|
if (m_impl->disabled) return;
|
||||||
|
throw std::runtime_error("Not implemented yet");
|
||||||
|
}
|
||||||
|
|
||||||
|
void RadixTree::reset() {
|
||||||
|
m_impl->reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RadixTree::lock_ref(NodeHandle node_id, bool increment) {
|
||||||
|
if (m_impl->disabled) return;
|
||||||
|
m_impl->lock_ref(node_id, increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t RadixTree::evictable_size() const {
|
||||||
|
return m_impl->evictable_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t RadixTree::protected_size() const {
|
||||||
|
return m_impl->protected_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t RadixTree::total_size() const {
|
||||||
|
return m_impl->total_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace radix_tree_v2
|
||||||
59
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h
Normal file
59
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <ATen/core/TensorBody.h>
|
||||||
|
#include <c10/core/Device.h>
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <tuple>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
namespace radix_tree_v2 {
|
||||||
|
|
||||||
|
struct RadixTree {
|
||||||
|
public:
|
||||||
|
RadixTree(bool disabled, std::optional<std::size_t> host_size, std::size_t page_size, std::size_t threshold);
|
||||||
|
~RadixTree();
|
||||||
|
|
||||||
|
// Trees should not be copied or moved, as they manage their own memory and state.
|
||||||
|
RadixTree(const RadixTree&) = delete;
|
||||||
|
RadixTree(RadixTree&&) = delete;
|
||||||
|
RadixTree& operator=(const RadixTree&) = delete;
|
||||||
|
RadixTree& operator=(RadixTree&&) = delete;
|
||||||
|
|
||||||
|
/// @return (device indices that are matched, host indices length, device node, host node)
|
||||||
|
std::tuple<std::vector<at::Tensor>, std::size_t, NodeHandle, NodeHandle> match_prefix(const token_vec_t& key);
|
||||||
|
/// @return Device indices that need to be evicted (on python side).
|
||||||
|
std::vector<at::Tensor> evict(std::size_t num_tokens);
|
||||||
|
/// @brief (Un-)Lock a node.
|
||||||
|
void lock_ref(NodeHandle node_id, bool increment /* increment or decrement */);
|
||||||
|
/// @brief Update new key-value pair and try to perform write-through.
|
||||||
|
std::tuple<std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>>, std::size_t>
|
||||||
|
writing_through(const token_vec_t& key, at::Tensor value);
|
||||||
|
/// @brief Load to device from host within a range of nodes.
|
||||||
|
std::tuple<IOTicket, std::vector<at::Tensor>> loading_onboard(NodeHandle host_id, at::Tensor indices);
|
||||||
|
/// @brief Commit a transaction of write-through.
|
||||||
|
void commit_writing_through(IOTicket ticket, bool success);
|
||||||
|
/// @brief Commit a transaction of load onboard.
|
||||||
|
void commit_loading_onboard(IOTicket ticket, bool success);
|
||||||
|
/// @brief Clear and reset the tree.
|
||||||
|
void reset();
|
||||||
|
|
||||||
|
/// @return How many size are still evictable (on device + not locked).
|
||||||
|
std::size_t evictable_size() const;
|
||||||
|
/// @return How many size are protected (locked).
|
||||||
|
std::size_t protected_size() const;
|
||||||
|
/// @return How many size are used on device.
|
||||||
|
std::size_t total_size() const;
|
||||||
|
|
||||||
|
/// @brief Print debug information of the tree.
|
||||||
|
void debug_print() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Impl;
|
||||||
|
std::unique_ptr<Impl> m_impl;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace radix_tree_v2
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "tree_v2.h"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(radix_tree_cpp, m) {
|
||||||
|
using namespace radix_tree_v2;
|
||||||
|
namespace py = pybind11;
|
||||||
|
py::class_<RadixTree>(m, "RadixTree")
|
||||||
|
.def(
|
||||||
|
py::init<bool, std::optional<std::size_t>, std::size_t, std::size_t>(),
|
||||||
|
py::arg("disabled"),
|
||||||
|
py::arg("host_size"),
|
||||||
|
py::arg("page_size"),
|
||||||
|
py::arg("write_through_threshold"))
|
||||||
|
.def("match_prefix", &RadixTree::match_prefix)
|
||||||
|
.def("evict", &RadixTree::evict)
|
||||||
|
.def("lock_ref", &RadixTree::lock_ref)
|
||||||
|
.def("evictable_size", &RadixTree::evictable_size)
|
||||||
|
.def("protected_size", &RadixTree::protected_size)
|
||||||
|
.def("total_size", &RadixTree::total_size)
|
||||||
|
.def("writing_through", &RadixTree::writing_through)
|
||||||
|
.def("loading_onboard", &RadixTree::loading_onboard)
|
||||||
|
.def("commit_writing_through", &RadixTree::commit_writing_through)
|
||||||
|
.def("commit_loading_onboard", &RadixTree::commit_loading_onboard)
|
||||||
|
.def("reset", &RadixTree::reset)
|
||||||
|
.def("debug_print", &RadixTree::debug_print);
|
||||||
|
}
|
||||||
194
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp
Normal file
194
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
#include <c10/core/DeviceType.h>
|
||||||
|
#include <c10/core/MemoryFormat.h>
|
||||||
|
#include <c10/core/ScalarType.h>
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tree_v2.h"
|
||||||
|
#include "tree_v2_impl.h"
|
||||||
|
|
||||||
|
namespace radix_tree_v2 {
|
||||||
|
|
||||||
|
void RadixTree::debug_print() const {
|
||||||
|
m_impl->debug_print(std::clog);
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr auto npos = std::size_t(-1);
|
||||||
|
|
||||||
|
void RadixTree::Impl::debug_print(std::ostream& os) const {
|
||||||
|
static constexpr auto _check = [](bool condition, auto msg, std::size_t id = npos) {
|
||||||
|
if (!condition) {
|
||||||
|
std::string suffix = id == npos ? "" : " [id = " + std::to_string(id) + "]";
|
||||||
|
throw std::runtime_error(std::string("RadixTree::debug_print failed: ") + msg + suffix);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr auto _print_node = [](TreeNode* node, std::size_t depth, std::ostream& os) {
|
||||||
|
const auto length = node->length();
|
||||||
|
os << node->node_id << " [depth = " << depth << "] [len = " << length << "]";
|
||||||
|
|
||||||
|
// placement status
|
||||||
|
if (node->on_both()) {
|
||||||
|
os << " [cpu + gpu]";
|
||||||
|
} else if (node->on_gpu()) {
|
||||||
|
os << " [gpu]";
|
||||||
|
} else if (node->on_cpu()) {
|
||||||
|
os << " [cpu]";
|
||||||
|
} else {
|
||||||
|
_check(false, "Node is not on GPU or CPU", node->node_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IO status
|
||||||
|
if (node->is_io_free()) {
|
||||||
|
os << " [io = free]";
|
||||||
|
} else if (node->is_io_device_to_host()) {
|
||||||
|
os << " [io = gpu -> cpu]";
|
||||||
|
} else if (node->is_io_host_to_device()) {
|
||||||
|
os << " [io = cpu -> gpu]";
|
||||||
|
} else {
|
||||||
|
_check(false, "Node is in unknown IO state", node->node_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
os << " [rc = " << node->ref_count << "]";
|
||||||
|
os << " [hit = " << node->hit_count << "]";
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr auto _print_indices = [](at::Tensor indices, std::ostream& os) {
|
||||||
|
if (!indices.defined()) {
|
||||||
|
os << "[[N/A]]";
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
indices = indices.to(c10::kCPU, c10::kLong, false, false, c10::MemoryFormat::Contiguous);
|
||||||
|
const auto length = indices.numel();
|
||||||
|
os << "[";
|
||||||
|
auto* data_ptr = indices.data_ptr<int64_t>();
|
||||||
|
for (const auto i : c10::irange(indices.size(0))) {
|
||||||
|
os << data_ptr[i];
|
||||||
|
if (i != length - 1) os << ", ";
|
||||||
|
}
|
||||||
|
os << "]";
|
||||||
|
return indices;
|
||||||
|
};
|
||||||
|
|
||||||
|
os << "Evictable size: " << evictable_size() << std::endl;
|
||||||
|
os << "Protected size: " << protected_size() << std::endl;
|
||||||
|
os << "Total size: " << const_cast<Impl*>(this)->total_size() << std::endl;
|
||||||
|
std::vector<std::tuple<TreeNode*, TreeNode*, token_slice>> stack;
|
||||||
|
auto root = const_cast<TreeNode*>(&m_root);
|
||||||
|
os << root->node_id << " [root]" << std::endl;
|
||||||
|
for (const auto& [key, child] : *root) {
|
||||||
|
stack.push_back({child.get(), root, key});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<TreeNode*, std::size_t> depth_map;
|
||||||
|
std::string indent_buffer;
|
||||||
|
depth_map[root] = 0;
|
||||||
|
std::vector<NodeHandle> visited_id;
|
||||||
|
std::size_t evictable_size_real = 0;
|
||||||
|
while (!stack.empty()) {
|
||||||
|
const auto [node, parent, key] = stack.back();
|
||||||
|
stack.pop_back();
|
||||||
|
visited_id.push_back(node->node_id);
|
||||||
|
const auto nid = node->node_id;
|
||||||
|
_check(node != nullptr, "Node is null", nid);
|
||||||
|
_check(node->on_gpu() || node->on_cpu(), "Node is not on GPU or CPU", nid);
|
||||||
|
_check(node->parent() == parent, "Parent is not correct", nid);
|
||||||
|
_check(key.size() == page_size && node->diff_key(key, 0) == page_size, "Key is not correct", nid);
|
||||||
|
_check(depth_map.count(node) == 0, "Node is visited twice", nid);
|
||||||
|
_check(m_node_map.count(nid) == 1, "Node is not in the map", nid);
|
||||||
|
_check(m_node_map.at(nid) == node, "Node in the map is not the same as the one in the stack", nid);
|
||||||
|
_check(!node->on_gpu() || parent->is_root() || parent->on_gpu(), "Node on GPU must have a GPU/root parent", nid);
|
||||||
|
if (!node->is_io_free()) {
|
||||||
|
_check(node->ref_count > 0, "Node is in IO state but not protected", nid);
|
||||||
|
_check(node->on_both(), "Node in IO state must be on both CPU and GPU", nid);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->on_gpu() && node->ref_count == 0) {
|
||||||
|
evictable_size_real += node->length();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto depth = (depth_map[node] = depth_map[parent] + 1);
|
||||||
|
indent_buffer.resize(depth * 2, ' ');
|
||||||
|
os << indent_buffer;
|
||||||
|
_print_node(node, depth, os);
|
||||||
|
os << std::endl;
|
||||||
|
for (const auto& [key, child] : *node) {
|
||||||
|
stack.push_back({child.get(), node, key});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_check(evictable_size_real == evictable_size(), "Evictable size is wrong");
|
||||||
|
_check(m_node_map.count(root->node_id) == 1, "Root node is not in the map");
|
||||||
|
_check(m_node_map.at(root->node_id) == root, "Root node in the map is not correct");
|
||||||
|
|
||||||
|
std::sort(visited_id.begin(), visited_id.end());
|
||||||
|
if (visited_id.size() != m_node_map.size() - 1) {
|
||||||
|
// Some error in the tree, not all nodes are visited
|
||||||
|
std::string id_list;
|
||||||
|
id_list += "(visited: ";
|
||||||
|
id_list += std::to_string(root->node_id) + " ";
|
||||||
|
for (const auto& id : visited_id) {
|
||||||
|
id_list += std::to_string(id) + " ";
|
||||||
|
}
|
||||||
|
id_list += "), (in map: ";
|
||||||
|
for (const auto& [id, _] : m_node_map) {
|
||||||
|
id_list += std::to_string(id) + " ";
|
||||||
|
}
|
||||||
|
id_list += ")";
|
||||||
|
_check(false, "Not all nodes are visited " + id_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const auto kSGLANG_RADIX_CPP_DEBUG_LIMIT = [] {
|
||||||
|
const char* env = std::getenv("SGLANG_RADIX_CPP_DEBUG_LIMIT");
|
||||||
|
const std::size_t default_limit = 16;
|
||||||
|
if (env != nullptr) {
|
||||||
|
try {
|
||||||
|
return static_cast<std::size_t>(std::stoull(env));
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::cerr << "Invalid SGLANG_RADIX_CPP_DEBUG_LIMIT value: " << env //
|
||||||
|
<< ". Using default value =" << default_limit << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return default_limit;
|
||||||
|
}();
|
||||||
|
|
||||||
|
for (const auto nid : visited_id) {
|
||||||
|
const auto node = m_node_map.at(nid);
|
||||||
|
// print key and indices
|
||||||
|
const auto& key = node->_unsafe_tokens();
|
||||||
|
if (key.size() > kSGLANG_RADIX_CPP_DEBUG_LIMIT) {
|
||||||
|
os << "Node " << nid << ": key is too long (" << key.size() << " tokens), skipping..." << std::endl;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os << "Node " << nid << ": key = [";
|
||||||
|
for (const auto& i : c10::irange(key.size())) {
|
||||||
|
os << key[i];
|
||||||
|
if (i != key.size() - 1) os << ", ";
|
||||||
|
}
|
||||||
|
|
||||||
|
_check(key.size() % page_size == 0, "Misaligned key", nid);
|
||||||
|
|
||||||
|
os << "] device_indices = ";
|
||||||
|
const auto device_indices = _print_indices(node->device_indices(), os);
|
||||||
|
if (device_indices.defined()) {
|
||||||
|
std::size_t length = device_indices.numel();
|
||||||
|
_check(device_indices.dim() == 1, "Device indices must be 1D tensor", nid);
|
||||||
|
_check(length == node->length(), "Wrong device indices size", nid);
|
||||||
|
}
|
||||||
|
|
||||||
|
os << " host_indices = ";
|
||||||
|
const auto host_indices = _print_indices(node->host_indices(), os);
|
||||||
|
if (host_indices.defined()) {
|
||||||
|
std::size_t length = host_indices.numel();
|
||||||
|
_check(host_indices.dim() == 1, "Host indices must be 1D tensor", nid);
|
||||||
|
_check(length == node->length(), "Wrong host indices size", nid);
|
||||||
|
}
|
||||||
|
os << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace radix_tree_v2
|
||||||
276
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h
Normal file
276
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <iosfwd>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "tree_v2.h"
|
||||||
|
#include "tree_v2_node.h"
|
||||||
|
|
||||||
|
namespace radix_tree_v2 {
|
||||||
|
|
||||||
|
using node_iterator_t = typename TreeNode::iterator_t;
|
||||||
|
|
||||||
|
struct RadixTree::Impl {
|
||||||
|
public:
|
||||||
|
Impl(bool disabled, bool use_hicache, std::size_t page_size, std::size_t host_size, std::size_t threshold)
|
||||||
|
: m_root(/*node_id_=*/0),
|
||||||
|
m_evictable_size(0),
|
||||||
|
m_protected_size(0),
|
||||||
|
m_cached_vec(),
|
||||||
|
m_node_map(),
|
||||||
|
m_node_counter(1), // start from 1 to avoid confusion with root node
|
||||||
|
disabled(disabled),
|
||||||
|
use_hicache(use_hicache),
|
||||||
|
page_size(page_size),
|
||||||
|
threshold(threshold) {
|
||||||
|
_assert(page_size > 0, "Page size must be greater than zero");
|
||||||
|
_assert(use_hicache == (host_size > 0), "Hierarchical cache is enabled iff host size > 0");
|
||||||
|
m_root.ref_count = 1; // root node is always protected
|
||||||
|
m_cached_vec.reserve(page_size); // to avoid repeated allocations
|
||||||
|
m_node_map[m_root.node_id] = &m_root; // add root to the map
|
||||||
|
}
|
||||||
|
|
||||||
|
TreeNode* split_node(node_iterator_t iterator, std::size_t prefix_length) {
|
||||||
|
// from `parent -> old_node` to `parent-> new_node -> old_node`
|
||||||
|
// the prefix part of the old node is moved to the new node
|
||||||
|
auto old_node_ptr = std::move(iterator->second);
|
||||||
|
auto new_node_ptr = std::make_unique<TreeNode>(m_node_counter++);
|
||||||
|
auto* old_node = old_node_ptr.get();
|
||||||
|
auto* new_node = new_node_ptr.get();
|
||||||
|
auto* parent = old_node->parent();
|
||||||
|
// set up data structures
|
||||||
|
split_prefix(new_node, old_node, prefix_length);
|
||||||
|
// set up parent-child relationship
|
||||||
|
add_child(new_node, std::move(old_node_ptr));
|
||||||
|
add_child(parent, std::move(new_node_ptr), iterator);
|
||||||
|
m_node_map[new_node->node_id] = new_node; // add to the map
|
||||||
|
return new_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// node: x -> [GPU]
|
||||||
|
TreeNode* create_device_node(TreeNode* parent, token_vec_t vec, at::Tensor indices) {
|
||||||
|
auto new_node_ptr = std::make_unique<TreeNode>(m_node_counter++);
|
||||||
|
auto new_node = new_node_ptr.get();
|
||||||
|
new_node_ptr->_unsafe_tokens() = std::move(vec);
|
||||||
|
new_node_ptr->_unsafe_device_indices() = std::move(indices);
|
||||||
|
m_evictable_size += new_node_ptr->length();
|
||||||
|
add_child(parent, std::move(new_node_ptr));
|
||||||
|
m_node_map[new_node->node_id] = new_node; // add to the map
|
||||||
|
return new_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// node: [GPU] -> x
|
||||||
|
void remove_device_node(TreeNode* node) {
|
||||||
|
_assert(node->on_gpu_only() && node->ref_count == 0);
|
||||||
|
m_evictable_size -= node->length();
|
||||||
|
node->parent()->erase_child(get_key(node));
|
||||||
|
m_node_map.erase(node->node_id); // remove from the map
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Walk the tree to find the node that matches the key.
|
||||||
|
* If the key partially matches a node, it will split that node.
|
||||||
|
* @return A pair containing the last node that matches the key and
|
||||||
|
* the total prefix length matched (on gpu and cpu) so far.
|
||||||
|
*/
|
||||||
|
std::pair<TreeNode*, std::size_t> tree_walk(token_slice key) {
|
||||||
|
_assert(key.size() % page_size == 0, "Key should be page-aligned");
|
||||||
|
|
||||||
|
std::size_t total_prefix_length = 0;
|
||||||
|
TreeNode* node = &m_root;
|
||||||
|
|
||||||
|
const auto now = std::chrono::steady_clock::now();
|
||||||
|
while (key.size() > 0) {
|
||||||
|
const auto iterator = node->find_child(get_key(key));
|
||||||
|
if (iterator == node->end()) break;
|
||||||
|
|
||||||
|
// walk to the child node
|
||||||
|
node = iterator->second.get();
|
||||||
|
|
||||||
|
// at least `page_size` tokens are matched, and there may be more tokens to match
|
||||||
|
// the return value prefix_length is no less than `page_size`
|
||||||
|
const auto prefix_length = align(node->diff_key(key, page_size) + page_size);
|
||||||
|
total_prefix_length += prefix_length;
|
||||||
|
|
||||||
|
// split the node if the prefix is not the whole token vector
|
||||||
|
if (prefix_length < node->length()) {
|
||||||
|
return {split_node(iterator, prefix_length), total_prefix_length};
|
||||||
|
}
|
||||||
|
|
||||||
|
// we have matched the whole key, continue to the next node
|
||||||
|
node->access(now);
|
||||||
|
key = key.subspan(prefix_length);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {node, total_prefix_length};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TreeNode*> collect_leaves() const {
|
||||||
|
std::vector<TreeNode*> leaves;
|
||||||
|
std::vector<TreeNode*> stack = {};
|
||||||
|
for (const auto& [_, child] : m_root) {
|
||||||
|
stack.push_back(child.get());
|
||||||
|
}
|
||||||
|
while (!stack.empty()) {
|
||||||
|
const auto node = stack.back();
|
||||||
|
stack.pop_back();
|
||||||
|
if (node->is_leaf()) {
|
||||||
|
if (node->ref_count == 0) {
|
||||||
|
leaves.push_back(node);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const auto& [_, child] : *node) {
|
||||||
|
stack.push_back(child.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return leaves;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TreeNode*> collect_leaves_device() const {
|
||||||
|
// for non-hicache, every leaf device node is a leaf node (since no backup on host)
|
||||||
|
if (!use_hicache) return collect_leaves();
|
||||||
|
std::vector<TreeNode*> leaves;
|
||||||
|
std::vector<TreeNode*> stack = {};
|
||||||
|
for (const auto& [_, child] : m_root) {
|
||||||
|
stack.push_back(child.get());
|
||||||
|
}
|
||||||
|
while (!stack.empty()) {
|
||||||
|
const auto node = stack.back();
|
||||||
|
stack.pop_back();
|
||||||
|
if (!node->on_gpu()) continue; // skip nodes that are not on GPU
|
||||||
|
if (node->is_leaf_device()) {
|
||||||
|
if (node->ref_count == 0) {
|
||||||
|
leaves.push_back(node);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const auto& [_, child] : *node) {
|
||||||
|
stack.push_back(child.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return leaves;
|
||||||
|
}
|
||||||
|
|
||||||
|
void lock_ref(TreeNode* node, bool increment) {
|
||||||
|
if (node->is_root()) return; // skip root node
|
||||||
|
_assert(node->on_gpu(), "Cannot lock reference on an evicted node");
|
||||||
|
if (increment)
|
||||||
|
walk_to_root(node, [this](TreeNode* n) {
|
||||||
|
if (n->ref_count == 0) {
|
||||||
|
m_evictable_size -= n->length();
|
||||||
|
m_protected_size += n->length();
|
||||||
|
}
|
||||||
|
n->ref_count++;
|
||||||
|
});
|
||||||
|
else
|
||||||
|
walk_to_root(node, [this](TreeNode* n) {
|
||||||
|
_assert(n->ref_count != 0, "Cannot decrement reference count = zero");
|
||||||
|
n->ref_count--;
|
||||||
|
if (n->ref_count == 0) {
|
||||||
|
m_protected_size -= n->length();
|
||||||
|
m_evictable_size += n->length();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void lock_ref(NodeHandle node_ptr, bool increment) {
|
||||||
|
return lock_ref(id2node(node_ptr), increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
void lock(TreeNode* node) {
|
||||||
|
return lock_ref(node, /*increment=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void unlock(TreeNode* node) {
|
||||||
|
return lock_ref(node, /*increment=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t total_size() const {
|
||||||
|
std::size_t size = 0;
|
||||||
|
std::vector<const TreeNode*> stack = {&m_root};
|
||||||
|
while (!stack.empty()) {
|
||||||
|
auto* node = stack.back();
|
||||||
|
stack.pop_back();
|
||||||
|
size += node->length();
|
||||||
|
for (const auto& [_, child] : *node)
|
||||||
|
stack.push_back(child.get());
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t evictable_size() const {
|
||||||
|
return m_evictable_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t protected_size() const {
|
||||||
|
return m_protected_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t align(std::size_t size) const {
|
||||||
|
return (size / page_size) * page_size; // align to page size
|
||||||
|
}
|
||||||
|
|
||||||
|
TreeNode* id2node(NodeHandle node_id) const {
|
||||||
|
const auto iterator = m_node_map.find(node_id);
|
||||||
|
_assert(iterator != m_node_map.end(), "Node not found in the map");
|
||||||
|
return iterator->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
_assert(m_root.ref_count == 1, "Root node must be protected during reset");
|
||||||
|
m_node_counter = 1; // reset node counter
|
||||||
|
m_root.root_reset();
|
||||||
|
m_evictable_size = 0;
|
||||||
|
m_protected_size = 0;
|
||||||
|
m_node_map.clear();
|
||||||
|
m_node_map[m_root.node_id] = &m_root; // re-add root to the map
|
||||||
|
}
|
||||||
|
|
||||||
|
void debug_print(std::ostream& os) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// some auxiliary functions
|
||||||
|
token_vec_t& get_key(token_slice tokens) {
|
||||||
|
_assert(tokens.size() >= page_size, "Key should be at least page-sized");
|
||||||
|
tokens = tokens.subspan(0, page_size);
|
||||||
|
m_cached_vec.assign(tokens.begin(), tokens.end());
|
||||||
|
return m_cached_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
// justify for _unsafe call: we need to read the key part of the tokens
|
||||||
|
token_vec_t& get_key(TreeNode* node) {
|
||||||
|
return get_key(node->_unsafe_tokens());
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_child(TreeNode* parent, std::unique_ptr<TreeNode>&& child) {
|
||||||
|
parent->add_child(get_key(child.get()), std::move(child));
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_child(TreeNode* parent, std::unique_ptr<TreeNode>&& child, node_iterator_t it) {
|
||||||
|
parent->add_child(it, std::move(child));
|
||||||
|
}
|
||||||
|
|
||||||
|
TreeNode m_root; // root node of the tree
|
||||||
|
std::size_t m_evictable_size; // number of evictable tokens on GPU (lock ref = 0)
|
||||||
|
std::size_t m_protected_size; // number of protected tokens on GPU (lock ref > 0)
|
||||||
|
token_vec_t m_cached_vec; // cached vector of tokens for the current operation
|
||||||
|
std::unordered_map<std::size_t, TreeNode*> m_node_map; // map of node keys to nodes
|
||||||
|
std::size_t m_node_counter; // counter for node IDs
|
||||||
|
|
||||||
|
public:
|
||||||
|
// some public constant configurations (without m_ prefix)
|
||||||
|
const bool disabled; // whether the cache is enabled, or just a temporary cache
|
||||||
|
const bool use_hicache; // whether to use the HiCache for this tree
|
||||||
|
const std::size_t page_size; // size of each page in the cache
|
||||||
|
const std::size_t threshold; // threshold for write_through
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace radix_tree_v2
|
||||||
257
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h
Normal file
257
python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <ATen/core/TensorBody.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <array>
|
||||||
|
#include <chrono>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <ranges>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
namespace radix_tree_v2 {
|
||||||
|
|
||||||
|
struct std_vector_hash {
|
||||||
|
// see https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
|
||||||
|
std::size_t operator()(const token_vec_t& vec) const {
|
||||||
|
std::size_t hash = 0;
|
||||||
|
for (const auto& token : vec) {
|
||||||
|
hash ^= token + 0x9e3779b9 + (hash << 6) + (hash >> 2);
|
||||||
|
}
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TreeNode {
|
||||||
|
public:
|
||||||
|
using childern_map_t = std::unordered_map<token_vec_t, std::unique_ptr<TreeNode>, std_vector_hash>;
|
||||||
|
using iterator_t = typename childern_map_t::iterator;
|
||||||
|
using const_iterator_t = typename childern_map_t::const_iterator;
|
||||||
|
using timestamp_t = std::chrono::steady_clock::time_point;
|
||||||
|
|
||||||
|
TreeNode(std::size_t node_id_)
|
||||||
|
: ref_count(0),
|
||||||
|
hit_count(0),
|
||||||
|
m_io_locked(std::nullopt),
|
||||||
|
m_io_status(IOStatus::None),
|
||||||
|
m_io_ticket(),
|
||||||
|
m_tokens(),
|
||||||
|
m_device_indices(),
|
||||||
|
m_host_indices(),
|
||||||
|
m_parent(),
|
||||||
|
m_children(),
|
||||||
|
m_last_access_time(std::chrono::steady_clock::now()),
|
||||||
|
node_id(node_id_) {}
|
||||||
|
|
||||||
|
void access(timestamp_t time = std::chrono::steady_clock::now()) {
|
||||||
|
m_last_access_time = time;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_root() const {
|
||||||
|
return m_parent == nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp_t time() const {
|
||||||
|
return m_last_access_time;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool on_gpu() const {
|
||||||
|
return m_device_indices.defined();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool on_cpu() const {
|
||||||
|
return m_host_indices.defined();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool on_gpu_only() const {
|
||||||
|
return on_gpu() && !on_cpu();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool on_cpu_only() const {
|
||||||
|
return !on_gpu() && on_cpu();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool on_both() const {
|
||||||
|
return on_gpu() && on_cpu();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t length() const {
|
||||||
|
return m_tokens.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_leaf() const {
|
||||||
|
return m_children.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_leaf_device() const {
|
||||||
|
for (const auto& [_, child] : m_children)
|
||||||
|
if (child->on_gpu()) return false; // at least one child is on the device
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_child(const token_vec_t& v, std::unique_ptr<TreeNode>&& child) {
|
||||||
|
child->m_parent = this;
|
||||||
|
m_children[v] = std::move(child);
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_child(iterator_t it, std::unique_ptr<TreeNode>&& child) {
|
||||||
|
child->m_parent = this;
|
||||||
|
it->second = std::move(child);
|
||||||
|
}
|
||||||
|
|
||||||
|
void erase_child(const token_vec_t& v) {
|
||||||
|
_assert(m_children.erase(v) > 0, "Child node not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator_t find_child(const token_vec_t& v) {
|
||||||
|
return m_children.find(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator_t begin() {
|
||||||
|
return m_children.begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator_t end() {
|
||||||
|
return m_children.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
const_iterator_t begin() const {
|
||||||
|
return m_children.begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
const_iterator_t end() const {
|
||||||
|
return m_children.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
TreeNode* parent() {
|
||||||
|
return m_parent;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set up all data structures except for parent-child relationship
|
||||||
|
friend void split_prefix(TreeNode* new_node, TreeNode* old_node, std::size_t prefix_length) {
|
||||||
|
auto tokens = std::move(old_node->m_tokens);
|
||||||
|
_assert(0 < prefix_length && prefix_length < tokens.size(), "Invalid prefix size for split");
|
||||||
|
|
||||||
|
// set up tokens
|
||||||
|
old_node->m_tokens = token_vec_t(tokens.begin() + prefix_length, tokens.end());
|
||||||
|
new_node->m_tokens = std::move(tokens);
|
||||||
|
new_node->m_tokens.resize(prefix_length);
|
||||||
|
|
||||||
|
// set up values
|
||||||
|
const int64_t new_size = new_node->length();
|
||||||
|
const int64_t old_size = old_node->length();
|
||||||
|
if (old_node->m_device_indices.defined()) {
|
||||||
|
auto new_indices = old_node->m_device_indices.split_with_sizes({new_size, old_size});
|
||||||
|
new_node->m_device_indices = std::move(new_indices[0]);
|
||||||
|
old_node->m_device_indices = std::move(new_indices[1]);
|
||||||
|
}
|
||||||
|
if (old_node->m_host_indices.defined()) {
|
||||||
|
auto new_indices = old_node->m_host_indices.split_with_sizes({new_size, old_size});
|
||||||
|
new_node->m_host_indices = std::move(new_indices[0]);
|
||||||
|
old_node->m_host_indices = std::move(new_indices[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// set up ref counts and hit counts
|
||||||
|
new_node->ref_count = old_node->ref_count;
|
||||||
|
new_node->hit_count = old_node->hit_count;
|
||||||
|
|
||||||
|
// If the old node (child) was locked for IO, the new node (parent) does not need
|
||||||
|
// to be locked, since it is naturally protected by the child node's lock.
|
||||||
|
if (old_node->m_io_locked.has_value()) {
|
||||||
|
new_node->m_io_locked = false;
|
||||||
|
new_node->m_io_status = old_node->m_io_status;
|
||||||
|
new_node->m_io_ticket = old_node->m_io_ticket;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @return The first index in `m_tokens` that differs from `key`.
|
||||||
|
std::size_t diff_key(token_slice key, std::size_t offset) const {
|
||||||
|
const auto a = token_slice{key}.subspan(offset);
|
||||||
|
const auto b = token_slice{m_tokens}.subspan(offset);
|
||||||
|
const auto [it_a, it_b] = std::ranges::mismatch(a, b);
|
||||||
|
return it_a - a.begin(); // return the index of the first differing token
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor device_indices() const {
|
||||||
|
return m_device_indices;
|
||||||
|
}
|
||||||
|
at::Tensor host_indices() const {
|
||||||
|
return m_host_indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
// visiting tokens are always unsafe (use `diff_key` instead)
|
||||||
|
token_vec_t& _unsafe_tokens() {
|
||||||
|
return m_tokens;
|
||||||
|
}
|
||||||
|
at::Tensor& _unsafe_device_indices() {
|
||||||
|
return m_device_indices;
|
||||||
|
}
|
||||||
|
at::Tensor& _unsafe_host_indices() {
|
||||||
|
return m_host_indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_io_free() const {
|
||||||
|
return m_io_status == IOStatus::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_io_device_to_host() const {
|
||||||
|
return m_io_status == IOStatus::DeviceToHost;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_io_host_to_device() const {
|
||||||
|
return m_io_status == IOStatus::HostToDevice;
|
||||||
|
}
|
||||||
|
|
||||||
|
void root_reset() {
|
||||||
|
_assert(is_root(), "Only root node can call root_reset");
|
||||||
|
_assert(
|
||||||
|
m_io_status == IOStatus::None && m_io_locked == std::nullopt,
|
||||||
|
"IO operation in progress, cannot reset root node");
|
||||||
|
_assert(this->m_tokens.empty(), "Root node tokens should be empty on reset");
|
||||||
|
_assert(
|
||||||
|
!this->m_device_indices.defined() && !this->m_host_indices.defined(),
|
||||||
|
"Root node indices should be always be empty and never assigned");
|
||||||
|
m_children.clear();
|
||||||
|
this->access();
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::size_t ref_count;
|
||||||
|
std::size_t hit_count;
|
||||||
|
|
||||||
|
private:
|
||||||
|
enum class IOStatus : std::uint8_t {
|
||||||
|
None,
|
||||||
|
HostToDevice,
|
||||||
|
DeviceToHost,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::optional<bool> m_io_locked; // whether the node is locked in IO operation
|
||||||
|
IOStatus m_io_status;
|
||||||
|
IOTicket m_io_ticket;
|
||||||
|
|
||||||
|
token_vec_t m_tokens;
|
||||||
|
at::Tensor m_device_indices; // indices of device value
|
||||||
|
at::Tensor m_host_indices; // indices of host value
|
||||||
|
TreeNode* m_parent;
|
||||||
|
childern_map_t m_children;
|
||||||
|
timestamp_t m_last_access_time;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const std::size_t node_id; // unique ID for the node
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
inline TreeNode* walk_to_root(TreeNode* t, const F& f) {
|
||||||
|
while (!t->is_root()) {
|
||||||
|
f(t);
|
||||||
|
t = t->parent();
|
||||||
|
}
|
||||||
|
return t; // return the root node
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace radix_tree_v2
|
||||||
229
python/sglang/srt/mem_cache/radix_cache_cpp.py
Normal file
229
python/sglang/srt/mem_cache/radix_cache_cpp.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, List, Set
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||||
|
from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
||||||
|
IOHandle,
|
||||||
|
RadixTreeCpp,
|
||||||
|
TreeNodeCpp,
|
||||||
|
)
|
||||||
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RadixCacheCpp(BasePrefixCache):
|
||||||
|
def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Merge a list of tensors into a single tensor.
|
||||||
|
Args:
|
||||||
|
l (List[torch.Tensor]): List of tensors to merge.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Merged tensor.
|
||||||
|
"""
|
||||||
|
if len(l) == 0:
|
||||||
|
return torch.empty(0, dtype=torch.int64, device=self.device)
|
||||||
|
elif len(l) == 1:
|
||||||
|
return l[0]
|
||||||
|
else:
|
||||||
|
return torch.cat(l)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
disable: bool,
|
||||||
|
use_hicache: bool,
|
||||||
|
req_to_token_pool: ReqToTokenPool,
|
||||||
|
token_to_kv_pool: BaseTokenToKVPoolAllocator,
|
||||||
|
tp_cache_group: torch.distributed.ProcessGroup,
|
||||||
|
page_size: int,
|
||||||
|
hicache_ratio: float,
|
||||||
|
hicache_size: int,
|
||||||
|
hicache_write_policy: str,
|
||||||
|
enable_kv_cache_events: bool = False,
|
||||||
|
hicache_oracle: bool = False,
|
||||||
|
enable_write_cancel: bool = False,
|
||||||
|
):
|
||||||
|
self.disable = disable
|
||||||
|
self.enable_write_cancel = enable_write_cancel
|
||||||
|
|
||||||
|
assert (
|
||||||
|
enable_kv_cache_events is False
|
||||||
|
), "HiRadixCache does not support kv cache events yet"
|
||||||
|
self.kv_cache = token_to_kv_pool.get_kvcache()
|
||||||
|
|
||||||
|
# record the nodes with ongoing write through
|
||||||
|
self.ongoing_write_through: Set[IOHandle] = set()
|
||||||
|
# record the node segments with ongoing load back
|
||||||
|
self.ongoing_load_back: Set[IOHandle] = set()
|
||||||
|
# todo: dynamically adjust the threshold
|
||||||
|
self.write_through_threshold = (
|
||||||
|
1 if hicache_write_policy == "write_through" else 2
|
||||||
|
)
|
||||||
|
self.device = token_to_kv_pool.device
|
||||||
|
self.token_to_kv_pool = token_to_kv_pool
|
||||||
|
self.req_to_token_pool = req_to_token_pool
|
||||||
|
self.page_size = page_size
|
||||||
|
|
||||||
|
self.tp_group = tp_cache_group
|
||||||
|
|
||||||
|
if not use_hicache:
|
||||||
|
self.tree = RadixTreeCpp(
|
||||||
|
disabled=self.disable,
|
||||||
|
page_size=page_size,
|
||||||
|
host_size=None, # no host cache, this should be removed in the future
|
||||||
|
write_through_threshold=self.write_through_threshold,
|
||||||
|
)
|
||||||
|
self.cache_controller = None
|
||||||
|
return # early return if hicache is not used
|
||||||
|
|
||||||
|
raise NotImplementedError("Host cache is not supported yet")
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
if self.cache_controller is not None:
|
||||||
|
# need to clear the acks before resetting the cache controller
|
||||||
|
raise NotImplementedError("Host cache is not supported yet")
|
||||||
|
self.tree.reset()
|
||||||
|
|
||||||
|
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
||||||
|
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
||||||
|
self.tree.match_prefix(key)
|
||||||
|
)
|
||||||
|
return MatchResult(
|
||||||
|
device_indices=self._merge_tensor(device_indices_vec),
|
||||||
|
last_device_node=node_gpu,
|
||||||
|
last_host_node=node_cpu,
|
||||||
|
host_hit_length=host_indices_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _insert(self, key: List[int], value: torch.Tensor) -> int:
|
||||||
|
"""
|
||||||
|
Insert a key-value pair into the radix tree.
|
||||||
|
Args:
|
||||||
|
key (List[int]): The key to insert, represented as a list of integers.
|
||||||
|
value (torch.Tensor): The value to associate with the key.
|
||||||
|
Returns:
|
||||||
|
int: Number of device indices that were already present in the tree before the insertion.
|
||||||
|
"""
|
||||||
|
ongoing_write, length = self.tree.writing_through(key, value)
|
||||||
|
if self.cache_controller is None:
|
||||||
|
assert len(ongoing_write) == 0, "Implementation error"
|
||||||
|
return length
|
||||||
|
|
||||||
|
raise NotImplementedError("Host cache is not supported yet")
|
||||||
|
|
||||||
|
def dec_lock_ref(self, node: TreeNodeCpp):
|
||||||
|
"""
|
||||||
|
Decrement the reference count of a node to root of the radix tree.
|
||||||
|
Args:
|
||||||
|
node (TreeNodeCpp): The handle of the node to decrement the reference count for.
|
||||||
|
"""
|
||||||
|
self.tree.lock_ref(node, False) # do not increment
|
||||||
|
|
||||||
|
def inc_lock_ref(self, node: TreeNodeCpp):
|
||||||
|
"""
|
||||||
|
Increment the reference count of from a node to root of the radix tree.
|
||||||
|
Args:
|
||||||
|
node (TreeNodeCpp): The handle of the node to increment the reference count for.
|
||||||
|
"""
|
||||||
|
self.tree.lock_ref(node, True)
|
||||||
|
|
||||||
|
def evict(self, num_tokens: int):
|
||||||
|
evicted_device_indices = self.tree.evict(num_tokens)
|
||||||
|
for indice in evicted_device_indices:
|
||||||
|
self.token_to_kv_pool.free(indice)
|
||||||
|
|
||||||
|
def evictable_size(self):
|
||||||
|
return self.tree.evictable_size()
|
||||||
|
|
||||||
|
def protected_size(self):
|
||||||
|
return self.tree.protected_size()
|
||||||
|
|
||||||
|
def total_size(self):
|
||||||
|
return self.tree.total_size()
|
||||||
|
|
||||||
|
def cache_finished_req(self, req: Req):
|
||||||
|
"""Cache request when it finishes."""
|
||||||
|
assert req.req_pool_idx is not None
|
||||||
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||||
|
overall_len = len(token_ids) # prefill + decode
|
||||||
|
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
|
||||||
|
|
||||||
|
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||||
|
# it will automatically align them, but length of them should be equal
|
||||||
|
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||||
|
new_prefix_len = self._insert(token_ids, kv_indices)
|
||||||
|
|
||||||
|
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||||
|
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||||
|
|
||||||
|
# KVCache between old & new is newly generated, but already exists in the pool
|
||||||
|
# we need to free this newly generated kv indices
|
||||||
|
if old_prefix_len < new_prefix_len:
|
||||||
|
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||||
|
|
||||||
|
# need to free the unaligned part, since it cannot be inserted into the radix tree
|
||||||
|
if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
|
||||||
|
(unaligned_len := overall_len % self.page_size) > 0
|
||||||
|
):
|
||||||
|
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
|
||||||
|
self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
|
||||||
|
|
||||||
|
# Remove req slot release the cache lock
|
||||||
|
self.dec_lock_ref(req.last_node)
|
||||||
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
|
def cache_unfinished_req(self, req: Req):
|
||||||
|
"""Cache request when it is unfinished."""
|
||||||
|
assert req.req_pool_idx is not None
|
||||||
|
token_ids = req.fill_ids
|
||||||
|
prefill_len = len(token_ids) # prefill only (maybe chunked)
|
||||||
|
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len]
|
||||||
|
|
||||||
|
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||||
|
# it will automatically align them, but length of them should be equal
|
||||||
|
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||||
|
new_prefix_len = self._insert(token_ids, kv_indices)
|
||||||
|
|
||||||
|
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||||
|
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||||
|
|
||||||
|
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
||||||
|
# The prefix indices need to updated to reuse the kv indices in the pool
|
||||||
|
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
|
||||||
|
new_indices = self._merge_tensor(new_indices_vec)
|
||||||
|
assert new_prefix_len <= len(new_indices)
|
||||||
|
|
||||||
|
# KVCache between old & new is newly generated, but already exists in the pool
|
||||||
|
# we need to free this newly generated kv indices and reuse the indices in the pool
|
||||||
|
if old_prefix_len < new_prefix_len:
|
||||||
|
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||||
|
reused_indices = new_indices[old_prefix_len:new_prefix_len]
|
||||||
|
self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx, old_prefix_len:new_prefix_len
|
||||||
|
] = reused_indices
|
||||||
|
|
||||||
|
if req.last_node != new_last_node:
|
||||||
|
self.dec_lock_ref(req.last_node)
|
||||||
|
self.inc_lock_ref(new_last_node)
|
||||||
|
|
||||||
|
# NOTE: there might be unaligned tail, so we may need to append it
|
||||||
|
assert len(new_indices) <= prefill_len < len(new_indices) + self.page_size
|
||||||
|
if self.page_size != 1 and len(new_indices) < prefill_len:
|
||||||
|
req.prefix_indices = torch.cat(
|
||||||
|
[new_indices, kv_indices[len(new_indices) :]]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
req.prefix_indices = new_indices
|
||||||
|
req.last_node = new_last_node
|
||||||
|
|
||||||
|
def pretty_print(self):
|
||||||
|
return self.tree.debug_print()
|
||||||
47
test/srt/test_cpp_radix_cache.py
Normal file
47
test/srt/test_cpp_radix_cache.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCppRadixCache(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
os.environ["SGLANG_EXPERIMENTAL_CPP_RADIX_TREE"] = "1"
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(metrics)
|
||||||
|
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user