diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 57f12dbe0..d2298cf38 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -569,7 +569,23 @@ class Scheduler( page_size=self.page_size, ) else: - if self.enable_hierarchical_cache: + if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1": + # lazy import to avoid JIT overhead + from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp + + self.tree_cache = RadixCacheCpp( + disable=False, + use_hicache=self.enable_hierarchical_cache, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool=self.token_to_kv_pool_allocator, + tp_cache_group=self.tp_cpu_group, + page_size=self.page_size, + hicache_ratio=server_args.hicache_ratio, + hicache_size=server_args.hicache_size, + hicache_write_policy=server_args.hicache_write_policy, + enable_kv_cache_events=self.enable_kv_cache_events, + ) + elif self.enable_hierarchical_cache: self.tree_cache = HiRadixCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format b/python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format new file mode 120000 index 000000000..5a7a8cea7 --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format @@ -0,0 +1 @@ +../../../../../sgl-kernel/.clang-format \ No newline at end of file diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/common.h b/python/sglang/srt/mem_cache/cpp_radix_tree/common.h new file mode 100644 index 000000000..72c2c78be --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/common.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace radix_tree_v2 { + +using token_t = std::int32_t; +using token_vec_t = std::vector; +using token_slice = std::span; +using NodeHandle = std::size_t; +using IOTicket = std::uint32_t; + +inline void _assert( + bool condition, + const char* message = "Assertion failed", + std::source_location loc = std::source_location::current()) { + if (!condition) [[unlikely]] { + std::string msg = message; + msg = msg + " at " + loc.file_name() + ":" + std::to_string(loc.line()) + " in " + loc.function_name(); + throw std::runtime_error(msg); + } +} + +} // namespace radix_tree_v2 diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py b/python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py new file mode 100644 index 000000000..592727aac --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch +from torch.utils.cpp_extension import load + +_abs_path = os.path.dirname(os.path.abspath(__file__)) +radix_tree_cpp = load( + name="radix_tree_cpp", + sources=[ + f"{_abs_path}/tree_v2_binding.cpp", + f"{_abs_path}/tree_v2_debug.cpp", + f"{_abs_path}/tree_v2.cpp", + ], + extra_cflags=["-O3", "-std=c++20"], +) + +if TYPE_CHECKING: + + class TreeNodeCpp: + """ + A placeholder for the TreeNode class. Cannot be constructed elsewhere. + """ + + class IOHandle: + """ + A placeholder for the IOHandle class. Cannot be constructed elsewhere. + """ + + class RadixTreeCpp: + def __init__( + self, + disabled: bool, + host_size: Optional[int], + page_size: int, + write_through_threshold: int, + ): + """ + Initializes the RadixTreeCpp instance. + Args: + disabled (bool): If True, the radix tree is disabled. + host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree. + page_size (int): Size of the page for the radix tree. + write_through_threshold (int): Threshold for writing through from GPU to CPU. + """ + self.tree = radix_tree_cpp.RadixTree( # type: ignore + disabled, host_size, page_size, write_through_threshold + ) + + def match_prefix( + self, prefix: List[int] + ) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]: + """ + Matches a prefix in the radix tree. + Args: + prefix (List[int]): The prefix to match. + Returns: + Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]: + 0. A list of indices that is matched by the prefix on the GPU. + 1. Sum length of the indices matched on the CPU. + 2. The last node of the prefix matched on the GPU. + 3. The last node of the prefix matched on the CPU. + """ + return self.tree.match_prefix(prefix) + + def evict(self, num_tokens: int) -> List[torch.Tensor]: + """ + Evicts a number of tokens from the radix tree. + Args: + num_tokens (int): The number of tokens to evict. + Returns: + List[torch.Tensor]: A list of indices that were evicted. + """ + return self.tree.evict(num_tokens) + + def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None: + """ + Locks or unlocks a reference to a tree node. + After locking, the node will not be evicted from the radix tree. + Args: + handle (TreeNodeCpp): The tree node to lock or unlock. + lock (bool): If True, locks the node; if False, unlocks it. + """ + return self.tree.lock_ref(handle, lock) + + def writing_through( + self, key: List[int], indices: torch.Tensor + ) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]: + """ + Inserts a key-value pair into the radix tree and perform write-through check. + Args: + key (List[int]): The key to insert. + indices (torch.Tensor): The value associated with the key. + Returns: + Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]: + 0. A list of (IOHandle, device indices, host indices) tuples. + These IOhandles require write-through to the CPU in python side. + 1. The number of indices that are matched on device. + """ + return self.tree.writing_through(key, indices) + + def loading_onboard( + self, + host_node: TreeNodeCpp, + new_device_indices: torch.Tensor, + ) -> Tuple[IOHandle, List[torch.Tensor]]: + """ + Updates the device indices of tree nodes within a range on the tree. + Args: + host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node. + new_device_indices (torch.Tensor): The new device indices to set. + The length of this tensor must be exactly host indices length. + Returns: + Tuple[IOHandle, List[torch.Tensor]]: + 0. An IOHandle that requires loading to the CPU in python side. + 1. A list of host indices corresponding to the new device indices. + """ + return self.tree.loading_onboard(host_node, new_device_indices) + + def commit_writing_through(self, handle: IOHandle, success: bool) -> None: + """ + Commits the write-through process for a tree node. + Args: + handle (IOHandle): The IOHandle to commit. + success (bool): If True, commits the write-through; if False, just indicates failure. + """ + return self.tree.commit_writing_through(handle, success) + + def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None: + """ + Commits the load onboard process for tree nodes within a range on the tree. + Args: + handle (IOHandle): The IOHandle to commit. + success (bool): If True, commits the load-onboard; if False, just indicates failure. + """ + return self.tree.commit_loading_onboard(handle, success) + + def evictable_size(self) -> int: + """ + Returns the size of the evictable part of the radix tree. + This is the size of the part that can be evicted from the GPU (ref_count = 0). + Returns: + int: The size of the evictable part. + """ + return self.tree.evictable_size() + + def protected_size(self) -> int: + """ + Returns the size of the protected part of the radix tree. + This is the size of the part that cannot be evicted from the GPU (ref_count > 0). + Returns: + int: The size of the protected part. + """ + return self.tree.protected_size() + + def total_size(self) -> int: + """ + Returns the total size of the radix tree (including CPU nodes). + Returns: + int: The total size of the radix tree. + """ + return self.tree.total_size() + + def reset(self) -> None: + """ + Resets the radix tree, clearing all nodes and indices. + """ + return self.tree.reset() + + def debug_print(self) -> None: + """ + Prints the internal state of the radix tree for debugging purposes. + """ + return self.tree.debug_print() + +else: + # Real implementation of the classes for runtime + RadixTreeCpp = radix_tree_cpp.RadixTree + TreeNodeCpp = object + IOHandle = object diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp new file mode 100644 index 000000000..2a5433221 --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp @@ -0,0 +1,143 @@ +#include "tree_v2.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "tree_v2_impl.h" +#include "tree_v2_node.h" + +namespace radix_tree_v2 { + +static NodeHandle node2id(TreeNode* node) { + return node->node_id; +} + +// compare function for the TreeNode pointers based on their time +// we use LRU, so we want to evict the least recently used nodes +// since std::priority_queue is a max-heap, we need to reverse the comparison +static constexpr auto cmp = [](TreeNode* lhs, TreeNode* rhs) { return lhs->time() > rhs->time(); }; + +RadixTree::RadixTree(bool disabled, std::optional host_size, std::size_t page_size, std::size_t threshold) + : m_impl(std::make_unique(disabled, host_size.has_value(), page_size, host_size.value_or(0), threshold)) {} + +RadixTree::~RadixTree() = default; + +std::tuple, std::size_t, NodeHandle, NodeHandle> +RadixTree::match_prefix(const token_vec_t& _key) { + if (m_impl->disabled) return {}; + + const auto key = token_slice{_key.data(), m_impl->align(_key.size())}; + const auto [host_node, _] = m_impl->tree_walk(key); + + // walk up to the first non-evicted node + std::size_t host_hit_length = 0; + const auto device_node = host_node; + + // collect all the device indices + std::vector indices{}; + walk_to_root(device_node, [&](TreeNode* n) { indices.push_back(n->device_indices()); }); + std::reverse(indices.begin(), indices.end()); + + return {std::move(indices), host_hit_length, node2id(device_node), node2id(host_node)}; +} + +std::vector RadixTree::evict(std::size_t num_tokens) { + if (m_impl->disabled || num_tokens == 0) return {}; + auto heap = std::priority_queue{cmp, m_impl->collect_leaves_device()}; + std::vector evicted_values; + // evict nodes until we reach the desired number of tokens + std::size_t num_evict = 0; + while (num_evict < num_tokens && !heap.empty()) { + const auto node = heap.top(); + heap.pop(); + // when ref_count == 0, can't be writing through + _assert(node->on_gpu() && node->ref_count == 0); + if (!node->is_io_free()) continue; // skip nodes that are undergoing IO (i.e. indices protected) + evicted_values.push_back(node->device_indices()); + num_evict += node->length(); + const auto parent = node->parent(); + m_impl->remove_device_node(node); + if (parent->is_leaf_device() && parent->ref_count == 0) + heap.push(parent); // push parent to the heap if it is now a free leaf + } + + return evicted_values; +} + +std::tuple>, std::size_t> +RadixTree::writing_through(const token_vec_t& _key, at::Tensor value) { + if (m_impl->disabled) return {}; + _assert(_key.size() == std::size_t(value.size(0)), "Key and value must have the same size"); + + // just align the key to the page size, clip the unaligned tail + const auto key = token_slice{_key.data(), m_impl->align(_key.size())}; + + // walk the tree to find the right place to insert + const auto [host_node, host_prefix_length] = m_impl->tree_walk(key); + + // insert and create a new node if the remaining part of the key is not empty + if (host_prefix_length != key.size()) { + m_impl->create_device_node( + host_node, + {key.begin() + host_prefix_length, key.end()}, + value.slice(/*dim=*/0, host_prefix_length, key.size())); + } + + // add the hit count for the device node + walk_to_root(host_node, [&](TreeNode* n) { n->hit_count++; }); + + std::vector> result; + + // don't write through if hicache is disabled (no host memory), fast path + if (!m_impl->use_hicache) return {std::move(result), host_prefix_length}; + throw std::runtime_error("Not implemented yet"); +} + +std::tuple> RadixTree::loading_onboard(NodeHandle, at::Tensor) { + if (m_impl->disabled) return {}; + throw std::runtime_error("Not implemented yet"); +} + +void RadixTree::commit_writing_through(IOTicket, bool) { + if (m_impl->disabled) return; + throw std::runtime_error("Not implemented yet"); +} + +void RadixTree::commit_loading_onboard(IOTicket, bool) { + if (m_impl->disabled) return; + throw std::runtime_error("Not implemented yet"); +} + +void RadixTree::reset() { + m_impl->reset(); +} + +void RadixTree::lock_ref(NodeHandle node_id, bool increment) { + if (m_impl->disabled) return; + m_impl->lock_ref(node_id, increment); +} + +std::size_t RadixTree::evictable_size() const { + return m_impl->evictable_size(); +} + +std::size_t RadixTree::protected_size() const { + return m_impl->protected_size(); +} + +std::size_t RadixTree::total_size() const { + return m_impl->total_size(); +} + +} // namespace radix_tree_v2 diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h new file mode 100644 index 000000000..68da9b9e1 --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h @@ -0,0 +1,59 @@ +#pragma once +#include +#include + +#include +#include +#include +#include +#include + +#include "common.h" + +namespace radix_tree_v2 { + +struct RadixTree { + public: + RadixTree(bool disabled, std::optional host_size, std::size_t page_size, std::size_t threshold); + ~RadixTree(); + + // Trees should not be copied or moved, as they manage their own memory and state. + RadixTree(const RadixTree&) = delete; + RadixTree(RadixTree&&) = delete; + RadixTree& operator=(const RadixTree&) = delete; + RadixTree& operator=(RadixTree&&) = delete; + + /// @return (device indices that are matched, host indices length, device node, host node) + std::tuple, std::size_t, NodeHandle, NodeHandle> match_prefix(const token_vec_t& key); + /// @return Device indices that need to be evicted (on python side). + std::vector evict(std::size_t num_tokens); + /// @brief (Un-)Lock a node. + void lock_ref(NodeHandle node_id, bool increment /* increment or decrement */); + /// @brief Update new key-value pair and try to perform write-through. + std::tuple>, std::size_t> + writing_through(const token_vec_t& key, at::Tensor value); + /// @brief Load to device from host within a range of nodes. + std::tuple> loading_onboard(NodeHandle host_id, at::Tensor indices); + /// @brief Commit a transaction of write-through. + void commit_writing_through(IOTicket ticket, bool success); + /// @brief Commit a transaction of load onboard. + void commit_loading_onboard(IOTicket ticket, bool success); + /// @brief Clear and reset the tree. + void reset(); + + /// @return How many size are still evictable (on device + not locked). + std::size_t evictable_size() const; + /// @return How many size are protected (locked). + std::size_t protected_size() const; + /// @return How many size are used on device. + std::size_t total_size() const; + + /// @brief Print debug information of the tree. + void debug_print() const; + + private: + struct Impl; + std::unique_ptr m_impl; +}; + +} // namespace radix_tree_v2 diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_binding.cpp b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_binding.cpp new file mode 100644 index 000000000..81069e4fe --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_binding.cpp @@ -0,0 +1,32 @@ +#include +#include +#include + +#include +#include + +#include "tree_v2.h" + +PYBIND11_MODULE(radix_tree_cpp, m) { + using namespace radix_tree_v2; + namespace py = pybind11; + py::class_(m, "RadixTree") + .def( + py::init, std::size_t, std::size_t>(), + py::arg("disabled"), + py::arg("host_size"), + py::arg("page_size"), + py::arg("write_through_threshold")) + .def("match_prefix", &RadixTree::match_prefix) + .def("evict", &RadixTree::evict) + .def("lock_ref", &RadixTree::lock_ref) + .def("evictable_size", &RadixTree::evictable_size) + .def("protected_size", &RadixTree::protected_size) + .def("total_size", &RadixTree::total_size) + .def("writing_through", &RadixTree::writing_through) + .def("loading_onboard", &RadixTree::loading_onboard) + .def("commit_writing_through", &RadixTree::commit_writing_through) + .def("commit_loading_onboard", &RadixTree::commit_loading_onboard) + .def("reset", &RadixTree::reset) + .def("debug_print", &RadixTree::debug_print); +} diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp new file mode 100644 index 000000000..89b6290b1 --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp @@ -0,0 +1,194 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "tree_v2.h" +#include "tree_v2_impl.h" + +namespace radix_tree_v2 { + +void RadixTree::debug_print() const { + m_impl->debug_print(std::clog); +} + +static constexpr auto npos = std::size_t(-1); + +void RadixTree::Impl::debug_print(std::ostream& os) const { + static constexpr auto _check = [](bool condition, auto msg, std::size_t id = npos) { + if (!condition) { + std::string suffix = id == npos ? "" : " [id = " + std::to_string(id) + "]"; + throw std::runtime_error(std::string("RadixTree::debug_print failed: ") + msg + suffix); + } + }; + + static constexpr auto _print_node = [](TreeNode* node, std::size_t depth, std::ostream& os) { + const auto length = node->length(); + os << node->node_id << " [depth = " << depth << "] [len = " << length << "]"; + + // placement status + if (node->on_both()) { + os << " [cpu + gpu]"; + } else if (node->on_gpu()) { + os << " [gpu]"; + } else if (node->on_cpu()) { + os << " [cpu]"; + } else { + _check(false, "Node is not on GPU or CPU", node->node_id); + } + + // IO status + if (node->is_io_free()) { + os << " [io = free]"; + } else if (node->is_io_device_to_host()) { + os << " [io = gpu -> cpu]"; + } else if (node->is_io_host_to_device()) { + os << " [io = cpu -> gpu]"; + } else { + _check(false, "Node is in unknown IO state", node->node_id); + } + + os << " [rc = " << node->ref_count << "]"; + os << " [hit = " << node->hit_count << "]"; + }; + + static constexpr auto _print_indices = [](at::Tensor indices, std::ostream& os) { + if (!indices.defined()) { + os << "[[N/A]]"; + return indices; + } + indices = indices.to(c10::kCPU, c10::kLong, false, false, c10::MemoryFormat::Contiguous); + const auto length = indices.numel(); + os << "["; + auto* data_ptr = indices.data_ptr(); + for (const auto i : c10::irange(indices.size(0))) { + os << data_ptr[i]; + if (i != length - 1) os << ", "; + } + os << "]"; + return indices; + }; + + os << "Evictable size: " << evictable_size() << std::endl; + os << "Protected size: " << protected_size() << std::endl; + os << "Total size: " << const_cast(this)->total_size() << std::endl; + std::vector> stack; + auto root = const_cast(&m_root); + os << root->node_id << " [root]" << std::endl; + for (const auto& [key, child] : *root) { + stack.push_back({child.get(), root, key}); + } + + std::unordered_map depth_map; + std::string indent_buffer; + depth_map[root] = 0; + std::vector visited_id; + std::size_t evictable_size_real = 0; + while (!stack.empty()) { + const auto [node, parent, key] = stack.back(); + stack.pop_back(); + visited_id.push_back(node->node_id); + const auto nid = node->node_id; + _check(node != nullptr, "Node is null", nid); + _check(node->on_gpu() || node->on_cpu(), "Node is not on GPU or CPU", nid); + _check(node->parent() == parent, "Parent is not correct", nid); + _check(key.size() == page_size && node->diff_key(key, 0) == page_size, "Key is not correct", nid); + _check(depth_map.count(node) == 0, "Node is visited twice", nid); + _check(m_node_map.count(nid) == 1, "Node is not in the map", nid); + _check(m_node_map.at(nid) == node, "Node in the map is not the same as the one in the stack", nid); + _check(!node->on_gpu() || parent->is_root() || parent->on_gpu(), "Node on GPU must have a GPU/root parent", nid); + if (!node->is_io_free()) { + _check(node->ref_count > 0, "Node is in IO state but not protected", nid); + _check(node->on_both(), "Node in IO state must be on both CPU and GPU", nid); + } + + if (node->on_gpu() && node->ref_count == 0) { + evictable_size_real += node->length(); + } + + const auto depth = (depth_map[node] = depth_map[parent] + 1); + indent_buffer.resize(depth * 2, ' '); + os << indent_buffer; + _print_node(node, depth, os); + os << std::endl; + for (const auto& [key, child] : *node) { + stack.push_back({child.get(), node, key}); + } + } + + _check(evictable_size_real == evictable_size(), "Evictable size is wrong"); + _check(m_node_map.count(root->node_id) == 1, "Root node is not in the map"); + _check(m_node_map.at(root->node_id) == root, "Root node in the map is not correct"); + + std::sort(visited_id.begin(), visited_id.end()); + if (visited_id.size() != m_node_map.size() - 1) { + // Some error in the tree, not all nodes are visited + std::string id_list; + id_list += "(visited: "; + id_list += std::to_string(root->node_id) + " "; + for (const auto& id : visited_id) { + id_list += std::to_string(id) + " "; + } + id_list += "), (in map: "; + for (const auto& [id, _] : m_node_map) { + id_list += std::to_string(id) + " "; + } + id_list += ")"; + _check(false, "Not all nodes are visited " + id_list); + } + + static const auto kSGLANG_RADIX_CPP_DEBUG_LIMIT = [] { + const char* env = std::getenv("SGLANG_RADIX_CPP_DEBUG_LIMIT"); + const std::size_t default_limit = 16; + if (env != nullptr) { + try { + return static_cast(std::stoull(env)); + } catch (const std::exception& e) { + std::cerr << "Invalid SGLANG_RADIX_CPP_DEBUG_LIMIT value: " << env // + << ". Using default value =" << default_limit << std::endl; + } + } + return default_limit; + }(); + + for (const auto nid : visited_id) { + const auto node = m_node_map.at(nid); + // print key and indices + const auto& key = node->_unsafe_tokens(); + if (key.size() > kSGLANG_RADIX_CPP_DEBUG_LIMIT) { + os << "Node " << nid << ": key is too long (" << key.size() << " tokens), skipping..." << std::endl; + continue; + } + os << "Node " << nid << ": key = ["; + for (const auto& i : c10::irange(key.size())) { + os << key[i]; + if (i != key.size() - 1) os << ", "; + } + + _check(key.size() % page_size == 0, "Misaligned key", nid); + + os << "] device_indices = "; + const auto device_indices = _print_indices(node->device_indices(), os); + if (device_indices.defined()) { + std::size_t length = device_indices.numel(); + _check(device_indices.dim() == 1, "Device indices must be 1D tensor", nid); + _check(length == node->length(), "Wrong device indices size", nid); + } + + os << " host_indices = "; + const auto host_indices = _print_indices(node->host_indices(), os); + if (host_indices.defined()) { + std::size_t length = host_indices.numel(); + _check(host_indices.dim() == 1, "Host indices must be 1D tensor", nid); + _check(length == node->length(), "Wrong host indices size", nid); + } + os << std::endl; + } +} + +} // namespace radix_tree_v2 diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h new file mode 100644 index 000000000..cb9f9dde5 --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h @@ -0,0 +1,276 @@ +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "tree_v2.h" +#include "tree_v2_node.h" + +namespace radix_tree_v2 { + +using node_iterator_t = typename TreeNode::iterator_t; + +struct RadixTree::Impl { + public: + Impl(bool disabled, bool use_hicache, std::size_t page_size, std::size_t host_size, std::size_t threshold) + : m_root(/*node_id_=*/0), + m_evictable_size(0), + m_protected_size(0), + m_cached_vec(), + m_node_map(), + m_node_counter(1), // start from 1 to avoid confusion with root node + disabled(disabled), + use_hicache(use_hicache), + page_size(page_size), + threshold(threshold) { + _assert(page_size > 0, "Page size must be greater than zero"); + _assert(use_hicache == (host_size > 0), "Hierarchical cache is enabled iff host size > 0"); + m_root.ref_count = 1; // root node is always protected + m_cached_vec.reserve(page_size); // to avoid repeated allocations + m_node_map[m_root.node_id] = &m_root; // add root to the map + } + + TreeNode* split_node(node_iterator_t iterator, std::size_t prefix_length) { + // from `parent -> old_node` to `parent-> new_node -> old_node` + // the prefix part of the old node is moved to the new node + auto old_node_ptr = std::move(iterator->second); + auto new_node_ptr = std::make_unique(m_node_counter++); + auto* old_node = old_node_ptr.get(); + auto* new_node = new_node_ptr.get(); + auto* parent = old_node->parent(); + // set up data structures + split_prefix(new_node, old_node, prefix_length); + // set up parent-child relationship + add_child(new_node, std::move(old_node_ptr)); + add_child(parent, std::move(new_node_ptr), iterator); + m_node_map[new_node->node_id] = new_node; // add to the map + return new_node; + } + + // node: x -> [GPU] + TreeNode* create_device_node(TreeNode* parent, token_vec_t vec, at::Tensor indices) { + auto new_node_ptr = std::make_unique(m_node_counter++); + auto new_node = new_node_ptr.get(); + new_node_ptr->_unsafe_tokens() = std::move(vec); + new_node_ptr->_unsafe_device_indices() = std::move(indices); + m_evictable_size += new_node_ptr->length(); + add_child(parent, std::move(new_node_ptr)); + m_node_map[new_node->node_id] = new_node; // add to the map + return new_node; + } + + // node: [GPU] -> x + void remove_device_node(TreeNode* node) { + _assert(node->on_gpu_only() && node->ref_count == 0); + m_evictable_size -= node->length(); + node->parent()->erase_child(get_key(node)); + m_node_map.erase(node->node_id); // remove from the map + } + + /** + * @brief Walk the tree to find the node that matches the key. + * If the key partially matches a node, it will split that node. + * @return A pair containing the last node that matches the key and + * the total prefix length matched (on gpu and cpu) so far. + */ + std::pair tree_walk(token_slice key) { + _assert(key.size() % page_size == 0, "Key should be page-aligned"); + + std::size_t total_prefix_length = 0; + TreeNode* node = &m_root; + + const auto now = std::chrono::steady_clock::now(); + while (key.size() > 0) { + const auto iterator = node->find_child(get_key(key)); + if (iterator == node->end()) break; + + // walk to the child node + node = iterator->second.get(); + + // at least `page_size` tokens are matched, and there may be more tokens to match + // the return value prefix_length is no less than `page_size` + const auto prefix_length = align(node->diff_key(key, page_size) + page_size); + total_prefix_length += prefix_length; + + // split the node if the prefix is not the whole token vector + if (prefix_length < node->length()) { + return {split_node(iterator, prefix_length), total_prefix_length}; + } + + // we have matched the whole key, continue to the next node + node->access(now); + key = key.subspan(prefix_length); + } + + return {node, total_prefix_length}; + } + + std::vector collect_leaves() const { + std::vector leaves; + std::vector stack = {}; + for (const auto& [_, child] : m_root) { + stack.push_back(child.get()); + } + while (!stack.empty()) { + const auto node = stack.back(); + stack.pop_back(); + if (node->is_leaf()) { + if (node->ref_count == 0) { + leaves.push_back(node); + } + } else { + for (const auto& [_, child] : *node) { + stack.push_back(child.get()); + } + } + } + return leaves; + } + + std::vector collect_leaves_device() const { + // for non-hicache, every leaf device node is a leaf node (since no backup on host) + if (!use_hicache) return collect_leaves(); + std::vector leaves; + std::vector stack = {}; + for (const auto& [_, child] : m_root) { + stack.push_back(child.get()); + } + while (!stack.empty()) { + const auto node = stack.back(); + stack.pop_back(); + if (!node->on_gpu()) continue; // skip nodes that are not on GPU + if (node->is_leaf_device()) { + if (node->ref_count == 0) { + leaves.push_back(node); + } + } else { + for (const auto& [_, child] : *node) { + stack.push_back(child.get()); + } + } + } + return leaves; + } + + void lock_ref(TreeNode* node, bool increment) { + if (node->is_root()) return; // skip root node + _assert(node->on_gpu(), "Cannot lock reference on an evicted node"); + if (increment) + walk_to_root(node, [this](TreeNode* n) { + if (n->ref_count == 0) { + m_evictable_size -= n->length(); + m_protected_size += n->length(); + } + n->ref_count++; + }); + else + walk_to_root(node, [this](TreeNode* n) { + _assert(n->ref_count != 0, "Cannot decrement reference count = zero"); + n->ref_count--; + if (n->ref_count == 0) { + m_protected_size -= n->length(); + m_evictable_size += n->length(); + } + }); + } + + void lock_ref(NodeHandle node_ptr, bool increment) { + return lock_ref(id2node(node_ptr), increment); + } + + void lock(TreeNode* node) { + return lock_ref(node, /*increment=*/true); + } + + void unlock(TreeNode* node) { + return lock_ref(node, /*increment=*/false); + } + + std::size_t total_size() const { + std::size_t size = 0; + std::vector stack = {&m_root}; + while (!stack.empty()) { + auto* node = stack.back(); + stack.pop_back(); + size += node->length(); + for (const auto& [_, child] : *node) + stack.push_back(child.get()); + } + return size; + } + + std::size_t evictable_size() const { + return m_evictable_size; + } + + std::size_t protected_size() const { + return m_protected_size; + } + + std::size_t align(std::size_t size) const { + return (size / page_size) * page_size; // align to page size + } + + TreeNode* id2node(NodeHandle node_id) const { + const auto iterator = m_node_map.find(node_id); + _assert(iterator != m_node_map.end(), "Node not found in the map"); + return iterator->second; + } + + void reset() { + _assert(m_root.ref_count == 1, "Root node must be protected during reset"); + m_node_counter = 1; // reset node counter + m_root.root_reset(); + m_evictable_size = 0; + m_protected_size = 0; + m_node_map.clear(); + m_node_map[m_root.node_id] = &m_root; // re-add root to the map + } + + void debug_print(std::ostream& os) const; + + private: + // some auxiliary functions + token_vec_t& get_key(token_slice tokens) { + _assert(tokens.size() >= page_size, "Key should be at least page-sized"); + tokens = tokens.subspan(0, page_size); + m_cached_vec.assign(tokens.begin(), tokens.end()); + return m_cached_vec; + } + + // justify for _unsafe call: we need to read the key part of the tokens + token_vec_t& get_key(TreeNode* node) { + return get_key(node->_unsafe_tokens()); + } + + void add_child(TreeNode* parent, std::unique_ptr&& child) { + parent->add_child(get_key(child.get()), std::move(child)); + } + + void add_child(TreeNode* parent, std::unique_ptr&& child, node_iterator_t it) { + parent->add_child(it, std::move(child)); + } + + TreeNode m_root; // root node of the tree + std::size_t m_evictable_size; // number of evictable tokens on GPU (lock ref = 0) + std::size_t m_protected_size; // number of protected tokens on GPU (lock ref > 0) + token_vec_t m_cached_vec; // cached vector of tokens for the current operation + std::unordered_map m_node_map; // map of node keys to nodes + std::size_t m_node_counter; // counter for node IDs + + public: + // some public constant configurations (without m_ prefix) + const bool disabled; // whether the cache is enabled, or just a temporary cache + const bool use_hicache; // whether to use the HiCache for this tree + const std::size_t page_size; // size of each page in the cache + const std::size_t threshold; // threshold for write_through +}; + +} // namespace radix_tree_v2 diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h new file mode 100644 index 000000000..4eac86ea4 --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h @@ -0,0 +1,257 @@ +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" + +namespace radix_tree_v2 { + +struct std_vector_hash { + // see https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector + std::size_t operator()(const token_vec_t& vec) const { + std::size_t hash = 0; + for (const auto& token : vec) { + hash ^= token + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + return hash; + } +}; + +struct TreeNode { + public: + using childern_map_t = std::unordered_map, std_vector_hash>; + using iterator_t = typename childern_map_t::iterator; + using const_iterator_t = typename childern_map_t::const_iterator; + using timestamp_t = std::chrono::steady_clock::time_point; + + TreeNode(std::size_t node_id_) + : ref_count(0), + hit_count(0), + m_io_locked(std::nullopt), + m_io_status(IOStatus::None), + m_io_ticket(), + m_tokens(), + m_device_indices(), + m_host_indices(), + m_parent(), + m_children(), + m_last_access_time(std::chrono::steady_clock::now()), + node_id(node_id_) {} + + void access(timestamp_t time = std::chrono::steady_clock::now()) { + m_last_access_time = time; + } + + bool is_root() const { + return m_parent == nullptr; + } + + timestamp_t time() const { + return m_last_access_time; + } + + bool on_gpu() const { + return m_device_indices.defined(); + } + + bool on_cpu() const { + return m_host_indices.defined(); + } + + bool on_gpu_only() const { + return on_gpu() && !on_cpu(); + } + + bool on_cpu_only() const { + return !on_gpu() && on_cpu(); + } + + bool on_both() const { + return on_gpu() && on_cpu(); + } + + std::size_t length() const { + return m_tokens.size(); + } + + bool is_leaf() const { + return m_children.empty(); + } + + bool is_leaf_device() const { + for (const auto& [_, child] : m_children) + if (child->on_gpu()) return false; // at least one child is on the device + return true; + } + + void add_child(const token_vec_t& v, std::unique_ptr&& child) { + child->m_parent = this; + m_children[v] = std::move(child); + } + + void add_child(iterator_t it, std::unique_ptr&& child) { + child->m_parent = this; + it->second = std::move(child); + } + + void erase_child(const token_vec_t& v) { + _assert(m_children.erase(v) > 0, "Child node not found"); + } + + iterator_t find_child(const token_vec_t& v) { + return m_children.find(v); + } + + iterator_t begin() { + return m_children.begin(); + } + + iterator_t end() { + return m_children.end(); + } + + const_iterator_t begin() const { + return m_children.begin(); + } + + const_iterator_t end() const { + return m_children.end(); + } + + TreeNode* parent() { + return m_parent; + } + + // set up all data structures except for parent-child relationship + friend void split_prefix(TreeNode* new_node, TreeNode* old_node, std::size_t prefix_length) { + auto tokens = std::move(old_node->m_tokens); + _assert(0 < prefix_length && prefix_length < tokens.size(), "Invalid prefix size for split"); + + // set up tokens + old_node->m_tokens = token_vec_t(tokens.begin() + prefix_length, tokens.end()); + new_node->m_tokens = std::move(tokens); + new_node->m_tokens.resize(prefix_length); + + // set up values + const int64_t new_size = new_node->length(); + const int64_t old_size = old_node->length(); + if (old_node->m_device_indices.defined()) { + auto new_indices = old_node->m_device_indices.split_with_sizes({new_size, old_size}); + new_node->m_device_indices = std::move(new_indices[0]); + old_node->m_device_indices = std::move(new_indices[1]); + } + if (old_node->m_host_indices.defined()) { + auto new_indices = old_node->m_host_indices.split_with_sizes({new_size, old_size}); + new_node->m_host_indices = std::move(new_indices[0]); + old_node->m_host_indices = std::move(new_indices[1]); + } + + // set up ref counts and hit counts + new_node->ref_count = old_node->ref_count; + new_node->hit_count = old_node->hit_count; + + // If the old node (child) was locked for IO, the new node (parent) does not need + // to be locked, since it is naturally protected by the child node's lock. + if (old_node->m_io_locked.has_value()) { + new_node->m_io_locked = false; + new_node->m_io_status = old_node->m_io_status; + new_node->m_io_ticket = old_node->m_io_ticket; + } + } + + /// @return The first index in `m_tokens` that differs from `key`. + std::size_t diff_key(token_slice key, std::size_t offset) const { + const auto a = token_slice{key}.subspan(offset); + const auto b = token_slice{m_tokens}.subspan(offset); + const auto [it_a, it_b] = std::ranges::mismatch(a, b); + return it_a - a.begin(); // return the index of the first differing token + } + + at::Tensor device_indices() const { + return m_device_indices; + } + at::Tensor host_indices() const { + return m_host_indices; + } + + // visiting tokens are always unsafe (use `diff_key` instead) + token_vec_t& _unsafe_tokens() { + return m_tokens; + } + at::Tensor& _unsafe_device_indices() { + return m_device_indices; + } + at::Tensor& _unsafe_host_indices() { + return m_host_indices; + } + + bool is_io_free() const { + return m_io_status == IOStatus::None; + } + + bool is_io_device_to_host() const { + return m_io_status == IOStatus::DeviceToHost; + } + + bool is_io_host_to_device() const { + return m_io_status == IOStatus::HostToDevice; + } + + void root_reset() { + _assert(is_root(), "Only root node can call root_reset"); + _assert( + m_io_status == IOStatus::None && m_io_locked == std::nullopt, + "IO operation in progress, cannot reset root node"); + _assert(this->m_tokens.empty(), "Root node tokens should be empty on reset"); + _assert( + !this->m_device_indices.defined() && !this->m_host_indices.defined(), + "Root node indices should be always be empty and never assigned"); + m_children.clear(); + this->access(); + } + + public: + std::size_t ref_count; + std::size_t hit_count; + + private: + enum class IOStatus : std::uint8_t { + None, + HostToDevice, + DeviceToHost, + }; + + std::optional m_io_locked; // whether the node is locked in IO operation + IOStatus m_io_status; + IOTicket m_io_ticket; + + token_vec_t m_tokens; + at::Tensor m_device_indices; // indices of device value + at::Tensor m_host_indices; // indices of host value + TreeNode* m_parent; + childern_map_t m_children; + timestamp_t m_last_access_time; + + public: + const std::size_t node_id; // unique ID for the node +}; + +template +inline TreeNode* walk_to_root(TreeNode* t, const F& f) { + while (!t->is_root()) { + f(t); + t = t->parent(); + } + return t; // return the root node +} + +} // namespace radix_tree_v2 diff --git a/python/sglang/srt/mem_cache/radix_cache_cpp.py b/python/sglang/srt/mem_cache/radix_cache_cpp.py new file mode 100644 index 000000000..5234f1a0f --- /dev/null +++ b/python/sglang/srt/mem_cache/radix_cache_cpp.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, List, Set + +import torch + +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import ( + IOHandle, + RadixTreeCpp, + TreeNodeCpp, +) +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + + +logger = logging.getLogger(__name__) + + +class RadixCacheCpp(BasePrefixCache): + def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor: + """ + Merge a list of tensors into a single tensor. + Args: + l (List[torch.Tensor]): List of tensors to merge. + Returns: + torch.Tensor: Merged tensor. + """ + if len(l) == 0: + return torch.empty(0, dtype=torch.int64, device=self.device) + elif len(l) == 1: + return l[0] + else: + return torch.cat(l) + + def __init__( + self, + disable: bool, + use_hicache: bool, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool: BaseTokenToKVPoolAllocator, + tp_cache_group: torch.distributed.ProcessGroup, + page_size: int, + hicache_ratio: float, + hicache_size: int, + hicache_write_policy: str, + enable_kv_cache_events: bool = False, + hicache_oracle: bool = False, + enable_write_cancel: bool = False, + ): + self.disable = disable + self.enable_write_cancel = enable_write_cancel + + assert ( + enable_kv_cache_events is False + ), "HiRadixCache does not support kv cache events yet" + self.kv_cache = token_to_kv_pool.get_kvcache() + + # record the nodes with ongoing write through + self.ongoing_write_through: Set[IOHandle] = set() + # record the node segments with ongoing load back + self.ongoing_load_back: Set[IOHandle] = set() + # todo: dynamically adjust the threshold + self.write_through_threshold = ( + 1 if hicache_write_policy == "write_through" else 2 + ) + self.device = token_to_kv_pool.device + self.token_to_kv_pool = token_to_kv_pool + self.req_to_token_pool = req_to_token_pool + self.page_size = page_size + + self.tp_group = tp_cache_group + + if not use_hicache: + self.tree = RadixTreeCpp( + disabled=self.disable, + page_size=page_size, + host_size=None, # no host cache, this should be removed in the future + write_through_threshold=self.write_through_threshold, + ) + self.cache_controller = None + return # early return if hicache is not used + + raise NotImplementedError("Host cache is not supported yet") + + def reset(self): + if self.cache_controller is not None: + # need to clear the acks before resetting the cache controller + raise NotImplementedError("Host cache is not supported yet") + self.tree.reset() + + def match_prefix(self, key: List[int], **kwargs) -> MatchResult: + device_indices_vec, host_indices_length, node_gpu, node_cpu = ( + self.tree.match_prefix(key) + ) + return MatchResult( + device_indices=self._merge_tensor(device_indices_vec), + last_device_node=node_gpu, + last_host_node=node_cpu, + host_hit_length=host_indices_length, + ) + + def _insert(self, key: List[int], value: torch.Tensor) -> int: + """ + Insert a key-value pair into the radix tree. + Args: + key (List[int]): The key to insert, represented as a list of integers. + value (torch.Tensor): The value to associate with the key. + Returns: + int: Number of device indices that were already present in the tree before the insertion. + """ + ongoing_write, length = self.tree.writing_through(key, value) + if self.cache_controller is None: + assert len(ongoing_write) == 0, "Implementation error" + return length + + raise NotImplementedError("Host cache is not supported yet") + + def dec_lock_ref(self, node: TreeNodeCpp): + """ + Decrement the reference count of a node to root of the radix tree. + Args: + node (TreeNodeCpp): The handle of the node to decrement the reference count for. + """ + self.tree.lock_ref(node, False) # do not increment + + def inc_lock_ref(self, node: TreeNodeCpp): + """ + Increment the reference count of from a node to root of the radix tree. + Args: + node (TreeNodeCpp): The handle of the node to increment the reference count for. + """ + self.tree.lock_ref(node, True) + + def evict(self, num_tokens: int): + evicted_device_indices = self.tree.evict(num_tokens) + for indice in evicted_device_indices: + self.token_to_kv_pool.free(indice) + + def evictable_size(self): + return self.tree.evictable_size() + + def protected_size(self): + return self.tree.protected_size() + + def total_size(self): + return self.tree.total_size() + + def cache_finished_req(self, req: Req): + """Cache request when it finishes.""" + assert req.req_pool_idx is not None + token_ids = (req.origin_input_ids + req.output_ids)[:-1] + overall_len = len(token_ids) # prefill + decode + kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len] + + # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned + # it will automatically align them, but length of them should be equal + old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size + new_prefix_len = self._insert(token_ids, kv_indices) + + # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices + assert old_prefix_len <= new_prefix_len, "Wrong prefix indices" + + # KVCache between old & new is newly generated, but already exists in the pool + # we need to free this newly generated kv indices + if old_prefix_len < new_prefix_len: + self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len]) + + # need to free the unaligned part, since it cannot be inserted into the radix tree + if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1 + (unaligned_len := overall_len % self.page_size) > 0 + ): + # NOTE: sglang PagedAllocator support unaligned free (which will automatically align it) + self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :]) + + # Remove req slot release the cache lock + self.dec_lock_ref(req.last_node) + self.req_to_token_pool.free(req.req_pool_idx) + + def cache_unfinished_req(self, req: Req): + """Cache request when it is unfinished.""" + assert req.req_pool_idx is not None + token_ids = req.fill_ids + prefill_len = len(token_ids) # prefill only (maybe chunked) + kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len] + + # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned + # it will automatically align them, but length of them should be equal + old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size + new_prefix_len = self._insert(token_ids, kv_indices) + + # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices + assert old_prefix_len <= new_prefix_len, "Wrong prefix indices" + + # TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function) + # The prefix indices need to updated to reuse the kv indices in the pool + new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids) + new_indices = self._merge_tensor(new_indices_vec) + assert new_prefix_len <= len(new_indices) + + # KVCache between old & new is newly generated, but already exists in the pool + # we need to free this newly generated kv indices and reuse the indices in the pool + if old_prefix_len < new_prefix_len: + self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len]) + reused_indices = new_indices[old_prefix_len:new_prefix_len] + self.req_to_token_pool.req_to_token[ + req.req_pool_idx, old_prefix_len:new_prefix_len + ] = reused_indices + + if req.last_node != new_last_node: + self.dec_lock_ref(req.last_node) + self.inc_lock_ref(new_last_node) + + # NOTE: there might be unaligned tail, so we may need to append it + assert len(new_indices) <= prefill_len < len(new_indices) + self.page_size + if self.page_size != 1 and len(new_indices) < prefill_len: + req.prefix_indices = torch.cat( + [new_indices, kv_indices[len(new_indices) :]] + ) + else: + req.prefix_indices = new_indices + req.last_node = new_last_node + + def pretty_print(self): + return self.tree.debug_print() diff --git a/test/srt/test_cpp_radix_cache.py b/test/srt/test_cpp_radix_cache.py new file mode 100644 index 000000000..cb2822b88 --- /dev/null +++ b/test/srt/test_cpp_radix_cache.py @@ -0,0 +1,47 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestCppRadixCache(CustomTestCase): + @classmethod + def setUpClass(cls): + os.environ["SGLANG_EXPERIMENTAL_CPP_RADIX_TREE"] = "1" + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + print(metrics) + self.assertGreaterEqual(metrics["score"], 0.65) + + +if __name__ == "__main__": + unittest.main()