From cbedd1db1d8bdde867efadf90b3c801dfe4e9964 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sat, 23 Nov 2024 08:34:48 -0800 Subject: [PATCH] [router] cache-aware load-balancing router v1 (#2114) --- .../multi_turn_chat/long_prompt_multi_turn.py | 53 +- python/sglang/bench_serving.py | 4 +- python/sglang/test/few_shot_gsm8k.py | 10 +- rust/Cargo.lock | 24 +- rust/Cargo.toml | 2 + rust/README.md | 3 + rust/demo.py | 10 - rust/dp_demo.py | 156 -- rust/py_src/sglang_router/launch_router.py | 204 +++ rust/py_src/sglang_router/launch_server.py | 178 +++ rust/py_src/sglang_router/router.py | 30 +- rust/src/lib.rs | 49 +- rust/src/main.rs | 71 +- rust/src/router.rs | 296 ++-- rust/src/server.rs | 1 + rust/src/tree.rs | 1343 +++++++++++++++-- rust/tests/test_tree.rs | 131 -- 17 files changed, 1963 insertions(+), 602 deletions(-) delete mode 100644 rust/demo.py delete mode 100644 rust/dp_demo.py create mode 100644 rust/py_src/sglang_router/launch_router.py create mode 100644 rust/py_src/sglang_router/launch_server.py delete mode 100644 rust/tests/test_tree.rs diff --git a/benchmark/multi_turn_chat/long_prompt_multi_turn.py b/benchmark/multi_turn_chat/long_prompt_multi_turn.py index c6fa67438..decd8a72f 100644 --- a/benchmark/multi_turn_chat/long_prompt_multi_turn.py +++ b/benchmark/multi_turn_chat/long_prompt_multi_turn.py @@ -1,21 +1,24 @@ import itertools import json +import os import random import string import threading import time from argparse import ArgumentParser +from pathlib import Path +from typing import Union + +from tqdm import tqdm import sglang as sgl -from sglang.srt.hf_transformers_utils import get_tokenize +from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) from sglang.utils import dump_state_text -random.seed(42) - def gen_prompt(tokenizer, token_num): all_available_tokens = list(tokenizer.get_vocab().values()) @@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num): return ret +def get_cache_path(args): + # Create cache directory under ~/.cache/sglang + cache_dir = Path.home() / ".cache" / "sglang" + + # Create a unique cache filename based on the arguments that affect generation + cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json" + return cache_dir / cache_key + + def gen_arguments(args, tokenizer): - multi_qas = [ - {"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []} - for _ in range(args.num_qa) - ] - for i in range(args.num_qa): + cache_path = get_cache_path(args) + + # Try to load from cache first + if cache_path.exists(): + print(f"Loading cached arguments from {cache_path}") + with open(cache_path, "r") as f: + return json.load(f) + + print("Generating new arguments...") + # First progress bar for system prompts + multi_qas = [] + for _ in tqdm(range(args.num_qa), desc="Generating system prompts"): + multi_qas.append( + {"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []} + ) + + # Nested progress bars for QA pairs + for i in tqdm(range(args.num_qa), desc="Generating QA pairs"): qas = multi_qas[i]["qas"] for j in range(args.turns): qas.append( @@ -38,6 +63,13 @@ def gen_arguments(args, tokenizer): "new_tokens": args.len_a, } ) + + # Save to cache + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(multi_qas, f) + print(f"Cached arguments saved to {cache_path}") + return multi_qas @@ -45,7 +77,7 @@ def gen_arguments(args, tokenizer): def multi_turns(s, system_prompt, qas): s += system_prompt - for qa in qas: + for i, qa in enumerate(qas): s += qa["prompt"] s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True) @@ -62,7 +94,7 @@ def main(args): multi_qas, temperature=0, backend=backend, - num_threads=args.parallel, + num_threads="auto", progress_bar=True, ) latency = time.time() - tic @@ -75,7 +107,6 @@ def main(args): value = { "task": "multi_turn_system_prompt_chat", "backend": args.backend, - "num_gpus": 1, "latency": round(latency, 3), "num_requests": args.num_qa, "num_turns": args.turns, diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 203d79fff..a1e43e4cd 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests( total_input_tokens = 0 total_output_tokens = 0 - for group_idx in range(num_groups): + for group_idx in tqdm(range(num_groups), desc="Generating system prompt"): system_prompt = system_prompts[group_idx] - for prompt_idx in range(prompts_per_group): + for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"): question = questions[group_idx * prompts_per_group + prompt_idx] full_prompt = f"{system_prompt}\n\n{question}" prompt_len = len(tokenizer.encode(full_prompt)) diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py index 1f2af1629..8e6572da6 100644 --- a/python/sglang/test/few_shot_gsm8k.py +++ b/python/sglang/test/few_shot_gsm8k.py @@ -48,9 +48,13 @@ def run_eval(args): # Select backend set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}")) - # Read data - url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" - filename = download_and_cache_file(url) + if args.data_path is None: + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + else: + filename = args.data_path + lines = list(read_jsonl(filename)) # Construct prompts diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 69b666df1..ac73bb855 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -591,6 +591,20 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deranged" version = "0.3.11" @@ -904,6 +918,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.1" @@ -1226,7 +1246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.1", ] [[package]] @@ -2097,7 +2117,9 @@ dependencies = [ "actix-web", "bytes", "clap", + "dashmap", "futures-util", + "http 1.1.0", "pyo3", "rand", "reqwest", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 8c4b56fd1..cb39daa90 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -24,6 +24,8 @@ futures-util = "0.3" serde_json = "1.0" pyo3 = { version = "0.22.5", features = ["extension-module"] } tokenizers = { version = "0.20.3", features = ["http"] } +dashmap = "6.1.0" +http = "1.1.0" [profile.release] lto = "thin" diff --git a/rust/README.md b/rust/README.md index 63d8a2317..51cee80f2 100644 --- a/rust/README.md +++ b/rust/README.md @@ -46,6 +46,9 @@ pip install #### Option B: Development Mode For development purposes, you can install the package in editable mode: + +Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance. + ```bash pip install -e . ``` diff --git a/rust/demo.py b/rust/demo.py deleted file mode 100644 index 6d83580d2..000000000 --- a/rust/demo.py +++ /dev/null @@ -1,10 +0,0 @@ -from sglang_router import PolicyType, Router - -router = Router( - worker_urls=[ - "http://localhost:30000", - "http://localhost:30001", - ] -) - -router.start() diff --git a/rust/dp_demo.py b/rust/dp_demo.py deleted file mode 100644 index 8b601e95a..000000000 --- a/rust/dp_demo.py +++ /dev/null @@ -1,156 +0,0 @@ -import argparse -import os -import signal -import subprocess -import sys -import time -from typing import Dict, List - -import requests -from sglang_router import PolicyType, Router - -# Global processes list for cleanup -_processes: List[subprocess.Popen] = [] - - -def cleanup_processes(signum=None, frame=None): - """Cleanup function to kill all worker processes.""" - print("\nCleaning up processes...") - for process in _processes: - try: - # Kill the entire process group - pgid = os.getpgid(process.pid) - os.killpg(pgid, signal.SIGKILL) - process.wait() - except: - pass - sys.exit(1) - - -# Register signal handlers -signal.signal(signal.SIGINT, cleanup_processes) -signal.signal(signal.SIGTERM, cleanup_processes) - - -def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="Launch SGLang Router Server") - parser.add_argument( - "--host", type=str, default="localhost", help="Host address to bind the server" - ) - parser.add_argument( - "--port", type=int, default=30000, help="Base port number for workers" - ) - parser.add_argument( - "--dp", - type=int, - default=2, - help="Number of worker processes (degree of parallelism)", - ) - parser.add_argument( - "--model-path", type=str, required=True, help="Path to the model" - ) - parser.add_argument( - "--local-tokenizer-path", - type=str, - required=True, - help="Path to the local tokenizer", - ) - return parser.parse_args() - - -def launch_workers(args) -> tuple[List[subprocess.Popen], List[str]]: - """Launch all worker processes concurrently using subprocess.""" - processes = [] - worker_urls = [] - - # Launch each worker process - for i in range(args.dp): - port = args.port + i - url = f"http://{args.host}:{port}" - worker_urls.append(url) - # TODO: replace this with launch_server, and move this file to sglang/ because it depends on sglang - # We don't - command = f"export CUDA_VISIBLE_DEVICES={i}; python -m sglang.launch_server --model-path {args.model_path} --host {args.host} --port {port}" - print(command) - process = subprocess.Popen(command, shell=True) - processes.append(process) - _processes.append(process) # Add to global list for cleanup - - return processes, worker_urls - - -def wait_for_healthy_workers(worker_urls: List[str], timeout: int = 300) -> bool: - """Block until all workers are healthy or timeout is reached.""" - start_time = time.time() - healthy_workers: Dict[str, bool] = {url: False for url in worker_urls} - - while time.time() - start_time < timeout: - print("checking healthiness...") - all_healthy = True - - for url in worker_urls: - if not healthy_workers[url]: # Only check workers that aren't healthy yet - try: - response = requests.get(f"{url}/health") - if response.status_code == 200: - print(f"Worker at {url} is healthy") - healthy_workers[url] = True - else: - all_healthy = False - except requests.RequestException: - all_healthy = False - - if all_healthy: - print("All workers are healthy!") - return True - - time.sleep(5) - - # If we get here, we've timed out - unhealthy_workers = [url for url, healthy in healthy_workers.items() if not healthy] - print(f"Timeout waiting for workers: {unhealthy_workers}") - return False - - -def main(): - """Main function to launch the router and workers.""" - args = parse_args() - processes = None - - try: - # Launch all workers concurrently - processes, worker_urls = launch_workers(args) - - # Block until all workers are healthy - if not wait_for_healthy_workers(worker_urls): - raise RuntimeError("Failed to start all workers") - - # Initialize and start the router - router = Router( - worker_urls=worker_urls, - policy=PolicyType.ApproxTree, - tokenizer_path=args.local_tokenizer_path, - ) - - print("Starting router...") - router.start() - - # Keep the main process running - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - print("\nShutting down...") - - except Exception as e: - print(f"Error: {e}") - finally: - # Cleanup: Kill all worker processes - if processes: - for process in processes: - process.kill() - - -if __name__ == "__main__": - main() diff --git a/rust/py_src/sglang_router/launch_router.py b/rust/py_src/sglang_router/launch_router.py new file mode 100644 index 000000000..b0726f2fa --- /dev/null +++ b/rust/py_src/sglang_router/launch_router.py @@ -0,0 +1,204 @@ +import argparse +import dataclasses +import sys +from typing import List, Optional + +from sglang_router import Router +from sglang_router_rs import PolicyType + + +@dataclasses.dataclass +class RouterArgs: + # Worker configuration + worker_urls: List[str] + host: str = "127.0.0.1" + port: int = 30000 + + # Routing policy + policy: str = "cache_aware" + cache_threshold: float = 0.5 + cache_routing_prob: float = 1.0 + eviction_interval: int = 60 + max_tree_size: int = 2**24 + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser, + use_router_prefix: bool = False, + exclude_host_port: bool = False, + ): + """ + Add router-specific arguments to an argument parser. + + Args: + parser: The argument parser to add arguments to + use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts + exclude_host_port: If True, don't add host and port arguments (used when inheriting from server) + """ + prefix = "router-" if use_router_prefix else "" + + # Worker configuration + if not exclude_host_port: + parser.add_argument( + "--host", + type=str, + default=RouterArgs.host, + help="Host address to bind the router server", + ) + parser.add_argument( + "--port", + type=int, + default=RouterArgs.port, + help="Port number to bind the router server", + ) + + parser.add_argument( + "--worker-urls", + type=str, + nargs="+", + help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)", + ) + + # Routing policy configuration + parser.add_argument( + f"--{prefix}policy", + type=str, + default=RouterArgs.policy, + choices=["random", "round_robin", "cache_aware"], + help="Load balancing policy to use", + ) + parser.add_argument( + f"--{prefix}cache-threshold", + type=float, + default=RouterArgs.cache_threshold, + help="Cache threshold (0.0-1.0) for cache-aware routing", + ) + parser.add_argument( + f"--{prefix}cache-routing-prob", + type=float, + default=RouterArgs.cache_routing_prob, + help="Probability of using cache-aware routing (0.0-1.0)", + ) + parser.add_argument( + f"--{prefix}eviction-interval", + type=int, + default=RouterArgs.eviction_interval, + help="Interval in seconds between cache eviction operations", + ) + parser.add_argument( + f"--{prefix}max-tree-size", + type=int, + default=RouterArgs.max_tree_size, + help="Maximum size of the approximation tree for cache-aware routing", + ) + + @classmethod + def from_cli_args( + cls, args: argparse.Namespace, use_router_prefix: bool = False + ) -> "RouterArgs": + """ + Create RouterArgs instance from parsed command line arguments. + + Args: + args: Parsed command line arguments + use_router_prefix: If True, look for arguments with 'router-' prefix + """ + prefix = "router_" if use_router_prefix else "" + return cls( + worker_urls=args.worker_urls, + host=args.host, + port=args.port, + policy=getattr(args, f"{prefix}policy"), + cache_threshold=getattr(args, f"{prefix}cache_threshold"), + cache_routing_prob=getattr(args, f"{prefix}cache_routing_prob"), + eviction_interval=getattr(args, f"{prefix}eviction_interval"), + max_tree_size=getattr(args, f"{prefix}max_tree_size"), + ) + + +def policy_from_str(policy_str: str) -> PolicyType: + """Convert policy string to PolicyType enum.""" + policy_map = { + "random": PolicyType.Random, + "round_robin": PolicyType.RoundRobin, + "cache_aware": PolicyType.CacheAware, + } + return policy_map[policy_str] + + +def launch_router(args: argparse.Namespace) -> Optional[Router]: + """ + Launch the SGLang router with the configuration from parsed arguments. + + Args: + args: Namespace object containing router configuration + Can be either raw argparse.Namespace or converted RouterArgs + + Returns: + Router instance if successful, None if failed + """ + try: + # Convert to RouterArgs if needed + if not isinstance(args, RouterArgs): + router_args = RouterArgs.from_cli_args(args) + else: + router_args = args + + router = Router( + worker_urls=router_args.worker_urls, + policy=policy_from_str(router_args.policy), + host=router_args.host, + port=router_args.port, + cache_threshold=router_args.cache_threshold, + cache_routing_prob=router_args.cache_routing_prob, + eviction_interval_secs=router_args.eviction_interval, + max_tree_size=router_args.max_tree_size, + ) + + router.start() + return router + + except Exception as e: + print(f"Error starting router: {e}", file=sys.stderr) + return None + + +class CustomHelpFormatter( + argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter +): + """Custom formatter that preserves both description formatting and shows defaults""" + + pass + + +def parse_router_args(args: List[str]) -> RouterArgs: + """Parse command line arguments and return RouterArgs instance.""" + parser = argparse.ArgumentParser( + description="""SGLang Router - High-performance request distribution across worker nodes + +Usage: +This launcher enables starting a router with individual worker instances. It is useful for +multi-node setups or when you want to start workers and router separately. + +Examples: + python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 + python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --cache-routing-prob 0.5 + + """, + formatter_class=CustomHelpFormatter, + ) + + RouterArgs.add_cli_args(parser, use_router_prefix=False) + return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False) + + +def main() -> None: + router_args = parse_router_args(sys.argv[1:]) + router = launch_router(router_args) + + if router is None: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/rust/py_src/sglang_router/launch_server.py b/rust/py_src/sglang_router/launch_server.py new file mode 100644 index 000000000..614c870ff --- /dev/null +++ b/rust/py_src/sglang_router/launch_server.py @@ -0,0 +1,178 @@ +import argparse +import copy +import multiprocessing as mp +import os +import signal +import sys +import time +from typing import List + +import requests +from sglang_router.launch_router import RouterArgs, launch_router + +from sglang.srt.server import launch_server +from sglang.srt.server_args import ServerArgs, prepare_server_args +from sglang.srt.utils import is_port_available +from sglang.utils import get_exception_traceback + + +# Create new process group +def run_server(server_args, dp_rank): + os.setpgrp() # Create new process group + + # Set DP_RANK environment variable + os.environ["DP_RANK"] = str(dp_rank) + + launch_server(server_args) + + +def launch_server_process( + server_args: ServerArgs, worker_port: int, dp_id: int +) -> mp.Process: + """Launch a single server process with the given args and port.""" + server_args = copy.deepcopy(server_args) + server_args.port = worker_port + server_args.base_gpu_id = dp_id * server_args.tp_size + server_args.dp_size = 1 + + proc = mp.Process(target=run_server, args=(server_args, dp_id)) + proc.start() + return proc + + +def cleanup_processes(processes: List[mp.Process]): + """Clean up all processes using process groups.""" + print("\nCleaning up processes...") + for proc in processes: + if proc.is_alive(): + try: + # Kill the entire process group + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + # Give processes some time to terminate gracefully + proc.join(timeout=3) + # If process is still alive, force kill + if proc.is_alive(): + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + except ProcessLookupError: + pass # Process already terminated + + +def setup_signal_handlers(cleanup_func): + """Setup handlers for various termination signals.""" + + def signal_handler(signum, frame): + cleanup_func() + sys.exit(1) + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + if hasattr(signal, "SIGQUIT"): + signal.signal(signal.SIGQUIT, signal_handler) + + +def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool: + """Wait for server to be healthy by checking /health endpoint.""" + start_time = time.time() + url = f"http://{host}:{port}/health" + + while time.time() - start_time < timeout: + try: + response = requests.get(url, timeout=5) + if response.status_code == 200: + return True + except requests.exceptions.RequestException: + pass + time.sleep(1) + return False + + +def find_available_ports(base_port: int, count: int) -> List[int]: + """Find consecutive available ports starting from base_port.""" + available_ports = [] + current_port = base_port + + while len(available_ports) < count: + if is_port_available(current_port): + available_ports.append(current_port) + current_port += 1 + + return available_ports + + +def main(): + # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes + mp.set_start_method("spawn") + + parser = argparse.ArgumentParser( + description="Launch SGLang router and server processes" + ) + + ServerArgs.add_cli_args(parser) + RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) + parser.add_argument( + "--router-dp-worker-base-port", + type=int, + default=31000, + help="Base port number for data parallel workers", + ) + + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) + + # Find available ports for workers + worker_ports = find_available_ports( + args.router_dp_worker_base_port, server_args.dp_size + ) + + # Start server processes + server_processes = [] + + try: + # Launch server processes + for i, worker_port in enumerate(worker_ports): + proc = launch_server_process(server_args, worker_port, i) + server_processes.append(proc) + + # Setup cleanup handler + setup_signal_handlers(lambda: cleanup_processes(server_processes)) + + # Wait for all servers to be healthy + all_healthy = True + for port in worker_ports: + if not wait_for_server_health(server_args.host, port): + print(f"Server on port {port} failed to become healthy") + all_healthy = False + break + + if not all_healthy: + print("Not all servers are healthy. Shutting down...") + cleanup_processes(server_processes) + sys.exit(1) + + print("All servers are healthy. Starting router...") + + # Update router args with worker URLs + router_args.worker_urls = [ + f"http://{server_args.host}:{port}" for port in worker_ports + ] + + # Start the router + router = launch_router(router_args) + + if router is None: + print("Failed to start router. Shutting down...") + cleanup_processes(server_processes) + sys.exit(1) + + except KeyboardInterrupt: + print("\nReceived shutdown signal...") + except Exception as e: + print(f"Error occurred: {e}") + print(get_exception_traceback()) + finally: + cleanup_processes(server_processes) + + +if __name__ == "__main__": + main() diff --git a/rust/py_src/sglang_router/router.py b/rust/py_src/sglang_router/router.py index 200027937..7965d8c02 100644 --- a/rust/py_src/sglang_router/router.py +++ b/rust/py_src/sglang_router/router.py @@ -9,16 +9,23 @@ class Router: A high-performance router for distributing requests across worker nodes. Args: - worker_urls: List of URLs for worker nodes that will handle requests + worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include + the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000']) policy: Load balancing policy to use. Options: - PolicyType.Random: Randomly select workers - PolicyType.RoundRobin: Distribute requests in round-robin fashion - - PolicyType.ApproxTree: Tree-based routing using tokenizer similarity - host: Host address to bind the router server - port: Port number to bind the router server - tokenizer_path: Path to tokenizer model file (required for ApproxTree policy) - cache_threshold: Caching threshold value between 0-1 - + - PolicyType.CacheAware: Distribute requests in cache-aware fashion + host: Host address to bind the router server. Default: '127.0.0.1' + port: Port number to bind the router server. Default: 3001 + cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker + if the match rate exceeds threshold, otherwise routes to the worker with the smallest + tree. Default: 0.5 + cache_routing_prob: Probability of using cache-aware routing (0.0-1.0). Default 1.0 for + full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven + workloads, use a lower value to better distribute requests + eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware + routing. Default: 60 + max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 """ def __init__( @@ -27,17 +34,20 @@ class Router: policy: PolicyType = PolicyType.RoundRobin, host: str = "127.0.0.1", port: int = 3001, - tokenizer_path: Optional[str] = None, cache_threshold: float = 0.50, + cache_routing_prob: float = 1.0, + eviction_interval_secs: int = 60, + max_tree_size: int = 2**24, ): - self._router = _Router( worker_urls=worker_urls, policy=policy, host=host, port=port, - tokenizer_path=tokenizer_path, cache_threshold=cache_threshold, + cache_routing_prob=cache_routing_prob, + eviction_interval_secs=eviction_interval_secs, + max_tree_size=max_tree_size, ) def start(self) -> None: diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 9f3fb6fc2..f9a8603a6 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,7 +1,6 @@ -// Python Binding use pyo3::prelude::*; pub mod router; -mod server; +pub mod server; pub mod tree; #[pyclass(eq)] @@ -9,7 +8,7 @@ pub mod tree; pub enum PolicyType { Random, RoundRobin, - ApproxTree, + CacheAware, } #[pyclass] @@ -18,8 +17,10 @@ struct Router { port: u16, worker_urls: Vec, policy: PolicyType, - tokenizer_path: Option, - cache_threshold: Option, + cache_threshold: f32, + cache_routing_prob: f32, + eviction_interval_secs: u64, + max_tree_size: usize, } #[pymethods] @@ -30,33 +31,30 @@ impl Router { policy = PolicyType::RoundRobin, host = String::from("127.0.0.1"), port = 3001, - tokenizer_path = None, - cache_threshold = Some(0.50) + cache_threshold = 0.50, + cache_routing_prob = 1.0, + eviction_interval_secs = 60, + max_tree_size = 2usize.pow(24) ))] fn new( worker_urls: Vec, policy: PolicyType, host: String, port: u16, - tokenizer_path: Option, - cache_threshold: Option, + cache_threshold: f32, + cache_routing_prob: f32, + eviction_interval_secs: u64, + max_tree_size: usize, ) -> PyResult { - // Validate required parameters for approx_tree policy - if matches!(policy, PolicyType::ApproxTree) { - if tokenizer_path.is_none() { - return Err(PyErr::new::( - "tokenizer_path is required for approx_tree policy", - )); - } - } - Ok(Router { host, port, worker_urls, policy, - tokenizer_path, cache_threshold, + cache_routing_prob, + eviction_interval_secs, + max_tree_size, }) } @@ -68,14 +66,11 @@ impl Router { let policy_config = match &self.policy { PolicyType::Random => router::PolicyConfig::RandomConfig, PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, - PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig { - tokenizer_path: self - .tokenizer_path - .clone() - .expect("tokenizer_path is required for approx_tree policy"), - cache_threshold: self - .cache_threshold - .expect("cache_threshold is required for approx_tree policy"), + PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { + cache_threshold: self.cache_threshold, + cache_routing_prob: self.cache_routing_prob, + eviction_interval_secs: self.eviction_interval_secs, + max_tree_size: self.max_tree_size, }, }; diff --git a/rust/src/main.rs b/rust/src/main.rs index f7c8943eb..4725a315f 100644 --- a/rust/src/main.rs +++ b/rust/src/main.rs @@ -1,18 +1,14 @@ // src/main.rs use clap::Parser; use clap::ValueEnum; -// declare child modules -mod router; -mod server; -mod tree; -use crate::router::PolicyConfig; +use sglang_router_rs::{router::PolicyConfig, server}; #[derive(Debug, Clone, ValueEnum)] pub enum PolicyType { Random, RoundRobin, - ApproxTree, + CacheAware, } #[derive(Parser, Debug)] @@ -21,44 +17,70 @@ struct Args { #[arg( long, default_value = "127.0.0.1", - help = "Host address to bind the server to" + help = "Host address to bind the router server to. Default: 127.0.0.1" )] host: String, - #[arg(long, default_value_t = 3001, help = "Port number to listen on")] + #[arg( + long, + default_value_t = 3001, + help = "Port number to bind the router server to. Default: 3001" + )] port: u16, #[arg( long, value_delimiter = ',', - help = "Comma-separated list of worker URLs to distribute requests to" + help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)" )] worker_urls: Vec, #[arg( long, - default_value_t = PolicyType::RoundRobin, + default_value_t = PolicyType::CacheAware, value_enum, - help = "Load balancing policy to use: random, round_robin, or approx_tree" + help = "Load balancing policy to use for request distribution:\n\ + - random: Randomly select workers\n\ + - round_robin: Distribute requests in round-robin fashion\n\ + - cache_aware: Distribute requests in cache-aware fashion\n" )] policy: PolicyType, #[arg( long, + default_value_t = 0.5, requires = "policy", - required_if_eq("policy", "approx_tree"), - help = "Path to the tokenizer file, required when using approx_tree policy" + required_if_eq("policy", "cache_aware"), + help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5" )] - tokenizer_path: Option, + cache_threshold: f32, #[arg( long, - default_value = "0.50", + default_value_t = 1.0, requires = "policy", - required_if_eq("policy", "approx_tree"), - help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker" + required_if_eq("policy", "cache_aware"), + help = "Probability of using cache-aware routing (0.0-1.0). Default 1.0 for full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven workloads, use a lower value to better distribute requests" )] - cache_threshold: Option, + cache_routing_prob: f32, + + #[arg( + long, + default_value_t = 60, + requires = "policy", + required_if_eq("policy", "cache_aware"), + help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60" + )] + eviction_interval_secs: u64, + + #[arg( + long, + default_value_t = 2usize.pow(24), + requires = "policy", + required_if_eq("policy", "cache_aware"), + help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24" + )] + max_tree_size: usize, } impl Args { @@ -66,14 +88,11 @@ impl Args { match self.policy { PolicyType::Random => PolicyConfig::RandomConfig, PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig, - PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig { - tokenizer_path: self - .tokenizer_path - .clone() - .expect("tokenizer_path is required for approx_tree policy"), - cache_threshold: self - .cache_threshold - .expect("cache_threshold is required for approx_tree policy"), + PolicyType::CacheAware => PolicyConfig::CacheAwareConfig { + cache_threshold: self.cache_threshold, + cache_routing_prob: self.cache_routing_prob, + eviction_interval_secs: self.eviction_interval_secs, + max_tree_size: self.max_tree_size, }, } } diff --git a/rust/src/router.rs b/rust/src/router.rs index 65ab8214e..64738cc57 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -1,13 +1,16 @@ -use crate::tree::RadixTree; +use crate::tree::Tree; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use bytes::Bytes; -use futures_util::TryStreamExt; +use futures_util::{Stream, StreamExt, TryStreamExt}; use std::collections::HashMap; use std::fmt::Debug; +use std::hash::Hash; +use std::pin::Pin; use std::sync::atomic::AtomicUsize; use std::sync::{Arc, Mutex}; -use tokenizers::tokenizer::Tokenizer; +use std::thread; +use std::time::Duration; #[derive(Debug)] pub enum Router { @@ -18,34 +21,88 @@ pub enum Router { Random { worker_urls: Vec, }, - ApproxTree { + CacheAware { + /* + Cache-Aware Load Balancing Router + + This router combines two strategies to optimize both cache utilization and request distribution: + + 1. Cache-Aware Routing (Approximate Tree) + 2. Load Balancing (Shortest Queue) + + For each incoming request, the router chooses between these strategies: + - With probability P: Uses cache-aware routing + - With probability (1-P): Uses load balancing + where P is configured via `cache_routing_prob` + + Strategy Details: + + 1. Cache-Aware Routing (Approximate Tree) + ------------------------------------------- + This strategy maintains an approximate radix tree for each worker based on request history, + eliminating the need for direct cache state queries. The tree stores raw text characters + instead of token IDs to avoid tokenization overhead. + + Process: + a. For each request, find the worker with the highest prefix match + b. If match rate > cache_threshold: + Route to the worker with highest match (likely has relevant data cached) + c. If match rate ≤ cache_threshold: + Route to the worker with smallest tree size (most available cache capacity) + d. Background maintenance: + Periodically evict least recently used leaf nodes to prevent memory overflow + + 2. Load Balancing (Shortest Queue) + ------------------------------------------- + This strategy tracks pending request counts per worker and routes new requests + to the least busy worker for optimal load distribution. + + Configuration Parameters: + ------------------------ + 1. cache_routing_prob: (float, 0.0 to 1.0) + - 0.0: Exclusively use load balancing + - 1.0: Exclusively use cache-aware routing + - Between 0-1: Probability of using cache-aware routing vs load balancing + + 2. cache_threshold: (float, 0.0 to 1.0) + Minimum prefix match ratio to use highest-match routing. + Below this threshold, routes to worker with most available cache space. + + 3. eviction_interval_secs: (integer) + Interval between LRU eviction cycles for the approximate trees. + + 4. max_tree_size: (integer) + Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted + during the next eviction cycle. + */ worker_urls: Vec, - // TODO: don't lock the whole tree - url_to_tree: Arc>>, - tokenizer: Tokenizer, - url_to_count: Arc>>, + tree: Arc>, + running_queue: Arc>>, + processed_queue: Arc>>, cache_threshold: f32, + cache_routing_prob: f32, + _eviction_thread: Option>, // Store thread handle }, } +#[derive(Debug)] pub enum PolicyConfig { RandomConfig, RoundRobinConfig, - ApproxTreeConfig { - tokenizer_path: String, + CacheAwareConfig { cache_threshold: f32, + cache_routing_prob: f32, + eviction_interval_secs: u64, + max_tree_size: usize, }, } -fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec { +fn get_text_from_request(body: &Bytes) -> String { // 1. convert body to json let json = serde_json::from_slice::(body).unwrap(); // 2. get the text field let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); - // 3. tokenize the text field - let tokens = tokenizer.encode(text, false).unwrap(); - - tokens.get_ids().to_vec() + return text.to_string(); } impl Router { @@ -56,25 +113,56 @@ impl Router { worker_urls, current_index: std::sync::atomic::AtomicUsize::new(0), }, - PolicyConfig::ApproxTreeConfig { - tokenizer_path, + PolicyConfig::CacheAwareConfig { cache_threshold, + cache_routing_prob, + eviction_interval_secs, + max_tree_size, } => { - let mut url_to_tree = HashMap::new(); - let mut url_to_count = HashMap::new(); - + let mut running_queue = HashMap::new(); for url in &worker_urls { - url_to_tree.insert(url.clone(), RadixTree::new()); - url_to_count.insert(url.clone(), 0); + running_queue.insert(url.clone(), 0); } - Router::ApproxTree { + let mut processed_queue = HashMap::new(); + for url in &worker_urls { + processed_queue.insert(url.clone(), 0); + } + + let tree = Arc::new(Mutex::new(Tree::new())); + let running_queue = Arc::new(Mutex::new(running_queue)); + let processed_queue = Arc::new(Mutex::new(processed_queue)); + + // Create background eviction thread + let tree_clone = Arc::clone(&tree); + let processed_queue_clone = Arc::clone(&processed_queue); + let eviction_thread = thread::spawn(move || { + loop { + // Sleep for the specified interval + thread::sleep(Duration::from_secs(eviction_interval_secs)); + + let locked_tree_clone = tree_clone.lock().unwrap(); + // Run eviction + locked_tree_clone.evict_tenant_data(max_tree_size); + + // Print the process queue + let locked_processed_queue = processed_queue_clone.lock().unwrap(); + println!("Processed Queue: {:?}", locked_processed_queue); + } + }); + + for url in &worker_urls { + tree.lock().unwrap().insert(&"".to_string(), url); + } + + Router::CacheAware { worker_urls, - url_to_tree: Arc::new(Mutex::new(url_to_tree)), - // TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file - tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(), - url_to_count: Arc::new(Mutex::new(url_to_count)), + tree, + running_queue, + processed_queue, cache_threshold, + cache_routing_prob, + _eviction_thread: Some(eviction_thread), } } } @@ -84,7 +172,7 @@ impl Router { match self { Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } - | Router::ApproxTree { worker_urls, .. } => { + | Router::CacheAware { worker_urls, .. } => { if worker_urls.is_empty() { None } else { @@ -100,10 +188,7 @@ impl Router { req: HttpRequest, body: Bytes, ) -> HttpResponse { - let mut input_ids: Vec = Vec::new(); - if let Router::ApproxTree { tokenizer, .. } = self { - input_ids = get_token_ids_from_request(&body, tokenizer); - } + let text = get_text_from_request(&body); let worker_url = match self { Router::RoundRobin { @@ -125,78 +210,73 @@ impl Router { worker_urls[rand::random::() % worker_urls.len()].clone() } - Router::ApproxTree { + Router::CacheAware { worker_urls, - url_to_tree, - url_to_count, + tree, + running_queue, + processed_queue, cache_threshold, + cache_routing_prob, .. } => { - // TODO: pipeline the locks. Release one earlier. + // even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time) - let mut max_matched_rate = 0.0; - let mut max_matched_idx = 0; + let mut tree = tree.lock().unwrap(); + let mut running_queue = running_queue.lock().unwrap(); - let locked_url_to_tree = url_to_tree.lock().unwrap(); + // Generate a random float between 0 and 1 for probability check + let sampled_p: f32 = rand::random(); - // 1. Find the highest matched worker - for (i, url) in worker_urls.iter().enumerate() { - let tree = locked_url_to_tree.get(url).unwrap(); - let matched = tree.prefix_match(&input_ids[..]).len(); - let matched_rate = matched as f32 / input_ids.len() as f32; + let selected_url = if sampled_p < *cache_routing_prob { + // Cache-aware routing logic + let (matched_text, matched_worker) = tree.prefix_match(&text); + let matched_rate = + matched_text.chars().count() as f32 / text.chars().count() as f32; - if matched_rate > max_matched_rate { - max_matched_rate = matched_rate; - max_matched_idx = i; + if matched_rate > *cache_threshold { + matched_worker.to_string() + } else { + let m_map: HashMap = tree + .tenant_char_count + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect(); + + println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map); + + tree.get_smallest_tenant() } - } - - // 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue - if max_matched_rate > *cache_threshold { - worker_urls[max_matched_idx].clone() } else { - // pick the shortest queue from url_to_count - let locked_url_to_count = url_to_count.lock().unwrap(); + // Shortest queue routing logic + running_queue + .iter() + .min_by_key(|(_url, &count)| count) + .map(|(url, _)| url.clone()) + .unwrap_or_else(|| worker_urls[0].clone()) + }; - let mut min_count = std::usize::MAX; - let mut min_count_id = 0; + // Update running queue + let count = running_queue.get_mut(&selected_url).unwrap(); + *count += 1; - for (i, url) in worker_urls.iter().enumerate() { - let count = locked_url_to_count.get(url).unwrap(); - if *count < min_count { - min_count = *count; - min_count_id = i; - } - } + // Update processed queue + let mut locked_processed_queue = processed_queue.lock().unwrap(); + let count = locked_processed_queue.get_mut(&selected_url).unwrap(); + *count += 1; - worker_urls[min_count_id].clone() - } + // Update tree with the new request + tree.insert(&text, &selected_url); + + selected_url } }; - if let Router::ApproxTree { - url_to_tree, - url_to_count, - .. - } = self - { - // Insert input_ids to the tree - let mut locked_url_to_tree = url_to_tree.lock().unwrap(); - let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap(); - selected_tree.insert(&input_ids[..]); - - let mut locked_url_to_count = url_to_count.lock().unwrap(); - let count = locked_url_to_count.get_mut(&worker_url).unwrap(); - *count += 1; - } - - // Check if client requested streaming let is_stream = serde_json::from_slice::(&body) .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .unwrap_or(false); let res = match client - .post(format!("{}/generate", worker_url)) + .post(format!("{}/generate", worker_url.clone())) .header( "Content-Type", req.headers() @@ -216,23 +296,53 @@ impl Router { .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); if !is_stream { - // TODO: do the correction on the tree based on the cached input_ids - if let Router::ApproxTree { url_to_count, .. } = self { - let mut locked_url_to_count = url_to_count.lock().unwrap(); - let count = locked_url_to_count.get_mut(&worker_url).unwrap(); - *count -= 1; - } - - match res.bytes().await { + // For non-streaming requests, get response first + let response = match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), Err(_) => HttpResponse::InternalServerError().finish(), + }; + + // Then decrement running queue counter if using CacheAware + if let Router::CacheAware { running_queue, .. } = self { + if let Ok(mut queue) = running_queue.lock() { + if let Some(count) = queue.get_mut(&worker_url) { + *count = count.saturating_sub(1); + } + } } + + response + } else if let Router::CacheAware { running_queue, .. } = self { + let running_queue = Arc::clone(running_queue); + let worker_url = worker_url.clone(); + + HttpResponse::build(status) + .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .streaming( + res.bytes_stream() + .map_err(|_| { + actix_web::error::ErrorInternalServerError("Failed to read stream") + }) + .inspect(move |bytes| { + let bytes = bytes.as_ref().unwrap(); + if bytes + .as_ref() + .windows(12) + .any(|window| window == b"data: [DONE]") + { + let mut locked_queue = running_queue.lock().unwrap(); + let count = locked_queue.get_mut(&worker_url).unwrap(); + *count = count.saturating_sub(1); + // print + // println!("streaming is done!!") + } + }), + ) } else { - // TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle HttpResponse::build(status) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .streaming(res.bytes_stream().map_err(|_| { - actix_web::error::ErrorInternalServerError("Failed to read string") + actix_web::error::ErrorInternalServerError("Failed to read stream") })) } } diff --git a/rust/src/server.rs b/rust/src/server.rs index 05a2f150c..51df65f97 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -76,6 +76,7 @@ pub async fn startup( ) -> std::io::Result<()> { println!("Starting server on {}:{}", host, port); println!("Worker URLs: {:?}", worker_urls); + println!("Policy Config: {:?}", policy_config); // Create client once with configuration let client = reqwest::Client::builder() diff --git a/rust/src/tree.rs b/rust/src/tree.rs index 2bcb84bef..516c991ec 100644 --- a/rust/src/tree.rs +++ b/rust/src/tree.rs @@ -1,185 +1,1264 @@ +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; +use rand::distributions::{Alphanumeric, DistString}; +use rand::thread_rng; +use std::cmp::min; +use std::cmp::Reverse; +use std::collections::BinaryHeap; use std::collections::HashMap; -use std::mem; +use std::sync::Arc; +use std::sync::RwLock; +use std::thread; +use std::time::Duration; +use std::time::{SystemTime, UNIX_EPOCH}; + +type NodeRef = Arc; #[derive(Debug)] -pub struct Node { - pub children: HashMap, // the key is first id of the child because each child must have unique first id - pub ids: Vec, - pub count: u32, +struct Node { + children: DashMap, + text: RwLock, + tenant_last_access_time: DashMap, + parent: RwLock>, } #[derive(Debug)] -pub struct RadixTree { - pub root: Node, +pub struct Tree { + root: NodeRef, + // TODO: Char Count per tenant + pub tenant_char_count: DashMap, } -fn common_prefix_len(a: &[u32], b: &[u32]) -> usize { +// For the heap + +struct EvictionEntry { + timestamp: u128, + tenant: String, + node: NodeRef, +} + +impl Eq for EvictionEntry {} + +impl PartialOrd for EvictionEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.timestamp.cmp(&other.timestamp)) + } +} + +impl Ord for EvictionEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.timestamp.cmp(&other.timestamp) + } +} + +impl PartialEq for EvictionEntry { + fn eq(&self, other: &Self) -> bool { + self.timestamp == other.timestamp + } +} + +// For char operations +// Note that in rust, `.len()` or slice is operated on the "byte" level. It causes issues for UTF-8 characters because one character might use multiple bytes. +// https://en.wikipedia.org/wiki/UTF-8 + +fn shared_prefix_count(a: &str, b: &str) -> usize { let mut i = 0; - while i < a.len() && i < b.len() && a[i] == b[i] { - i += 1; - } - i -} + let mut a_iter = a.chars(); + let mut b_iter = b.chars(); -impl Default for RadixTree { - fn default() -> Self { - Self::new() - } -} - -impl RadixTree { - pub fn new() -> Self { - RadixTree { - root: Node { - children: HashMap::new(), - ids: Vec::new(), - count: 0, - }, + loop { + match (a_iter.next(), b_iter.next()) { + (Some(a_char), Some(b_char)) if a_char == b_char => { + i += 1; + } + _ => break, } } - pub fn insert(&mut self, input_ids: &[u32]) { - let mut curr = &mut self.root; - curr.count += 1; + return i; +} +fn slice_by_chars(s: &str, start: usize, end: usize) -> String { + s.chars().skip(start).take(end - start).collect() +} + +impl Tree { + /* + Thread-safe multi tenant radix tree + + 1. Storing data for multiple tenants (the overlap of multiple radix tree) + 2. Node-level lock to enable concurrent acesss on nodes + 3. Leaf LRU eviction based on tenant access time + */ + + pub fn new() -> Self { + Tree { + root: Arc::new(Node { + children: DashMap::new(), + text: RwLock::new("".to_string()), + tenant_last_access_time: DashMap::new(), + parent: RwLock::new(None), + }), + tenant_char_count: DashMap::new(), + } + } + + pub fn insert(&self, text: &str, tenant: &str) { + // Insert text into tree with given tenant + + let mut curr = Arc::clone(&self.root); let mut curr_idx = 0; - let input_ids_len = input_ids.len(); - while curr_idx < input_ids_len { - let first_id = &input_ids[curr_idx]; - // TODO: changing this get_mut causes error - if curr.children.contains_key(first_id) { - let child = curr.children.get_mut(first_id).unwrap(); + let timestamp_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis(); - let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids); + curr.tenant_last_access_time + .insert(tenant.to_string(), timestamp_ms); - if prefix_len == child.ids.len() { - // move curr to child - curr = child; - curr.count += 1; - curr_idx += prefix_len; - } else { - // split child - // [child]->... => [child]->[new child]->... - let new_child = Node { - // to avoid clone: replace child.children with default value (empty vector) and return the original value - children: mem::take(&mut child.children), - ids: child.ids[prefix_len..].to_vec(), - count: child.count, - }; + self.tenant_char_count + .entry(tenant.to_string()) + .or_insert(0); - child.ids = child.ids[..prefix_len].to_vec(); - child.children = HashMap::new(); - child.children.insert(new_child.ids[0], new_child); + let mut prev = Arc::clone(&self.root); - curr = child; - curr.count += 1; - curr_idx += prefix_len; + let text_count = text.chars().count(); + + while curr_idx < text_count { + let first_char = text.chars().nth(curr_idx).unwrap(); + + curr = prev; + + // dashmap.entry locks the entry until the op is done + // if using contains_key + insert, there will be an issue that + // 1. "apple" and "app" entered at the same time + // 2. and get inserted to the dashmap concurrently, so only one is inserted + + match curr.children.entry(first_char) { + Entry::Vacant(entry) => { + /* + no matched + [curr] + becomes + [curr] => [new node] + */ + + let curr_text = slice_by_chars(text, curr_idx, text_count); + let curr_text_count = curr_text.chars().count(); + let new_node = Arc::new(Node { + children: DashMap::new(), + text: RwLock::new(curr_text), + tenant_last_access_time: DashMap::new(), + parent: RwLock::new(Some(Arc::clone(&curr))), + }); + + // Increment char count when creating new node with tenant + self.tenant_char_count + .entry(tenant.to_string()) + .and_modify(|count| *count += curr_text_count) + .or_insert(curr_text_count); + + new_node + .tenant_last_access_time + .insert(tenant.to_string(), timestamp_ms); + + entry.insert(Arc::clone(&new_node)); + + prev = Arc::clone(&new_node); + curr_idx = text_count; } - } else { - // create new child - let new_child = Node { - children: HashMap::new(), - ids: input_ids[curr_idx..].to_vec(), - count: 0, - }; - let first_id = new_child.ids[0]; - curr.children.insert(first_id, new_child); + Entry::Occupied(mut entry) => { + // matched + let matched_node = entry.get().clone(); - curr = curr.children.get_mut(&first_id).unwrap(); - curr.count += 1; - curr_idx = input_ids_len; + let matched_node_text = matched_node.text.read().unwrap().to_owned(); + let matched_node_text_count = matched_node_text.chars().count(); + + let curr_text = slice_by_chars(text, curr_idx, text_count); + let shared_count = shared_prefix_count(&matched_node_text, &curr_text); + + if shared_count < matched_node_text_count { + /* + split the matched node + [curr] -> [matched_node] => + becomes + [curr] -> [new_node] -> [contracted_matched_node] + */ + + let matched_text = slice_by_chars(&matched_node_text, 0, shared_count); + let contracted_text = slice_by_chars( + &matched_node_text, + shared_count, + matched_node_text_count, + ); + let matched_text_count = matched_text.chars().count(); + + let new_node = Arc::new(Node { + text: RwLock::new(matched_text), + children: DashMap::new(), + parent: RwLock::new(Some(Arc::clone(&curr))), + tenant_last_access_time: matched_node.tenant_last_access_time.clone(), + }); + + let first_new_char = contracted_text.chars().nth(0).unwrap(); + new_node + .children + .insert(first_new_char, Arc::clone(&matched_node)); + + entry.insert(Arc::clone(&new_node)); + + *matched_node.text.write().unwrap() = contracted_text; + *matched_node.parent.write().unwrap() = Some(Arc::clone(&new_node)); + + prev = Arc::clone(&new_node); + + // Increment char count for the tenant in the new split node + if !prev.tenant_last_access_time.contains_key(tenant) { + self.tenant_char_count + .entry(tenant.to_string()) + .and_modify(|count| *count += matched_text_count) + .or_insert(matched_text_count); + } + + prev.tenant_last_access_time + .insert(tenant.to_string(), timestamp_ms); + + curr_idx += shared_count; + } else { + // move to next node + prev = Arc::clone(&matched_node); + + // Increment char count when adding tenant to existing node + if !prev.tenant_last_access_time.contains_key(tenant) { + self.tenant_char_count + .entry(tenant.to_string()) + .and_modify(|count| *count += matched_node_text_count) + .or_insert(matched_node_text_count); + } + + prev.tenant_last_access_time + .insert(tenant.to_string(), timestamp_ms); + curr_idx += shared_count; + } + } } } } - pub fn prefix_match<'a>(&self, input_ids: &'a [u32]) -> &'a [u32] { - let mut curr = &self.root; - + pub fn prefix_match(&self, text: &str) -> (String, String) { + let mut curr = Arc::clone(&self.root); let mut curr_idx = 0; - let input_ids_len = input_ids.len(); - while curr_idx < input_ids_len { - match curr.children.get(&input_ids[curr_idx]) { - Some(child) => { - let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids); + let mut prev = Arc::clone(&self.root); + let text_count = text.chars().count(); - if prefix_len == child.ids.len() { - curr_idx += prefix_len; - curr = child; + while curr_idx < text_count { + let first_char = text.chars().nth(curr_idx).unwrap(); + let curr_text = slice_by_chars(text, curr_idx, text_count); + + curr = prev.clone(); + + match curr.children.entry(first_char) { + Entry::Occupied(entry) => { + let matched_node = entry.get().clone(); + let shared_count = + shared_prefix_count(&matched_node.text.read().unwrap(), &curr_text); + + let matched_node_text_count = matched_node.text.read().unwrap().chars().count(); + + if shared_count == matched_node_text_count { + // Full match with current node's text, continue to next node + curr_idx += shared_count; + prev = Arc::clone(&matched_node); } else { - curr_idx += prefix_len; + // Partial match, stop here + curr_idx += shared_count; + prev = Arc::clone(&matched_node); break; } } - None => { + Entry::Vacant(_) => { + // No match found, stop here break; } } } - &input_ids[..curr_idx] - } + curr = prev.clone(); - pub fn delete(&mut self, input_ids: &[u32]) { - let mut curr = &mut self.root; - curr.count -= 1; + // Select the first tenant (key in the map) + let tenant = curr + .tenant_last_access_time + .iter() + .next() + .map(|kv| kv.key().to_owned()) + .unwrap_or("empty".to_string()); - let mut curr_idx = 0; - let input_ids_len = input_ids.len(); + // Traverse from the curr node to the root and update the timestamp - while curr_idx < input_ids_len { - let first_id = &input_ids[curr_idx]; + let timestamp_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis(); - if curr.children.contains_key(first_id) { - let child = curr.children.get(first_id).unwrap(); - - let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids); - - if prefix_len == child.ids.len() { - if child.count == 1 { - // If count will become 0, remove the child - let child = curr.children.get_mut(first_id).unwrap(); - child.count -= 1; - curr.children.remove(first_id); - break; - } else { - // Otherwise decrement count and continue - let child = curr.children.get_mut(first_id).unwrap(); - - child.count -= 1; - curr = child; - curr_idx += prefix_len; - } - } else { - panic!("No match found for {:?}", input_ids); - } - } else { - panic!("No match found for {:?}", input_ids); + if !tenant.eq("empty") { + let mut current_node = Some(curr); + while let Some(node) = current_node { + node.tenant_last_access_time + .insert(tenant.clone(), timestamp_ms); + current_node = node.parent.read().unwrap().clone(); } } + + let ret_text = slice_by_chars(text, 0, curr_idx); + (ret_text, tenant) } - // for debug - pub fn pretty_print(&self) { - println!("RadixTree:"); - Self::print_node(&self.root, String::from("")); + fn leaf_of(node: &NodeRef) -> Vec { + /* + Return the list of tenants if it's a leaf for the tenant + */ + let mut candidates: HashMap = node + .tenant_last_access_time + .iter() + .map(|entry| (entry.key().clone(), true)) + .collect(); + + for child in node.children.iter() { + for tenant in child.value().tenant_last_access_time.iter() { + candidates.insert(tenant.key().clone(), false); + } + } + + candidates + .into_iter() + .filter(|(_, is_leaf)| *is_leaf) + .map(|(tenant, _)| tenant) + .collect() } - fn print_node(node: &Node, prefix: String) { - // Print current node info with "count" word - println!("{}└── {:?} (count: {})", prefix, node.ids, node.count); + pub fn evict_tenant_data(&self, max_size: usize) { + // Calculate used size and collect leaves + let mut stack = vec![Arc::clone(&self.root)]; + let mut used_size_per_tenant: HashMap = HashMap::new(); + let mut pq = BinaryHeap::new(); - // Print children with proper prefixes - for (i, child) in node.children.values().enumerate() { - let is_last = i == node.children.len() - 1; - let child_prefix = if is_last { - format!("{} ", prefix) // Add space for last child - } else { - format!("{}│ ", prefix) // Add vertical line for other children - }; - Self::print_node(child, child_prefix); + while let Some(curr) = stack.pop() { + for tenant in curr.tenant_last_access_time.iter() { + let size = used_size_per_tenant + .entry(tenant.key().clone()) + .or_insert(0); + *size += curr.text.read().unwrap().chars().count(); + } + + for child in curr.children.iter() { + stack.push(Arc::clone(child.value())); + } + + // Add leaves to priority queue + for tenant in Tree::leaf_of(&curr) { + if let Some(timestamp) = curr.tenant_last_access_time.get(&tenant) { + pq.push(Reverse(EvictionEntry { + timestamp: *timestamp, + tenant: tenant.clone(), + node: Arc::clone(&curr), + })); + } + } + } + + println!("Before eviction - Used size per tenant:"); + for (tenant, size) in &used_size_per_tenant { + println!("Tenant: {}, Size: {}", tenant, size); + } + + // Process eviction + while let Some(Reverse(entry)) = pq.pop() { + let EvictionEntry { tenant, node, .. } = entry; + + if let Some(&used_size) = used_size_per_tenant.get(&tenant) { + if used_size <= max_size { + continue; + } + + // Update used size + if let Some(size) = used_size_per_tenant.get_mut(&tenant) { + *size -= node.text.read().unwrap().chars().count(); + } + + // Decrement when removing tenant from node + if node.tenant_last_access_time.contains_key(&tenant) { + self.tenant_char_count + .entry(tenant.clone()) + .and_modify(|count| { + if *count > 0 { + *count -= node.text.read().unwrap().chars().count(); + } + }); + } + + // Remove tenant from node + node.tenant_last_access_time.remove(&tenant); + + // Remove empty nodes + if node.children.is_empty() && node.tenant_last_access_time.is_empty() { + if let Some(parent) = node.parent.write().unwrap().as_ref() { + let first_char = node.text.read().unwrap().chars().next().unwrap(); + parent.children.remove(&first_char); + } + } + + // Add parent to queue if it becomes a leaf + if let Some(parent) = node.parent.read().unwrap().as_ref() { + if Tree::leaf_of(parent).contains(&tenant) { + if let Some(timestamp) = parent.tenant_last_access_time.get(&tenant) { + pq.push(Reverse(EvictionEntry { + timestamp: *timestamp, + tenant: tenant.clone(), + node: Arc::clone(parent), + })); + } + } + } + } + } + + println!("\nAfter eviction - Used size per tenant:"); + for (tenant, size) in &used_size_per_tenant { + println!("Tenant: {}, Size: {}", tenant, size); } } + + pub fn get_tenant_char_count(&self) -> HashMap { + self.tenant_char_count + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect() + } + + pub fn get_smallest_tenant(&self) -> String { + // Return a placeholder if there are no tenants + if self.tenant_char_count.is_empty() { + return "empty".to_string(); + } + + // Find the tenant with minimum char count + let mut min_tenant = None; + let mut min_count = usize::MAX; + + for entry in self.tenant_char_count.iter() { + let tenant = entry.key(); + let count = *entry.value(); + + if count < min_count { + min_count = count; + min_tenant = Some(tenant.clone()); + } + } + + // Return the found tenant or "empty" if somehow none was found + min_tenant.unwrap_or_else(|| "empty".to_string()) + } + + pub fn get_used_size_per_tenant(&self) -> HashMap { + // perform a DFS to traverse all nodes and calculate the total size used by each tenant + + let mut used_size_per_tenant: HashMap = HashMap::new(); + let mut stack = vec![Arc::clone(&self.root)]; + + while let Some(curr) = stack.pop() { + let text_count = curr.text.read().unwrap().chars().count(); + + for tenant in curr.tenant_last_access_time.iter() { + let size = used_size_per_tenant + .entry(tenant.key().clone()) + .or_insert(0); + *size += text_count; + } + + for child in curr.children.iter() { + stack.push(Arc::clone(child.value())); + } + } + + used_size_per_tenant + } + + fn node_to_string(node: &NodeRef, prefix: &str, is_last: bool) -> String { + let mut result = String::new(); + + // Add prefix and branch character + result.push_str(prefix); + result.push_str(if is_last { "└── " } else { "├── " }); + + // Add node text + let node_text = node.text.read().unwrap(); + result.push_str(&format!("'{}' [", node_text)); + + // Add tenant information with timestamps + let mut tenant_info = Vec::new(); + for entry in node.tenant_last_access_time.iter() { + let tenant_id = entry.key(); + let timestamp_ms = entry.value(); + + // Convert milliseconds to seconds and remaining milliseconds + let seconds = (timestamp_ms / 1000) as u64; + let millis = (timestamp_ms % 1000) as u32; + + // Create SystemTime from Unix timestamp + let system_time = UNIX_EPOCH + Duration::from_secs(seconds); + + // Format time as HH:MM:SS.mmm + let datetime = system_time.duration_since(UNIX_EPOCH).unwrap(); + let hours = (datetime.as_secs() % 86400) / 3600; + let minutes = (datetime.as_secs() % 3600) / 60; + let seconds = datetime.as_secs() % 60; + + tenant_info.push(format!( + "{} | {:02}:{:02}:{:02}.{:03}", + tenant_id, hours, minutes, seconds, millis + )); + } + + result.push_str(&tenant_info.join(", ")); + result.push_str("]\n"); + + // Process children + let children: Vec<_> = node.children.iter().collect(); + let child_count = children.len(); + + for (i, entry) in children.iter().enumerate() { + let is_last_child = i == child_count - 1; + let new_prefix = format!("{}{}", prefix, if is_last { " " } else { "│ " }); + + result.push_str(&Tree::node_to_string( + entry.value(), + &new_prefix, + is_last_child, + )); + } + + result + } + + pub fn pretty_print(&self) { + if self.root.children.is_empty() { + return; + } + + let mut result = String::new(); + let children: Vec<_> = self.root.children.iter().collect(); + let child_count = children.len(); + + for (i, entry) in children.iter().enumerate() { + let is_last = i == child_count - 1; + result.push_str(&Tree::node_to_string(entry.value(), "", is_last)); + } + + println!("{result}"); + + return; + } +} + +// Unit tests +#[cfg(test)] +mod tests { + use std::time::Instant; + + use rand::Rng; + + use super::*; + + #[test] + fn test_get_smallest_tenant() { + let tree = Tree::new(); + + // Test empty tree + assert_eq!(tree.get_smallest_tenant(), "empty"); + + // Insert data for tenant1 - "ap" + "icot" = 6 chars + tree.insert("ap", "tenant1"); + tree.insert("icot", "tenant1"); + + // Insert data for tenant2 - "cat" = 3 chars + tree.insert("cat", "tenant2"); + + // Test - tenant2 should be smallest with 3 chars vs 6 chars + assert_eq!( + tree.get_smallest_tenant(), + "tenant2", + "Expected tenant2 to be smallest with 3 characters" + ); + + // Insert overlapping data for tenant3 and tenant4 to test equal counts + // tenant3: "do" = 2 chars + // tenant4: "hi" = 2 chars + tree.insert("do", "tenant3"); + tree.insert("hi", "tenant4"); + + // Test - should return either tenant3 or tenant4 (both have 2 chars) + let smallest = tree.get_smallest_tenant(); + assert!( + smallest == "tenant3" || smallest == "tenant4", + "Expected either tenant3 or tenant4 (both have 2 characters), got {}", + smallest + ); + + // Add more text to tenant4 to make it larger + tree.insert("hello", "tenant4"); // Now tenant4 has "hi" + "hello" = 6 chars + + // Now tenant3 should be smallest (2 chars vs 6 chars for tenant4) + assert_eq!( + tree.get_smallest_tenant(), + "tenant3", + "Expected tenant3 to be smallest with 2 characters" + ); + + // Test eviction + tree.evict_tenant_data(3); // This should evict tenants with more than 3 chars + + let post_eviction_smallest = tree.get_smallest_tenant(); + println!("Smallest tenant after eviction: {}", post_eviction_smallest); + } + + #[test] + fn test_tenant_char_count() { + let tree = Tree::new(); + + // Phase 1: Initial insertions + tree.insert("apple", "tenant1"); + tree.insert("apricot", "tenant1"); + tree.insert("banana", "tenant1"); + tree.insert("amplify", "tenant2"); + tree.insert("application", "tenant2"); + + let computed_sizes = tree.get_used_size_per_tenant(); + let maintained_counts: HashMap = tree + .tenant_char_count + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect(); + + println!("Phase 1 - Maintained vs Computed counts:"); + println!( + "Maintained: {:?}\nComputed: {:?}", + maintained_counts, computed_sizes + ); + assert_eq!( + maintained_counts, computed_sizes, + "Phase 1: Initial insertions" + ); + + // Phase 2: Additional insertions + tree.insert("apartment", "tenant1"); + tree.insert("appetite", "tenant2"); + tree.insert("ball", "tenant1"); + tree.insert("box", "tenant2"); + + let computed_sizes = tree.get_used_size_per_tenant(); + let maintained_counts: HashMap = tree + .tenant_char_count + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect(); + + println!("Phase 2 - Maintained vs Computed counts:"); + println!( + "Maintained: {:?}\nComputed: {:?}", + maintained_counts, computed_sizes + ); + assert_eq!( + maintained_counts, computed_sizes, + "Phase 2: Additional insertions" + ); + + // Phase 3: Overlapping insertions + tree.insert("zebra", "tenant1"); + tree.insert("zebra", "tenant2"); + tree.insert("zero", "tenant1"); + tree.insert("zero", "tenant2"); + + let computed_sizes = tree.get_used_size_per_tenant(); + let maintained_counts: HashMap = tree + .tenant_char_count + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect(); + + println!("Phase 3 - Maintained vs Computed counts:"); + println!( + "Maintained: {:?}\nComputed: {:?}", + maintained_counts, computed_sizes + ); + assert_eq!( + maintained_counts, computed_sizes, + "Phase 3: Overlapping insertions" + ); + + // Phase 4: Eviction test + tree.evict_tenant_data(10); + + let computed_sizes = tree.get_used_size_per_tenant(); + let maintained_counts: HashMap = tree + .tenant_char_count + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect(); + + println!("Phase 4 - Maintained vs Computed counts:"); + println!( + "Maintained: {:?}\nComputed: {:?}", + maintained_counts, computed_sizes + ); + assert_eq!(maintained_counts, computed_sizes, "Phase 4: After eviction"); + } + + fn random_string(len: usize) -> String { + Alphanumeric.sample_string(&mut thread_rng(), len) + } + + #[test] + fn test_cold_start() { + let tree = Tree::new(); + + let (matched_text, tenant) = tree.prefix_match("hello"); + + assert_eq!(matched_text, ""); + assert_eq!(tenant, "empty"); + } + + #[test] + fn test_exact_match_seq() { + let tree = Tree::new(); + tree.insert("hello", "tenant1"); + tree.pretty_print(); + tree.insert("apple", "tenant2"); + tree.pretty_print(); + tree.insert("banana", "tenant3"); + tree.pretty_print(); + + let (matched_text, tenant) = tree.prefix_match("hello"); + assert_eq!(matched_text, "hello"); + assert_eq!(tenant, "tenant1"); + + let (matched_text, tenant) = tree.prefix_match("apple"); + assert_eq!(matched_text, "apple"); + assert_eq!(tenant, "tenant2"); + + let (matched_text, tenant) = tree.prefix_match("banana"); + assert_eq!(matched_text, "banana"); + assert_eq!(tenant, "tenant3"); + } + + #[test] + fn test_exact_match_concurrent() { + let tree = Arc::new(Tree::new()); + + // spawn 3 threads for insert + let tree_clone = Arc::clone(&tree); + + let texts = vec!["hello", "apple", "banana"]; + let tenants = vec!["tenant1", "tenant2", "tenant3"]; + + let mut handles = vec![]; + + for i in 0..3 { + let tree_clone = Arc::clone(&tree_clone); + let text = texts[i]; + let tenant = tenants[i]; + + let handle = thread::spawn(move || { + tree_clone.insert(text, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + + // spawn 3 threads for match + let mut handles = vec![]; + + let tree_clone = Arc::clone(&tree); + + for i in 0..3 { + let tree_clone = Arc::clone(&tree_clone); + let text = texts[i]; + let tenant = tenants[i]; + + let handle = thread::spawn(move || { + let (matched_text, matched_tenant) = tree_clone.prefix_match(text); + assert_eq!(matched_text, text); + assert_eq!(matched_tenant, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_partial_match_concurrent() { + let tree = Arc::new(Tree::new()); + + // spawn 3 threads for insert + let tree_clone = Arc::clone(&tree); + + let texts = vec!["apple", "apabc", "acbdeds"]; + + let mut handles = vec![]; + + for i in 0..3 { + let tree_clone = Arc::clone(&tree_clone); + let text = texts[i]; + let tenant = "tenant0"; + + let handle = thread::spawn(move || { + tree_clone.insert(text, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + + // spawn 3 threads for match + let mut handles = vec![]; + + let tree_clone = Arc::clone(&tree); + + for i in 0..3 { + let tree_clone = Arc::clone(&tree_clone); + let text = texts[i]; + let tenant = "tenant0"; + + let handle = thread::spawn(move || { + let (matched_text, matched_tenant) = tree_clone.prefix_match(text); + assert_eq!(matched_text, text); + assert_eq!(matched_tenant, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_group_prefix_insert_match_concurrent() { + let prefix = vec![ + "Clock strikes midnight, I'm still wide awake", + "Got dreams bigger than these city lights", + "Time waits for no one, gotta make my move", + "Started from the bottom, that's no metaphor", + ]; + let suffix = vec![ + "Got too much to prove, ain't got time to lose", + "History in the making, yeah, you can't erase this", + ]; + let tree = Arc::new(Tree::new()); + + let mut handles = vec![]; + + for i in 0..prefix.len() { + for j in 0..suffix.len() { + let tree_clone = Arc::clone(&tree); + let text = format!("{} {}", prefix[i], suffix[j]); + let tenant = format!("tenant{}", i); + + let handle = thread::spawn(move || { + tree_clone.insert(&text, &tenant); + }); + + handles.push(handle); + } + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + + tree.pretty_print(); + + // check matching using multi threads + + let mut handles = vec![]; + + for i in 0..prefix.len() { + let tree_clone = Arc::clone(&tree); + let text = prefix[i]; + + let handle = thread::spawn(move || { + let (matched_text, matched_tenant) = tree_clone.prefix_match(text); + let tenant = format!("tenant{}", i); + assert_eq!(matched_text, text); + assert_eq!(matched_tenant, tenant); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_mixed_concurrent_insert_match() { + // ensure it does not deadlock instead of doing correctness check + + let prefix = vec![ + "Clock strikes midnight, I'm still wide awake", + "Got dreams bigger than these city lights", + "Time waits for no one, gotta make my move", + "Started from the bottom, that's no metaphor", + ]; + let suffix = vec![ + "Got too much to prove, ain't got time to lose", + "History in the making, yeah, you can't erase this", + ]; + let tree = Arc::new(Tree::new()); + + let mut handles = vec![]; + + for i in 0..prefix.len() { + for j in 0..suffix.len() { + let tree_clone = Arc::clone(&tree); + let text = format!("{} {}", prefix[i], suffix[j]); + let tenant = format!("tenant{}", i); + + let handle = thread::spawn(move || { + tree_clone.insert(&text, &tenant); + }); + + handles.push(handle); + } + } + + // check matching using multi threads + + for i in 0..prefix.len() { + let tree_clone = Arc::clone(&tree); + let text = prefix[i]; + + let handle = thread::spawn(move || { + let (matched_text, matched_tenant) = tree_clone.prefix_match(text); + }); + + handles.push(handle); + } + + // wait + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_utf8_split_seq() { + // The string should be indexed and splitted by a utf-8 value basis instead of byte basis + // use .chars() to get the iterator of the utf-8 value + let tree = Arc::new(Tree::new()); + + let test_pairs = vec![ + ("你好嗎", "tenant1"), + ("你好喔", "tenant2"), + ("你心情好嗎", "tenant3"), + ]; + + // Insert sequentially + for i in 0..test_pairs.len() { + let text = test_pairs[i].0; + let tenant = test_pairs[i].1; + tree.insert(text, tenant); + } + + tree.pretty_print(); + + // Test sequentially + + for i in 0..test_pairs.len() { + let (matched_text, matched_tenant) = tree.prefix_match(test_pairs[i].0); + assert_eq!(matched_text, test_pairs[i].0); + assert_eq!(matched_tenant, test_pairs[i].1); + } + } + + #[test] + fn test_utf8_split_concurrent() { + let tree = Arc::new(Tree::new()); + + let test_pairs = vec![ + ("你好嗎", "tenant1"), + ("你好喔", "tenant2"), + ("你心情好嗎", "tenant3"), + ]; + + // Create multiple threads for insertion + let mut handles = vec![]; + + for i in 0..test_pairs.len() { + let tree_clone = Arc::clone(&tree); + let text = test_pairs[i].0.to_string(); + let tenant = test_pairs[i].1.to_string(); + + let handle = thread::spawn(move || { + tree_clone.insert(&text, &tenant); + }); + + handles.push(handle); + } + + // Wait for all insertions to complete + for handle in handles { + handle.join().unwrap(); + } + + tree.pretty_print(); + + // Create multiple threads for matching + let mut handles = vec![]; + + for i in 0..test_pairs.len() { + let tree_clone = Arc::clone(&tree); + let text = test_pairs[i].0.to_string(); + let tenant = test_pairs[i].1.to_string(); + + let handle = thread::spawn(move || { + let (matched_text, matched_tenant) = tree_clone.prefix_match(&text); + assert_eq!(matched_text, text); + assert_eq!(matched_tenant, tenant); + }); + + handles.push(handle); + } + + // Wait for all matches to complete + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_simple_eviction() { + let tree = Tree::new(); + let max_size = 5; + + // Insert strings for both tenants + tree.insert("hello", "tenant1"); // size 5 + + tree.insert("hello", "tenant2"); // size 5 + thread::sleep(Duration::from_millis(10)); + tree.insert("world", "tenant2"); // size 5, total for tenant2 = 10 + + tree.pretty_print(); + + // Verify initial sizes + let sizes_before = tree.get_used_size_per_tenant(); + assert_eq!(sizes_before.get("tenant1").unwrap(), &5); // "hello" = 5 + assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10 + + // Evict - should remove "hello" from tenant2 as it's the oldest + tree.evict_tenant_data(max_size); + + tree.pretty_print(); + + // Verify sizes after eviction + let sizes_after = tree.get_used_size_per_tenant(); + assert_eq!(sizes_after.get("tenant1").unwrap(), &5); // Should be unchanged + assert_eq!(sizes_after.get("tenant2").unwrap(), &5); // Only "world" remains + + // Verify "world" remains for tenant2 + let (matched, tenant) = tree.prefix_match("world"); + assert_eq!(matched, "world"); + assert_eq!(tenant, "tenant2"); + } + + #[test] + fn test_advanced_eviction() { + let tree = Tree::new(); + + // Set limits for each tenant + let max_size: usize = 100; + + // Define prefixes + let prefixes = vec!["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]; + + // Insert strings with shared prefixes + for i in 0..100 { + for (j, prefix) in prefixes.iter().enumerate() { + let random_suffix = random_string(10); + let text = format!("{}{}", prefix, random_suffix); + let tenant = format!("tenant{}", j + 1); + tree.insert(&text, &tenant); + } + } + + // Perform eviction + tree.evict_tenant_data(max_size); + + // Check sizes after eviction + let sizes_after = tree.get_used_size_per_tenant(); + // Verify all tenants are under their size limits + for (tenant, &size) in sizes_after.iter() { + assert!( + size <= max_size, + "Tenant {} exceeds size limit. Current size: {}, Limit: {}", + tenant, + size, + max_size + ); + } + } + + #[test] + fn test_concurrent_operations_with_eviction() { + // Ensure eviction works fine with concurrent insert and match operations for a given period + + let tree = Arc::new(Tree::new()); + let mut handles = vec![]; + let test_duration = Duration::from_secs(10); + let start_time = Instant::now(); + let max_size = 100; // Single max size for all tenants + + // Spawn eviction thread + { + let tree = Arc::clone(&tree); + let handle = thread::spawn(move || { + while start_time.elapsed() < test_duration { + // Run eviction + tree.evict_tenant_data(max_size); + + // Sleep for 5 seconds + thread::sleep(Duration::from_secs(5)); + } + }); + handles.push(handle); + } + + // Spawn 4 worker threads + for thread_id in 0..4 { + let tree = Arc::clone(&tree); + let handle = thread::spawn(move || { + let mut rng = rand::thread_rng(); + let tenant = format!("tenant{}", thread_id + 1); + let prefix = format!("prefix{}", thread_id); + + while start_time.elapsed() < test_duration { + // Random decision: match or insert (70% match, 30% insert) + if rng.gen_bool(0.7) { + // Perform match operation + let random_len = rng.gen_range(3..10); + let search_str = format!("{}{}", prefix, random_string(random_len)); + let (matched, _) = tree.prefix_match(&search_str); + } else { + // Perform insert operation + let random_len = rng.gen_range(5..15); + let insert_str = format!("{}{}", prefix, random_string(random_len)); + tree.insert(&insert_str, &tenant); + // println!("Thread {} inserted: {}", thread_id, insert_str); + } + + // Small random sleep to vary timing + thread::sleep(Duration::from_millis(rng.gen_range(10..100))); + } + }); + handles.push(handle); + } + + // Wait for all threads to complete + for handle in handles { + handle.join().unwrap(); + } + + // final eviction + tree.evict_tenant_data(max_size); + + // Final size check + let final_sizes = tree.get_used_size_per_tenant(); + println!("Final sizes after test completion: {:?}", final_sizes); + + // Verify all tenants are under limit + for (_, &size) in final_sizes.iter() { + assert!( + size <= max_size, + "Tenant exceeds size limit. Final size: {}, Limit: {}", + size, + max_size + ); + } + } + + #[test] + fn test_leaf_of() { + let tree = Tree::new(); + + // Single node + tree.insert("hello", "tenant1"); + let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap()); + assert_eq!(leaves, vec!["tenant1"]); + + // Node with multiple tenants + tree.insert("hello", "tenant2"); + let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap()); + assert_eq!(leaves.len(), 2); + assert!(leaves.contains(&"tenant1".to_string())); + assert!(leaves.contains(&"tenant2".to_string())); + + // Non-leaf node + tree.insert("hi", "tenant1"); + let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap()); + assert!(leaves.is_empty()); + } + + #[test] + fn test_get_used_size_per_tenant() { + let tree = Tree::new(); + + // Single tenant + tree.insert("hello", "tenant1"); + tree.insert("world", "tenant1"); + let sizes = tree.get_used_size_per_tenant(); + + tree.pretty_print(); + println!("{:?}", sizes); + assert_eq!(sizes.get("tenant1").unwrap(), &10); // "hello" + "world" + + // Multiple tenants sharing nodes + tree.insert("hello", "tenant2"); + tree.insert("help", "tenant2"); + let sizes = tree.get_used_size_per_tenant(); + + tree.pretty_print(); + println!("{:?}", sizes); + assert_eq!(sizes.get("tenant1").unwrap(), &10); + assert_eq!(sizes.get("tenant2").unwrap(), &6); // "hello" + "p" + + // UTF-8 characters + tree.insert("你好", "tenant3"); + let sizes = tree.get_used_size_per_tenant(); + tree.pretty_print(); + println!("{:?}", sizes); + assert_eq!(sizes.get("tenant3").unwrap(), &2); // 2 Chinese characters + + tree.pretty_print(); + } } diff --git a/rust/tests/test_tree.rs b/rust/tests/test_tree.rs deleted file mode 100644 index ed0b85e0b..000000000 --- a/rust/tests/test_tree.rs +++ /dev/null @@ -1,131 +0,0 @@ -use sglang_router_rs::tree::RadixTree; - -#[test] -fn test_new_tree() { - let tree = RadixTree::new(); - assert_eq!(tree.root.count, 0); - assert!(tree.root.children.is_empty()); - assert!(tree.root.ids.is_empty()); -} - -#[test] -fn test_single_insertion() { - let mut tree = RadixTree::new(); - tree.insert(&[1, 2, 3]); - - assert_eq!(tree.root.count, 1); - assert_eq!(tree.root.children.len(), 1); - assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]); - assert_eq!(tree.root.children[&1].count, 1); -} - -#[test] -fn test_multiple_insertions_no_split() { - let mut tree = RadixTree::new(); - tree.insert(&[1, 2, 3]); - tree.insert(&[4, 5, 6]); - - assert_eq!(tree.root.count, 2); - assert_eq!(tree.root.children.len(), 2); - assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]); - assert_eq!(tree.root.children[&4].ids, vec![4, 5, 6]); -} - -#[test] -fn test_insertion_with_split() { - let mut tree = RadixTree::new(); - tree.insert(&[1, 2, 3, 4]); - tree.insert(&[1, 2, 5, 6]); - - assert_eq!(tree.root.count, 2); - assert_eq!(tree.root.children.len(), 1); - assert_eq!(tree.root.children[&1].ids, vec![1, 2]); - assert_eq!(tree.root.children[&1].children.len(), 2); - assert_eq!(tree.root.children[&1].children[&3].ids, vec![3, 4]); - assert_eq!(tree.root.children[&1].children[&5].ids, vec![5, 6]); -} - -#[test] -fn test_prefix_match_exact() { - let mut tree = RadixTree::new(); - tree.insert(&[1, 2, 3, 4]); - - assert_eq!(tree.prefix_match(&[1, 2, 3, 4]), &[1, 2, 3, 4]); -} - -#[test] -fn test_prefix_match_partial() { - let mut tree = RadixTree::new(); - tree.insert(&[1, 2, 3, 4]); - - assert_eq!(tree.prefix_match(&[1, 2, 3, 5]), &[1, 2, 3]); - assert_eq!(tree.prefix_match(&[1, 2, 5]), &[1, 2]); - assert_eq!(tree.prefix_match(&[1, 5]), &[1]); -} - -#[test] -fn test_prefix_match_no_match() { - let mut tree = RadixTree::new(); - tree.insert(&[1, 2, 3, 4]); - let empty_slices: &[u32] = &[]; - assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices); -} - -#[test] -fn test_delete_leaf() { - let mut tree = RadixTree::new(); - tree.insert(&[1, 2, 3]); - tree.delete(&[1, 2, 3]); - - assert_eq!(tree.root.count, 0); - assert_eq!(tree.root.children.len(), 0); -} - -#[test] -fn test_delete_with_siblings() { - let mut tree = RadixTree::new(); - tree.insert(&[1, 2, 3]); - tree.insert(&[1, 2, 4]); - tree.delete(&[1, 2, 3]); - - assert_eq!(tree.root.count, 1); - assert_eq!(tree.root.children[&1].children[&4].ids, vec![4]); -} - -#[test] -fn test_multiple_operations() { - let mut tree = RadixTree::new(); - - // Insert several paths - tree.insert(&[1, 2, 3]); - tree.insert(&[1, 2, 4]); - tree.insert(&[1, 5, 6]); - - // Verify structure - assert_eq!(tree.root.count, 3); - assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2, 3]); - assert_eq!(tree.prefix_match(&[1, 2, 4]), &[1, 2, 4]); - assert_eq!(tree.prefix_match(&[1, 5, 6]), &[1, 5, 6]); - - // Delete and verify - tree.delete(&[1, 2, 3]); - assert_eq!(tree.root.count, 2); - assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2]); // Now only matches prefix -} - -#[test] -#[should_panic(expected = "No match found")] -fn test_delete_nonexistent() { - let mut tree = RadixTree::new(); - tree.insert(&[1, 2, 3]); - tree.delete(&[4, 5, 6]); // Should panic -} - -#[test] -fn test_empty_input() { - let mut tree = RadixTree::new(); - let empty_slice: &[u32] = &[]; - tree.insert(empty_slice); - assert_eq!(tree.prefix_match(empty_slice), empty_slice); - tree.delete(empty_slice); // Should not panic -}