Files
sglang/rust/src/tree.rs
2024-11-10 21:57:32 -08:00

186 lines
5.8 KiB
Rust

use std::collections::HashMap;
use std::mem;
#[derive(Debug)]
pub struct Node {
pub children: HashMap<u32, Node>, // the key is first id of the child because each child must have unique first id
pub ids: Vec<u32>,
pub count: u32,
}
#[derive(Debug)]
pub struct RadixTree {
pub root: Node,
}
fn common_prefix_len(a: &[u32], b: &[u32]) -> usize {
let mut i = 0;
while i < a.len() && i < b.len() && a[i] == b[i] {
i += 1;
}
i
}
impl Default for RadixTree {
fn default() -> Self {
Self::new()
}
}
impl RadixTree {
pub fn new() -> Self {
RadixTree {
root: Node {
children: HashMap::new(),
ids: Vec::new(),
count: 0,
},
}
}
pub fn insert(&mut self, input_ids: &[u32]) {
let mut curr = &mut self.root;
curr.count += 1;
let mut curr_idx = 0;
let input_ids_len = input_ids.len();
while curr_idx < input_ids_len {
let first_id = &input_ids[curr_idx];
// TODO: changing this get_mut causes error
if curr.children.contains_key(first_id) {
let child = curr.children.get_mut(first_id).unwrap();
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
if prefix_len == child.ids.len() {
// move curr to child
curr = child;
curr.count += 1;
curr_idx += prefix_len;
} else {
// split child
// [child]->... => [child]->[new child]->...
let new_child = Node {
// to avoid clone: replace child.children with default value (empty vector) and return the original value
children: mem::take(&mut child.children),
ids: child.ids[prefix_len..].to_vec(),
count: child.count,
};
child.ids = child.ids[..prefix_len].to_vec();
child.children = HashMap::new();
child.children.insert(new_child.ids[0], new_child);
curr = child;
curr.count += 1;
curr_idx += prefix_len;
}
} else {
// create new child
let new_child = Node {
children: HashMap::new(),
ids: input_ids[curr_idx..].to_vec(),
count: 0,
};
let first_id = new_child.ids[0];
curr.children.insert(first_id, new_child);
curr = curr.children.get_mut(&first_id).unwrap();
curr.count += 1;
curr_idx = input_ids_len;
}
}
}
pub fn prefix_match<'a>(&self, input_ids: &'a [u32]) -> &'a [u32] {
let mut curr = &self.root;
let mut curr_idx = 0;
let input_ids_len = input_ids.len();
while curr_idx < input_ids_len {
match curr.children.get(&input_ids[curr_idx]) {
Some(child) => {
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
if prefix_len == child.ids.len() {
curr_idx += prefix_len;
curr = child;
} else {
curr_idx += prefix_len;
break;
}
}
None => {
break;
}
}
}
&input_ids[..curr_idx]
}
pub fn delete(&mut self, input_ids: &[u32]) {
let mut curr = &mut self.root;
curr.count -= 1;
let mut curr_idx = 0;
let input_ids_len = input_ids.len();
while curr_idx < input_ids_len {
let first_id = &input_ids[curr_idx];
if curr.children.contains_key(first_id) {
let child = curr.children.get(first_id).unwrap();
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
if prefix_len == child.ids.len() {
if child.count == 1 {
// If count will become 0, remove the child
let child = curr.children.get_mut(first_id).unwrap();
child.count -= 1;
curr.children.remove(first_id);
break;
} else {
// Otherwise decrement count and continue
let child = curr.children.get_mut(first_id).unwrap();
child.count -= 1;
curr = child;
curr_idx += prefix_len;
}
} else {
panic!("No match found for {:?}", input_ids);
}
} else {
panic!("No match found for {:?}", input_ids);
}
}
}
// for debug
pub fn pretty_print(&self) {
println!("RadixTree:");
Self::print_node(&self.root, String::from(""));
}
fn print_node(node: &Node, prefix: String) {
// Print current node info with "count" word
println!("{}└── {:?} (count: {})", prefix, node.ids, node.count);
// Print children with proper prefixes
for (i, child) in node.children.values().enumerate() {
let is_last = i == node.children.len() - 1;
let child_prefix = if is_last {
format!("{} ", prefix) // Add space for last child
} else {
format!("{}", prefix) // Add vertical line for other children
};
Self::print_node(child, child_prefix);
}
}
}