2024-11-23 08:34:48 -08:00
use crate ::tree ::Tree ;
2024-11-06 00:02:02 -08:00
use actix_web ::http ::header ::{ HeaderValue , CONTENT_TYPE } ;
use actix_web ::{ HttpRequest , HttpResponse } ;
use bytes ::Bytes ;
2024-11-23 08:34:48 -08:00
use futures_util ::{ Stream , StreamExt , TryStreamExt } ;
2024-11-10 21:57:32 -08:00
use std ::collections ::HashMap ;
2024-10-28 09:49:48 -07:00
use std ::fmt ::Debug ;
2024-11-23 08:34:48 -08:00
use std ::hash ::Hash ;
use std ::pin ::Pin ;
2024-11-10 21:57:32 -08:00
use std ::sync ::atomic ::AtomicUsize ;
use std ::sync ::{ Arc , Mutex } ;
2024-11-23 08:34:48 -08:00
use std ::thread ;
use std ::time ::Duration ;
2024-10-28 09:49:48 -07:00
#[ derive(Debug) ]
2024-11-06 00:02:02 -08:00
pub enum Router {
RoundRobin {
worker_urls : Vec < String > ,
2024-11-10 21:57:32 -08:00
current_index : AtomicUsize ,
2024-11-06 00:02:02 -08:00
} ,
Random {
worker_urls : Vec < String > ,
} ,
2024-11-23 08:34:48 -08:00
CacheAware {
/*
Cache - Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution :
1. Cache - Aware Routing ( Approximate Tree )
2. Load Balancing ( Shortest Queue )
For each incoming request , the router chooses between these strategies :
- With probability P : Uses cache - aware routing
- With probability ( 1 - P ) : Uses load balancing
where P is configured via ` cache_routing_prob `
Strategy Details :
1. Cache - Aware Routing ( Approximate Tree )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
This strategy maintains an approximate radix tree for each worker based on request history ,
eliminating the need for direct cache state queries . The tree stores raw text characters
instead of token IDs to avoid tokenization overhead .
Process :
a . For each request , find the worker with the highest prefix match
b . If match rate > cache_threshold :
Route to the worker with highest match ( likely has relevant data cached )
c . If match rate ≤ cache_threshold :
Route to the worker with smallest tree size ( most available cache capacity )
d . Background maintenance :
Periodically evict least recently used leaf nodes to prevent memory overflow
2. Load Balancing ( Shortest Queue )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker for optimal load distribution .
Configuration Parameters :
- - - - - - - - - - - - - - - - - - - - - - - -
1. cache_routing_prob : ( float , 0.0 to 1.0 )
- 0.0 : Exclusively use load balancing
- 1.0 : Exclusively use cache - aware routing
- Between 0 - 1 : Probability of using cache - aware routing vs load balancing
2. cache_threshold : ( float , 0.0 to 1.0 )
Minimum prefix match ratio to use highest - match routing .
Below this threshold , routes to worker with most available cache space .
3. eviction_interval_secs : ( integer )
Interval between LRU eviction cycles for the approximate trees .
4. max_tree_size : ( integer )
Maximum nodes per tree . When exceeded , LRU leaf nodes are evicted
during the next eviction cycle .
* /
2024-11-10 21:57:32 -08:00
worker_urls : Vec < String > ,
2024-11-23 08:34:48 -08:00
tree : Arc < Mutex < Tree > > ,
running_queue : Arc < Mutex < HashMap < String , usize > > > ,
processed_queue : Arc < Mutex < HashMap < String , usize > > > ,
2024-11-10 21:57:32 -08:00
cache_threshold : f32 ,
2024-11-23 08:34:48 -08:00
cache_routing_prob : f32 ,
_eviction_thread : Option < thread ::JoinHandle < ( ) > > , // Store thread handle
2024-11-10 21:57:32 -08:00
} ,
}
2024-11-23 08:34:48 -08:00
#[ derive(Debug) ]
2024-11-10 21:57:32 -08:00
pub enum PolicyConfig {
RandomConfig ,
RoundRobinConfig ,
2024-11-23 08:34:48 -08:00
CacheAwareConfig {
2024-11-10 21:57:32 -08:00
cache_threshold : f32 ,
2024-11-23 08:34:48 -08:00
cache_routing_prob : f32 ,
eviction_interval_secs : u64 ,
max_tree_size : usize ,
2024-11-10 21:57:32 -08:00
} ,
}
2024-11-23 08:34:48 -08:00
fn get_text_from_request ( body : & Bytes ) -> String {
2024-11-10 21:57:32 -08:00
// 1. convert body to json
let json = serde_json ::from_slice ::< serde_json ::Value > ( body ) . unwrap ( ) ;
// 2. get the text field
let text = json . get ( " text " ) . and_then ( | t | t . as_str ( ) ) . unwrap_or ( " " ) ;
2024-11-23 08:34:48 -08:00
return text . to_string ( ) ;
2024-10-28 09:49:48 -07:00
}
2024-11-06 00:02:02 -08:00
impl Router {
2024-11-10 21:57:32 -08:00
pub fn new ( worker_urls : Vec < String > , policy_config : PolicyConfig ) -> Self {
match policy_config {
PolicyConfig ::RandomConfig = > Router ::Random { worker_urls } ,
PolicyConfig ::RoundRobinConfig = > Router ::RoundRobin {
2024-11-06 00:02:02 -08:00
worker_urls ,
current_index : std ::sync ::atomic ::AtomicUsize ::new ( 0 ) ,
} ,
2024-11-23 08:34:48 -08:00
PolicyConfig ::CacheAwareConfig {
2024-11-10 21:57:32 -08:00
cache_threshold ,
2024-11-23 08:34:48 -08:00
cache_routing_prob ,
eviction_interval_secs ,
max_tree_size ,
2024-11-10 21:57:32 -08:00
} = > {
2024-11-23 08:34:48 -08:00
let mut running_queue = HashMap ::new ( ) ;
for url in & worker_urls {
running_queue . insert ( url . clone ( ) , 0 ) ;
}
let mut processed_queue = HashMap ::new ( ) ;
for url in & worker_urls {
processed_queue . insert ( url . clone ( ) , 0 ) ;
}
let tree = Arc ::new ( Mutex ::new ( Tree ::new ( ) ) ) ;
let running_queue = Arc ::new ( Mutex ::new ( running_queue ) ) ;
let processed_queue = Arc ::new ( Mutex ::new ( processed_queue ) ) ;
// Create background eviction thread
let tree_clone = Arc ::clone ( & tree ) ;
let processed_queue_clone = Arc ::clone ( & processed_queue ) ;
let eviction_thread = thread ::spawn ( move | | {
loop {
// Sleep for the specified interval
thread ::sleep ( Duration ::from_secs ( eviction_interval_secs ) ) ;
let locked_tree_clone = tree_clone . lock ( ) . unwrap ( ) ;
// Run eviction
locked_tree_clone . evict_tenant_data ( max_tree_size ) ;
// Print the process queue
let locked_processed_queue = processed_queue_clone . lock ( ) . unwrap ( ) ;
println! ( " Processed Queue: {:?} " , locked_processed_queue ) ;
}
} ) ;
2024-11-10 21:57:32 -08:00
for url in & worker_urls {
2024-11-23 08:34:48 -08:00
tree . lock ( ) . unwrap ( ) . insert ( & " " . to_string ( ) , url ) ;
2024-11-10 21:57:32 -08:00
}
2024-11-23 08:34:48 -08:00
Router ::CacheAware {
2024-11-10 21:57:32 -08:00
worker_urls ,
2024-11-23 08:34:48 -08:00
tree ,
running_queue ,
processed_queue ,
2024-11-10 21:57:32 -08:00
cache_threshold ,
2024-11-23 08:34:48 -08:00
cache_routing_prob ,
_eviction_thread : Some ( eviction_thread ) ,
2024-11-10 21:57:32 -08:00
}
}
2024-10-28 09:49:48 -07:00
}
}
2024-11-06 00:02:02 -08:00
pub fn get_first ( & self ) -> Option < String > {
match self {
2024-11-10 21:57:32 -08:00
Router ::RoundRobin { worker_urls , .. }
| Router ::Random { worker_urls }
2024-11-23 08:34:48 -08:00
| Router ::CacheAware { worker_urls , .. } = > {
2024-11-06 00:02:02 -08:00
if worker_urls . is_empty ( ) {
None
} else {
Some ( worker_urls [ 0 ] . clone ( ) )
}
}
2024-10-28 09:49:48 -07:00
}
}
2024-11-06 00:02:02 -08:00
pub async fn dispatch (
& self ,
client : & reqwest ::Client ,
req : HttpRequest ,
body : Bytes ,
) -> HttpResponse {
2024-11-23 08:34:48 -08:00
let text = get_text_from_request ( & body ) ;
2024-11-10 21:57:32 -08:00
2024-11-06 00:02:02 -08:00
let worker_url = match self {
Router ::RoundRobin {
worker_urls ,
current_index ,
} = > {
2024-11-10 21:57:32 -08:00
let idx = current_index
2024-11-06 00:02:02 -08:00
. fetch_update (
std ::sync ::atomic ::Ordering ::SeqCst ,
std ::sync ::atomic ::Ordering ::SeqCst ,
| x | Some ( ( x + 1 ) % worker_urls . len ( ) ) ,
)
2024-11-10 21:57:32 -08:00
. unwrap ( ) ;
2024-10-28 09:49:48 -07:00
2024-11-10 21:57:32 -08:00
worker_urls [ idx ] . clone ( )
2024-11-06 00:02:02 -08:00
}
2024-11-10 21:57:32 -08:00
2024-11-06 00:02:02 -08:00
Router ::Random { worker_urls } = > {
2024-11-10 21:57:32 -08:00
worker_urls [ rand ::random ::< usize > ( ) % worker_urls . len ( ) ] . clone ( )
}
2024-11-23 08:34:48 -08:00
Router ::CacheAware {
2024-11-10 21:57:32 -08:00
worker_urls ,
2024-11-23 08:34:48 -08:00
tree ,
running_queue ,
processed_queue ,
2024-11-10 21:57:32 -08:00
cache_threshold ,
2024-11-23 08:34:48 -08:00
cache_routing_prob ,
2024-11-10 21:57:32 -08:00
..
} = > {
2024-11-23 08:34:48 -08:00
// even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time)
2024-11-10 21:57:32 -08:00
2024-11-23 08:34:48 -08:00
let mut tree = tree . lock ( ) . unwrap ( ) ;
let mut running_queue = running_queue . lock ( ) . unwrap ( ) ;
2024-11-10 21:57:32 -08:00
2024-11-23 08:34:48 -08:00
// Generate a random float between 0 and 1 for probability check
let sampled_p : f32 = rand ::random ( ) ;
2024-11-10 21:57:32 -08:00
2024-11-23 08:34:48 -08:00
let selected_url = if sampled_p < * cache_routing_prob {
// Cache-aware routing logic
let ( matched_text , matched_worker ) = tree . prefix_match ( & text ) ;
let matched_rate =
matched_text . chars ( ) . count ( ) as f32 / text . chars ( ) . count ( ) as f32 ;
2024-11-10 21:57:32 -08:00
2024-11-23 08:34:48 -08:00
if matched_rate > * cache_threshold {
matched_worker . to_string ( )
} else {
let m_map : HashMap < String , usize > = tree
. tenant_char_count
. iter ( )
. map ( | entry | ( entry . key ( ) . clone ( ) , * entry . value ( ) ) )
. collect ( ) ;
2024-11-10 21:57:32 -08:00
2024-11-23 08:34:48 -08:00
println! ( " map: {:?} , mmap: {:?} " , tree . get_tenant_char_count ( ) , m_map ) ;
tree . get_smallest_tenant ( )
2024-11-10 21:57:32 -08:00
}
2024-11-23 08:34:48 -08:00
} else {
// Shortest queue routing logic
running_queue
. iter ( )
. min_by_key ( | ( _url , & count ) | count )
. map ( | ( url , _ ) | url . clone ( ) )
. unwrap_or_else ( | | worker_urls [ 0 ] . clone ( ) )
} ;
2024-11-10 21:57:32 -08:00
2024-11-23 08:34:48 -08:00
// Update running queue
let count = running_queue . get_mut ( & selected_url ) . unwrap ( ) ;
* count + = 1 ;
// Update processed queue
let mut locked_processed_queue = processed_queue . lock ( ) . unwrap ( ) ;
let count = locked_processed_queue . get_mut ( & selected_url ) . unwrap ( ) ;
* count + = 1 ;
// Update tree with the new request
tree . insert ( & text , & selected_url ) ;
selected_url
2024-11-06 00:02:02 -08:00
}
} ;
2024-10-28 09:49:48 -07:00
2024-11-06 00:02:02 -08:00
let is_stream = serde_json ::from_slice ::< serde_json ::Value > ( & body )
. map ( | v | v . get ( " stream " ) . and_then ( | s | s . as_bool ( ) ) . unwrap_or ( false ) )
. unwrap_or ( false ) ;
2024-11-04 10:56:52 -08:00
2024-11-06 00:02:02 -08:00
let res = match client
2024-11-23 08:34:48 -08:00
. post ( format! ( " {} /generate " , worker_url . clone ( ) ) )
2024-11-06 00:02:02 -08:00
. header (
" Content-Type " ,
req . headers ( )
. get ( " Content-Type " )
. and_then ( | h | h . to_str ( ) . ok ( ) )
. unwrap_or ( " application/json " ) ,
)
. body ( body . to_vec ( ) )
. send ( )
. await
{
Ok ( res ) = > res ,
Err ( _ ) = > return HttpResponse ::InternalServerError ( ) . finish ( ) ,
} ;
2024-11-04 10:56:52 -08:00
2024-11-06 00:02:02 -08:00
let status = actix_web ::http ::StatusCode ::from_u16 ( res . status ( ) . as_u16 ( ) )
. unwrap_or ( actix_web ::http ::StatusCode ::INTERNAL_SERVER_ERROR ) ;
2024-10-28 09:49:48 -07:00
2024-11-06 00:02:02 -08:00
if ! is_stream {
2024-11-23 08:34:48 -08:00
// For non-streaming requests, get response first
let response = match res . bytes ( ) . await {
2024-11-06 00:02:02 -08:00
Ok ( body ) = > HttpResponse ::build ( status ) . body ( body . to_vec ( ) ) ,
Err ( _ ) = > HttpResponse ::InternalServerError ( ) . finish ( ) ,
2024-11-23 08:34:48 -08:00
} ;
// Then decrement running queue counter if using CacheAware
if let Router ::CacheAware { running_queue , .. } = self {
if let Ok ( mut queue ) = running_queue . lock ( ) {
if let Some ( count ) = queue . get_mut ( & worker_url ) {
* count = count . saturating_sub ( 1 ) ;
}
}
2024-11-06 00:02:02 -08:00
}
2024-11-23 08:34:48 -08:00
response
} else if let Router ::CacheAware { running_queue , .. } = self {
let running_queue = Arc ::clone ( running_queue ) ;
let worker_url = worker_url . clone ( ) ;
HttpResponse ::build ( status )
. insert_header ( ( CONTENT_TYPE , HeaderValue ::from_static ( " text/event-stream " ) ) )
. streaming (
res . bytes_stream ( )
. map_err ( | _ | {
actix_web ::error ::ErrorInternalServerError ( " Failed to read stream " )
} )
. inspect ( move | bytes | {
let bytes = bytes . as_ref ( ) . unwrap ( ) ;
if bytes
. as_ref ( )
. windows ( 12 )
. any ( | window | window = = b " data: [DONE] " )
{
let mut locked_queue = running_queue . lock ( ) . unwrap ( ) ;
let count = locked_queue . get_mut ( & worker_url ) . unwrap ( ) ;
* count = count . saturating_sub ( 1 ) ;
// print
// println!("streaming is done!!")
}
} ) ,
)
2024-11-06 00:02:02 -08:00
} else {
HttpResponse ::build ( status )
. insert_header ( ( CONTENT_TYPE , HeaderValue ::from_static ( " text/event-stream " ) ) )
. streaming ( res . bytes_stream ( ) . map_err ( | _ | {
2024-11-23 08:34:48 -08:00
actix_web ::error ::ErrorInternalServerError ( " Failed to read stream " )
2024-11-06 00:02:02 -08:00
} ) )
2024-10-28 09:49:48 -07:00
}
}
2024-11-04 10:56:52 -08:00
}