[router] Impl radix tree and set up CI (#1893)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
184
rust/src/tree.rs
Normal file
184
rust/src/tree.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
use std::collections::HashMap;
|
||||
use std::mem;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Node {
|
||||
pub children: HashMap<usize, Node>, // the key is first id of the child because each child must have unique first id
|
||||
pub ids: Vec<usize>,
|
||||
pub count: usize,
|
||||
}
|
||||
|
||||
pub struct RadixTree {
|
||||
pub root: Node,
|
||||
}
|
||||
|
||||
fn common_prefix_len(a: &[usize], b: &[usize]) -> usize {
|
||||
let mut i = 0;
|
||||
while i < a.len() && i < b.len() && a[i] == b[i] {
|
||||
i += 1;
|
||||
}
|
||||
i
|
||||
}
|
||||
|
||||
impl Default for RadixTree {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RadixTree {
|
||||
pub fn new() -> Self {
|
||||
RadixTree {
|
||||
root: Node {
|
||||
children: HashMap::new(),
|
||||
ids: Vec::new(),
|
||||
count: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, input_ids: &[usize]) {
|
||||
let mut curr = &mut self.root;
|
||||
curr.count += 1;
|
||||
|
||||
let mut curr_idx = 0;
|
||||
let input_ids_len = input_ids.len();
|
||||
|
||||
while curr_idx < input_ids_len {
|
||||
let first_id = &input_ids[curr_idx];
|
||||
// TODO: changing this get_mut causes error
|
||||
if curr.children.contains_key(first_id) {
|
||||
let child = curr.children.get_mut(first_id).unwrap();
|
||||
|
||||
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
|
||||
|
||||
if prefix_len == child.ids.len() {
|
||||
// move curr to child
|
||||
curr = child;
|
||||
curr.count += 1;
|
||||
curr_idx += prefix_len;
|
||||
} else {
|
||||
// split child
|
||||
// [child]->... => [child]->[new child]->...
|
||||
let new_child = Node {
|
||||
// to avoid clone: replace child.children with default value (empty vector) and return the original value
|
||||
children: mem::take(&mut child.children),
|
||||
ids: child.ids[prefix_len..].to_vec(),
|
||||
count: child.count,
|
||||
};
|
||||
|
||||
child.ids = child.ids[..prefix_len].to_vec();
|
||||
child.children = HashMap::new();
|
||||
child.children.insert(new_child.ids[0], new_child);
|
||||
|
||||
curr = child;
|
||||
curr.count += 1;
|
||||
curr_idx += prefix_len;
|
||||
}
|
||||
} else {
|
||||
// create new child
|
||||
let new_child = Node {
|
||||
children: HashMap::new(),
|
||||
ids: input_ids[curr_idx..].to_vec(),
|
||||
count: 0,
|
||||
};
|
||||
|
||||
let first_id = new_child.ids[0];
|
||||
curr.children.insert(first_id, new_child);
|
||||
|
||||
curr = curr.children.get_mut(&first_id).unwrap();
|
||||
curr.count += 1;
|
||||
curr_idx = input_ids_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prefix_match<'a>(&self, input_ids: &'a [usize]) -> &'a [usize] {
|
||||
let mut curr = &self.root;
|
||||
|
||||
let mut curr_idx = 0;
|
||||
let input_ids_len = input_ids.len();
|
||||
|
||||
while curr_idx < input_ids_len {
|
||||
match curr.children.get(&input_ids[curr_idx]) {
|
||||
Some(child) => {
|
||||
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
|
||||
|
||||
if prefix_len == child.ids.len() {
|
||||
curr_idx += prefix_len;
|
||||
curr = child;
|
||||
} else {
|
||||
curr_idx += prefix_len;
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
&input_ids[..curr_idx]
|
||||
}
|
||||
|
||||
pub fn delete(&mut self, input_ids: &[usize]) {
|
||||
let mut curr = &mut self.root;
|
||||
curr.count -= 1;
|
||||
|
||||
let mut curr_idx = 0;
|
||||
let input_ids_len = input_ids.len();
|
||||
|
||||
while curr_idx < input_ids_len {
|
||||
let first_id = &input_ids[curr_idx];
|
||||
|
||||
if curr.children.contains_key(first_id) {
|
||||
let child = curr.children.get(first_id).unwrap();
|
||||
|
||||
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
|
||||
|
||||
if prefix_len == child.ids.len() {
|
||||
if child.count == 1 {
|
||||
// If count will become 0, remove the child
|
||||
let child = curr.children.get_mut(first_id).unwrap();
|
||||
child.count -= 1;
|
||||
curr.children.remove(first_id);
|
||||
break;
|
||||
} else {
|
||||
// Otherwise decrement count and continue
|
||||
let child = curr.children.get_mut(first_id).unwrap();
|
||||
|
||||
child.count -= 1;
|
||||
curr = child;
|
||||
curr_idx += prefix_len;
|
||||
}
|
||||
} else {
|
||||
panic!("No match found for {:?}", input_ids);
|
||||
}
|
||||
} else {
|
||||
panic!("No match found for {:?}", input_ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for debug
|
||||
pub fn pretty_print(&self) {
|
||||
println!("RadixTree:");
|
||||
Self::print_node(&self.root, String::from(""));
|
||||
}
|
||||
|
||||
fn print_node(node: &Node, prefix: String) {
|
||||
// Print current node info with "count" word
|
||||
println!("{}└── {:?} (count: {})", prefix, node.ids, node.count);
|
||||
|
||||
// Print children with proper prefixes
|
||||
for (i, child) in node.children.values().enumerate() {
|
||||
let is_last = i == node.children.len() - 1;
|
||||
let child_prefix = if is_last {
|
||||
format!("{} ", prefix) // Add space for last child
|
||||
} else {
|
||||
format!("{}│ ", prefix) // Add vertical line for other children
|
||||
};
|
||||
Self::print_node(child, child_prefix);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user