2025-06-18 11:28:15 -07:00
use crate ::pd_router ::PDRouter ;
use crate ::pd_types ::PDSelectionPolicy ;
2024-11-23 08:34:48 -08:00
use crate ::tree ::Tree ;
2025-05-24 22:28:15 -07:00
use ::metrics ::{ counter , gauge , histogram } ;
2024-11-06 00:02:02 -08:00
use actix_web ::http ::header ::{ HeaderValue , CONTENT_TYPE } ;
use actix_web ::{ HttpRequest , HttpResponse } ;
2024-11-26 15:00:41 -08:00
use futures_util ::{ StreamExt , TryStreamExt } ;
2024-11-10 21:57:32 -08:00
use std ::collections ::HashMap ;
2024-10-28 09:49:48 -07:00
use std ::fmt ::Debug ;
2024-11-10 21:57:32 -08:00
use std ::sync ::atomic ::AtomicUsize ;
2024-12-06 01:17:04 -08:00
use std ::sync ::{ Arc , Mutex , RwLock } ;
2024-11-23 08:34:48 -08:00
use std ::thread ;
use std ::time ::Duration ;
2025-05-24 22:28:15 -07:00
use std ::time ::Instant ;
2024-12-07 15:39:54 -08:00
use tokio ;
2025-04-29 11:26:38 -07:00
use tracing ::{ debug , error , info , warn } ;
2024-10-28 09:49:48 -07:00
2025-06-18 11:28:15 -07:00
pub fn copy_request_headers ( req : & HttpRequest ) -> Vec < ( String , String ) > {
2025-01-23 20:30:31 -08:00
req . headers ( )
. iter ( )
. filter_map ( | ( name , value ) | {
value
. to_str ( )
. ok ( )
. map ( | v | ( name . to_string ( ) , v . to_string ( ) ) )
} )
. collect ( )
}
2024-10-28 09:49:48 -07:00
#[ derive(Debug) ]
2024-11-06 00:02:02 -08:00
pub enum Router {
RoundRobin {
2024-12-06 01:17:04 -08:00
worker_urls : Arc < RwLock < Vec < String > > > ,
2024-11-10 21:57:32 -08:00
current_index : AtomicUsize ,
2025-01-20 12:45:13 -08:00
timeout_secs : u64 ,
2025-01-20 14:36:54 -08:00
interval_secs : u64 ,
2024-11-06 00:02:02 -08:00
} ,
Random {
2024-12-06 01:17:04 -08:00
worker_urls : Arc < RwLock < Vec < String > > > ,
2025-01-20 12:45:13 -08:00
timeout_secs : u64 ,
2025-01-20 14:36:54 -08:00
interval_secs : u64 ,
2024-11-06 00:02:02 -08:00
} ,
2025-06-18 11:28:15 -07:00
PrefillDecode {
pd_router : Arc < PDRouter > ,
} ,
2024-11-23 08:34:48 -08:00
CacheAware {
/*
2024-11-24 23:17:11 -08:00
Cache - Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution :
1. Cache - Aware Routing ( Approximate Tree )
2. Load Balancing ( Shortest Queue with Balance Thresholds )
The router dynamically switches between these strategies based on load conditions :
- Uses load balancing when the system is imbalanced
- Uses cache - aware routing when the system is balanced
A system is considered imbalanced if both conditions are met :
1. ( max - min ) > abs_threshold
2. max > rel_threshold * min
Strategy Details :
1. Cache - Aware Routing ( Approximate Tree )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
This strategy maintains an approximate radix tree for each worker based on request history ,
eliminating the need for direct cache state queries . The tree stores raw text characters
instead of token IDs to avoid tokenization overhead .
Process :
a . For each request , find the worker with the highest prefix match
b . If match rate > cache_threshold :
Route to the worker with highest match ( likely has relevant data cached )
c . If match rate ≤ cache_threshold :
Route to the worker with smallest tree size ( most available cache capacity )
d . Background maintenance :
Periodically evict least recently used leaf nodes to prevent memory overflow
2. Load Balancing ( Shortest Queue )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker when the system is detected to be imbalanced .
Configuration Parameters :
- - - - - - - - - - - - - - - - - - - - - - - -
1. cache_threshold : ( float , 0.0 to 1.0 )
Minimum prefix match ratio to use highest - match routing .
Below this threshold , routes to worker with most available cache space .
2. balance_abs_threshold : ( integer )
Absolute difference threshold for load imbalance detection .
System is potentially imbalanced if ( max_load - min_load ) > abs_threshold
3. balance_rel_threshold : ( float )
Relative ratio threshold for load imbalance detection .
System is potentially imbalanced if max_load > min_load * rel_threshold
Used in conjunction with abs_threshold to determine final imbalance state .
4. eviction_interval_secs : ( integer )
Interval between LRU eviction cycles for the approximate trees .
5. max_tree_size : ( integer )
Maximum nodes per tree . When exceeded , LRU leaf nodes are evicted
during the next eviction cycle .
2024-11-23 08:34:48 -08:00
* /
2024-12-06 01:17:04 -08:00
worker_urls : Arc < RwLock < Vec < String > > > ,
2024-11-23 08:34:48 -08:00
tree : Arc < Mutex < Tree > > ,
running_queue : Arc < Mutex < HashMap < String , usize > > > ,
processed_queue : Arc < Mutex < HashMap < String , usize > > > ,
2024-11-10 21:57:32 -08:00
cache_threshold : f32 ,
2024-11-24 23:17:11 -08:00
balance_abs_threshold : usize ,
balance_rel_threshold : f32 ,
2025-01-20 12:45:13 -08:00
timeout_secs : u64 ,
2025-01-20 14:36:54 -08:00
interval_secs : u64 ,
2024-11-24 23:17:11 -08:00
_eviction_thread : Option < thread ::JoinHandle < ( ) > > ,
2024-11-10 21:57:32 -08:00
} ,
}
2024-12-08 17:17:37 -08:00
#[ derive(Debug, Clone) ]
2024-11-10 21:57:32 -08:00
pub enum PolicyConfig {
2025-01-20 12:45:13 -08:00
RandomConfig {
timeout_secs : u64 ,
2025-01-20 14:36:54 -08:00
interval_secs : u64 ,
2025-01-20 12:45:13 -08:00
} ,
RoundRobinConfig {
timeout_secs : u64 ,
2025-01-20 14:36:54 -08:00
interval_secs : u64 ,
2025-01-20 12:45:13 -08:00
} ,
2024-11-23 08:34:48 -08:00
CacheAwareConfig {
2024-11-10 21:57:32 -08:00
cache_threshold : f32 ,
2024-11-24 23:17:11 -08:00
balance_abs_threshold : usize ,
balance_rel_threshold : f32 ,
2024-11-23 08:34:48 -08:00
eviction_interval_secs : u64 ,
max_tree_size : usize ,
2025-01-20 12:45:13 -08:00
timeout_secs : u64 ,
2025-01-20 14:36:54 -08:00
interval_secs : u64 ,
2024-11-10 21:57:32 -08:00
} ,
2025-06-18 11:28:15 -07:00
PrefillDecodeConfig {
selection_policy : PDSelectionPolicy ,
prefill_urls : Vec < ( String , Option < u16 > ) > , // (url, bootstrap_port)
decode_urls : Vec < String > ,
timeout_secs : u64 ,
interval_secs : u64 ,
} ,
2024-11-10 21:57:32 -08:00
}
2024-11-06 00:02:02 -08:00
impl Router {
2024-12-08 17:17:37 -08:00
pub fn new ( worker_urls : Vec < String > , policy_config : PolicyConfig ) -> Result < Self , String > {
2025-05-24 22:28:15 -07:00
// Update active workers gauge
gauge! ( " sgl_router_active_workers " ) . set ( worker_urls . len ( ) as f64 ) ;
2025-01-20 14:36:54 -08:00
// Get timeout and interval from policy config
let ( timeout_secs , interval_secs ) = match & policy_config {
PolicyConfig ::RandomConfig {
timeout_secs ,
interval_secs ,
} = > ( * timeout_secs , * interval_secs ) ,
PolicyConfig ::RoundRobinConfig {
timeout_secs ,
interval_secs ,
} = > ( * timeout_secs , * interval_secs ) ,
PolicyConfig ::CacheAwareConfig {
timeout_secs ,
interval_secs ,
..
} = > ( * timeout_secs , * interval_secs ) ,
2025-06-18 11:28:15 -07:00
PolicyConfig ::PrefillDecodeConfig {
timeout_secs ,
interval_secs ,
..
} = > ( * timeout_secs , * interval_secs ) ,
2025-01-20 12:45:13 -08:00
} ;
2025-06-18 11:28:15 -07:00
// For PrefillDecode, we need to handle workers differently
match & policy_config {
PolicyConfig ::PrefillDecodeConfig { .. } = > {
// PD mode doesn't use the worker_urls parameter
// We'll validate PD workers separately
}
_ = > {
// Wait until all workers are healthy for regular modes
2025-06-29 23:41:34 +02:00
let worker_urls = worker_urls . clone ( ) ;
std ::thread ::spawn ( move | | {
Self ::wait_for_healthy_workers ( & worker_urls , timeout_secs , interval_secs )
} )
. join ( )
. map_err ( | e | {
error! ( " Health-check thread panicked: {:?} " , e ) ;
format! ( " Health-check thread panicked: {e:?} " )
} ) ? ? ;
2025-06-18 11:28:15 -07:00
}
}
2024-12-08 17:17:37 -08:00
// Create router based on policy...
Ok ( match policy_config {
2025-01-20 14:36:54 -08:00
PolicyConfig ::RandomConfig {
timeout_secs ,
interval_secs ,
} = > Router ::Random {
2024-12-06 01:17:04 -08:00
worker_urls : Arc ::new ( RwLock ::new ( worker_urls ) ) ,
2025-01-20 12:45:13 -08:00
timeout_secs ,
2025-01-20 14:36:54 -08:00
interval_secs ,
2024-12-06 01:17:04 -08:00
} ,
2025-01-20 14:36:54 -08:00
PolicyConfig ::RoundRobinConfig {
timeout_secs ,
interval_secs ,
} = > Router ::RoundRobin {
2024-12-06 01:17:04 -08:00
worker_urls : Arc ::new ( RwLock ::new ( worker_urls ) ) ,
2024-11-06 00:02:02 -08:00
current_index : std ::sync ::atomic ::AtomicUsize ::new ( 0 ) ,
2025-01-20 12:45:13 -08:00
timeout_secs ,
2025-01-20 14:36:54 -08:00
interval_secs ,
2024-11-06 00:02:02 -08:00
} ,
2024-11-23 08:34:48 -08:00
PolicyConfig ::CacheAwareConfig {
2024-11-10 21:57:32 -08:00
cache_threshold ,
2024-11-24 23:17:11 -08:00
balance_abs_threshold ,
balance_rel_threshold ,
2024-11-23 08:34:48 -08:00
eviction_interval_secs ,
max_tree_size ,
2025-01-20 12:45:13 -08:00
timeout_secs ,
2025-01-20 14:36:54 -08:00
interval_secs ,
2024-11-10 21:57:32 -08:00
} = > {
2024-11-23 08:34:48 -08:00
let mut running_queue = HashMap ::new ( ) ;
for url in & worker_urls {
running_queue . insert ( url . clone ( ) , 0 ) ;
}
let mut processed_queue = HashMap ::new ( ) ;
for url in & worker_urls {
processed_queue . insert ( url . clone ( ) , 0 ) ;
}
let tree = Arc ::new ( Mutex ::new ( Tree ::new ( ) ) ) ;
let running_queue = Arc ::new ( Mutex ::new ( running_queue ) ) ;
let processed_queue = Arc ::new ( Mutex ::new ( processed_queue ) ) ;
// Create background eviction thread
let tree_clone = Arc ::clone ( & tree ) ;
let processed_queue_clone = Arc ::clone ( & processed_queue ) ;
2024-11-24 23:17:11 -08:00
let running_queue_clone = Arc ::clone ( & running_queue ) ;
2024-11-23 08:34:48 -08:00
let eviction_thread = thread ::spawn ( move | | {
loop {
// Sleep for the specified interval
thread ::sleep ( Duration ::from_secs ( eviction_interval_secs ) ) ;
let locked_tree_clone = tree_clone . lock ( ) . unwrap ( ) ;
// Run eviction
2024-12-06 11:53:15 -08:00
locked_tree_clone . evict_tenant_by_size ( max_tree_size ) ;
2024-11-23 08:34:48 -08:00
// Print the process queue
let locked_processed_queue = processed_queue_clone . lock ( ) . unwrap ( ) ;
2024-11-25 13:36:02 -08:00
info! ( " Processed Queue: {:?} " , locked_processed_queue ) ;
2024-11-24 23:17:11 -08:00
// Print the running queue
let locked_running_queue = running_queue_clone . lock ( ) . unwrap ( ) ;
2024-11-25 13:36:02 -08:00
info! ( " Running Queue: {:?} " , locked_running_queue ) ;
2024-11-23 08:34:48 -08:00
}
} ) ;
2024-11-10 21:57:32 -08:00
for url in & worker_urls {
2025-06-18 11:28:15 -07:00
tree . lock ( ) . unwrap ( ) . insert ( " " , url ) ;
2024-11-10 21:57:32 -08:00
}
2024-11-23 08:34:48 -08:00
Router ::CacheAware {
2024-12-06 01:17:04 -08:00
worker_urls : Arc ::new ( RwLock ::new ( worker_urls ) ) ,
2024-11-23 08:34:48 -08:00
tree ,
running_queue ,
processed_queue ,
2024-11-10 21:57:32 -08:00
cache_threshold ,
2024-11-24 23:17:11 -08:00
balance_abs_threshold ,
balance_rel_threshold ,
2025-01-20 12:45:13 -08:00
timeout_secs ,
2025-01-20 14:36:54 -08:00
interval_secs ,
2024-11-23 08:34:48 -08:00
_eviction_thread : Some ( eviction_thread ) ,
2024-11-10 21:57:32 -08:00
}
}
2025-06-18 11:28:15 -07:00
PolicyConfig ::PrefillDecodeConfig {
selection_policy ,
prefill_urls ,
decode_urls ,
timeout_secs ,
interval_secs ,
} = > {
// Create PDRouter instance
let pd_router = PDRouter ::new (
prefill_urls ,
decode_urls ,
selection_policy ,
timeout_secs ,
interval_secs ,
) ? ;
Router ::PrefillDecode {
pd_router : Arc ::new ( pd_router ) ,
}
}
2024-12-08 17:17:37 -08:00
} )
2024-10-28 09:49:48 -07:00
}
2025-04-29 10:21:19 -07:00
/// Get a reference to the worker URLs shared across threads
pub fn get_worker_urls ( & self ) -> Arc < RwLock < Vec < String > > > {
match self {
Router ::RoundRobin { worker_urls , .. } = > Arc ::clone ( worker_urls ) ,
Router ::Random { worker_urls , .. } = > Arc ::clone ( worker_urls ) ,
Router ::CacheAware { worker_urls , .. } = > Arc ::clone ( worker_urls ) ,
2025-06-18 11:28:15 -07:00
Router ::PrefillDecode { .. } = > {
// For PD mode, return empty list since we manage workers differently
Arc ::new ( RwLock ::new ( Vec ::new ( ) ) )
}
2025-04-29 10:21:19 -07:00
}
}
2025-06-18 11:28:15 -07:00
pub fn wait_for_healthy_workers (
2024-12-08 17:17:37 -08:00
worker_urls : & [ String ] ,
timeout_secs : u64 ,
interval_secs : u64 ,
) -> Result < ( ) , String > {
let start_time = std ::time ::Instant ::now ( ) ;
2025-06-18 11:28:15 -07:00
let sync_client = reqwest ::blocking ::Client ::builder ( )
. timeout ( Duration ::from_secs ( timeout_secs ) )
. build ( )
. map_err ( | e | format! ( " Failed to create HTTP client: {} " , e ) ) ? ;
2024-12-08 17:17:37 -08:00
loop {
if start_time . elapsed ( ) > Duration ::from_secs ( timeout_secs ) {
2025-01-20 12:45:13 -08:00
error! (
2025-01-22 17:56:21 -08:00
" Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value " ,
timeout_secs , worker_urls
2025-01-20 12:45:13 -08:00
) ;
2024-12-08 17:17:37 -08:00
return Err ( format! (
2025-01-22 17:56:21 -08:00
" Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value " ,
timeout_secs , worker_urls
2024-12-08 17:17:37 -08:00
) ) ;
}
let mut all_healthy = true ;
let mut unhealthy_workers = Vec ::new ( ) ;
for url in worker_urls {
match sync_client . get ( & format! ( " {} /health " , url ) ) . send ( ) {
Ok ( res ) = > {
if ! res . status ( ) . is_success ( ) {
2025-03-16 22:49:47 -07:00
let msg = format! (
" Worker heatlh check is pending with status {} " ,
2024-12-08 17:17:37 -08:00
res . status ( )
) ;
2025-03-16 22:49:47 -07:00
info! ( " {} " , msg ) ;
2024-12-08 17:17:37 -08:00
all_healthy = false ;
2025-03-16 22:49:47 -07:00
unhealthy_workers . push ( ( url , msg ) ) ;
2024-12-08 17:17:37 -08:00
}
}
2025-03-16 22:49:47 -07:00
Err ( _ ) = > {
let msg = format! ( " Worker is not ready yet " ) ;
info! ( " {} " , msg ) ;
2024-12-08 17:17:37 -08:00
all_healthy = false ;
2025-03-16 22:49:47 -07:00
unhealthy_workers . push ( ( url , msg ) ) ;
2024-12-08 17:17:37 -08:00
}
}
}
if all_healthy {
info! ( " All workers are healthy " ) ;
return Ok ( ( ) ) ;
} else {
2025-03-16 22:49:47 -07:00
info! ( " Initializing workers: " ) ;
2024-12-08 17:17:37 -08:00
for ( url , reason ) in & unhealthy_workers {
info! ( " {} - {} " , url , reason ) ;
}
thread ::sleep ( Duration ::from_secs ( interval_secs ) ) ;
}
}
}
2024-12-11 00:51:21 -08:00
fn select_first_worker ( & self ) -> Result < String , String > {
match self {
Router ::RoundRobin { worker_urls , .. }
2025-01-20 12:45:13 -08:00
| Router ::Random { worker_urls , .. }
2024-12-11 00:51:21 -08:00
| Router ::CacheAware { worker_urls , .. } = > {
if worker_urls . read ( ) . unwrap ( ) . is_empty ( ) {
Err ( " No workers are available " . to_string ( ) )
} else {
Ok ( worker_urls . read ( ) . unwrap ( ) [ 0 ] . clone ( ) )
}
}
2025-06-18 11:28:15 -07:00
Router ::PrefillDecode { .. } = > {
// For PD mode, we don't need this method as routing is handled by PDRouter
Err ( " PrefillDecode mode doesn't use select_first_worker " . to_string ( ) )
}
2024-12-11 00:51:21 -08:00
}
}
2025-06-18 11:28:15 -07:00
pub async fn send_request (
2024-11-06 00:02:02 -08:00
& self ,
client : & reqwest ::Client ,
2024-12-11 01:38:50 -08:00
worker_url : & str ,
2024-11-23 15:10:26 -08:00
route : & str ,
2025-01-23 20:30:31 -08:00
req : & HttpRequest ,
2024-11-06 00:02:02 -08:00
) -> HttpResponse {
2025-05-24 22:28:15 -07:00
let start = Instant ::now ( ) ;
2025-01-23 20:30:31 -08:00
let mut request_builder = client . get ( format! ( " {} {} " , worker_url , route ) ) ;
// Copy all headers from original request except for /health because it does not need authorization
if route ! = " /health " {
for ( name , value ) in copy_request_headers ( req ) {
2025-06-18 11:28:15 -07:00
// Skip Content-Type and Content-Length as .json() sets them
if name . to_lowercase ( ) ! = " content-type " & & name . to_lowercase ( ) ! = " content-length "
{
request_builder = request_builder . header ( name , value ) ;
}
2025-01-23 20:30:31 -08:00
}
}
2025-05-24 22:28:15 -07:00
let response = match request_builder . send ( ) . await {
2024-12-11 00:51:21 -08:00
Ok ( res ) = > {
let status = actix_web ::http ::StatusCode ::from_u16 ( res . status ( ) . as_u16 ( ) )
. unwrap_or ( actix_web ::http ::StatusCode ::INTERNAL_SERVER_ERROR ) ;
match res . bytes ( ) . await {
Ok ( body ) = > HttpResponse ::build ( status ) . body ( body . to_vec ( ) ) ,
Err ( e ) = > HttpResponse ::InternalServerError ( )
. body ( format! ( " Failed to read response body: {} " , e ) ) ,
}
}
Err ( e ) = > HttpResponse ::InternalServerError ( ) . body ( format! (
" Failed to send request to worker {}: {} " ,
worker_url , e
) ) ,
2025-05-24 22:28:15 -07:00
} ;
// Record request metrics
if route ! = " /health " {
let duration = start . elapsed ( ) ;
counter! ( " sgl_router_requests_total " , " route " = > route . to_string ( ) ) . increment ( 1 ) ;
histogram! ( " sgl_router_request_duration_seconds " , " route " = > route . to_string ( ) )
. record ( duration . as_secs_f64 ( ) ) ;
if ! response . status ( ) . is_success ( ) {
counter! ( " sgl_router_request_errors_total " , " route " = > route . to_string ( ) )
. increment ( 1 ) ;
}
2024-12-11 00:51:21 -08:00
}
2025-05-24 22:28:15 -07:00
response
2024-12-11 00:51:21 -08:00
}
2025-01-23 20:30:31 -08:00
pub async fn route_to_first (
& self ,
client : & reqwest ::Client ,
route : & str ,
req : & HttpRequest ,
) -> HttpResponse {
2024-12-11 12:13:08 -08:00
const MAX_REQUEST_RETRIES : u32 = 3 ;
const MAX_TOTAL_RETRIES : u32 = 6 ;
let mut total_retries = 0 ;
while total_retries < MAX_TOTAL_RETRIES {
match self . select_first_worker ( ) {
Ok ( worker_url ) = > {
let mut request_retries = 0 ;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries > = 1 {
info! ( " Retrying request after {} failed attempts " , total_retries ) ;
}
2025-01-23 20:30:31 -08:00
let response = self . send_request ( client , & worker_url , route , req ) . await ;
2024-12-11 12:13:08 -08:00
if response . status ( ) . is_success ( ) {
return response ;
2025-01-23 20:30:31 -08:00
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self . send_request ( client , & worker_url , " /health " , req ) . await ;
if health_response . status ( ) . is_success ( ) {
return response ;
}
2024-12-11 12:13:08 -08:00
}
warn! (
" Request to {} failed (attempt {}/{}) " ,
worker_url ,
request_retries + 1 ,
MAX_REQUEST_RETRIES
) ;
request_retries + = 1 ;
total_retries + = 1 ;
if request_retries = = MAX_REQUEST_RETRIES {
warn! ( " Removing failed worker: {} " , worker_url ) ;
self . remove_worker ( & worker_url ) ;
break ;
}
}
}
Err ( e ) = > return HttpResponse ::InternalServerError ( ) . body ( e ) ,
}
2024-12-11 00:51:21 -08:00
}
2024-12-11 12:13:08 -08:00
HttpResponse ::InternalServerError ( ) . body ( " All retry attempts failed " )
2024-12-11 00:51:21 -08:00
}
2025-06-18 11:28:15 -07:00
pub async fn route_to_all (
& self ,
client : & reqwest ::Client ,
route : & str ,
req : & HttpRequest ,
) -> HttpResponse {
// Get all worker URLs based on router type
let worker_urls = match self {
Router ::PrefillDecode { .. } = > {
// For PD mode, route_to_all is not supported directly
// It should be handled by PDRouter if needed
return HttpResponse ::NotImplemented ( )
. body ( " route_to_all not implemented for PrefillDecode mode " ) ;
2024-12-11 00:51:21 -08:00
}
2025-06-18 11:28:15 -07:00
_ = > self . get_worker_urls ( ) . read ( ) . unwrap ( ) . clone ( ) ,
2025-03-04 04:06:30 -08:00
} ;
2024-12-11 00:51:21 -08:00
2025-06-18 11:28:15 -07:00
// Send requests to all workers concurrently
let mut tasks = Vec ::new ( ) ;
for worker_url in & worker_urls {
let mut request_builder = client . post ( format! ( " {} {} " , worker_url , route ) ) ;
// Copy headers from original request
for ( name , value ) in copy_request_headers ( req ) {
request_builder = request_builder . header ( name , value ) ;
2025-03-04 04:06:30 -08:00
}
2025-06-18 11:28:15 -07:00
tasks . push ( request_builder . send ( ) ) ;
}
// Wait for all responses
let results = futures_util ::future ::join_all ( tasks ) . await ;
// Check if all succeeded
let all_success = results . iter ( ) . all ( | r | {
r . as_ref ( )
. map ( | res | res . status ( ) . is_success ( ) )
. unwrap_or ( false )
} ) ;
if all_success {
HttpResponse ::Ok ( ) . body ( " Operation completed on all servers " )
} else {
HttpResponse ::InternalServerError ( ) . body ( " Operation failed on one or more servers " )
}
}
pub async fn get_all_loads (
& self ,
client : & reqwest ::Client ,
_req : & HttpRequest ,
) -> HttpResponse {
// For PD mode, delegate to PDRouter
match self {
Router ::PrefillDecode { pd_router } = > {
return pd_router . get_loads ( client ) . await ;
2025-03-04 04:06:30 -08:00
}
_ = > {
2025-06-18 11:28:15 -07:00
// For non-PD routers, handle normally
2025-03-04 04:06:30 -08:00
}
}
2025-06-18 11:28:15 -07:00
let urls = self . get_worker_urls ( ) . read ( ) . unwrap ( ) . clone ( ) ;
let prefill_urls : Vec < String > = Vec ::new ( ) ;
let decode_urls = urls ;
// Collect loads from all servers
let mut prefill_loads = Vec ::new ( ) ;
let mut decode_loads = Vec ::new ( ) ;
// Get prefill loads
for url in & prefill_urls {
let load = self . get_worker_load ( client , url ) . await . unwrap_or ( - 1 ) ;
prefill_loads . push ( serde_json ::json! ( {
" engine " : format ! ( " (Prefill@{}) " , url ) ,
" load " : load as i64
} ) ) ;
}
// Get decode loads
for url in & decode_urls {
let load = self . get_worker_load ( client , url ) . await . unwrap_or ( - 1 ) ;
decode_loads . push ( serde_json ::json! ( {
" engine " : format ! ( " (Decode@{}) " , url ) ,
" load " : load as i64
} ) ) ;
}
HttpResponse ::Ok ( ) . json ( serde_json ::json! ( {
" prefill " : prefill_loads ,
" decode " : decode_loads
} ) )
2024-12-11 00:51:21 -08:00
}
2025-06-18 11:28:15 -07:00
// New method to route typed requests directly
pub async fn route_typed_request <
T : crate ::openai_api_types ::GenerationRequest + serde ::Serialize + Clone ,
> (
& self ,
client : & reqwest ::Client ,
req : & HttpRequest ,
typed_req : & T ,
route : & str ,
) -> HttpResponse {
match self {
Router ::PrefillDecode { .. } = > HttpResponse ::InternalServerError ( )
. body ( " PD routing should use specialized typed handlers " ) ,
_ = > {
// Handle retries like the original implementation
let start = Instant ::now ( ) ;
const MAX_REQUEST_RETRIES : u32 = 3 ;
const MAX_TOTAL_RETRIES : u32 = 6 ;
let mut total_retries = 0 ;
while total_retries < MAX_TOTAL_RETRIES {
// Extract routing text directly from typed request
let text = typed_req . extract_text_for_routing ( ) ;
let is_stream = typed_req . is_stream ( ) ;
// Select worker based on text
let worker_url = self . select_generate_worker_from_text ( & text ) ;
let mut request_retries = 0 ;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries > = 1 {
info! ( " Retrying request after {} failed attempts " , total_retries ) ;
counter! ( " sgl_router_retries_total " , " route " = > route . to_string ( ) )
. increment ( 1 ) ;
}
// Send typed request directly
let response = self
. send_typed_request (
client ,
req ,
typed_req ,
route ,
& worker_url ,
is_stream ,
)
. await ;
if response . status ( ) . is_success ( ) {
let duration = start . elapsed ( ) ;
histogram! ( " sgl_router_generate_duration_seconds " , " route " = > route . to_string ( ) )
. record ( duration . as_secs_f64 ( ) ) ;
return response ;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self . send_request ( client , & worker_url , " /health " , req ) . await ;
if health_response . status ( ) . is_success ( ) {
counter! ( " sgl_router_request_errors_total " , " route " = > route . to_string ( ) )
. increment ( 1 ) ;
return response ;
}
}
warn! (
" Generate request to {} failed (attempt {}/{}) " ,
worker_url ,
request_retries + 1 ,
MAX_REQUEST_RETRIES
) ;
request_retries + = 1 ;
total_retries + = 1 ;
2024-11-10 21:57:32 -08:00
2025-06-18 11:28:15 -07:00
if request_retries = = MAX_REQUEST_RETRIES {
warn! ( " Removing failed worker: {} " , worker_url ) ;
self . remove_worker ( & worker_url ) ;
break ;
}
}
}
counter! ( " sgl_router_request_errors_total " , " route " = > route . to_string ( ) )
. increment ( 1 ) ;
HttpResponse ::InternalServerError ( ) . body ( " All retry attempts failed " )
}
}
}
// Helper method to select worker from text
fn select_generate_worker_from_text ( & self , text : & str ) -> String {
match self {
2024-11-06 00:02:02 -08:00
Router ::RoundRobin {
worker_urls ,
current_index ,
2025-01-20 12:45:13 -08:00
..
2024-11-06 00:02:02 -08:00
} = > {
2024-11-10 21:57:32 -08:00
let idx = current_index
2024-11-06 00:02:02 -08:00
. fetch_update (
std ::sync ::atomic ::Ordering ::SeqCst ,
std ::sync ::atomic ::Ordering ::SeqCst ,
2024-12-06 01:17:04 -08:00
| x | Some ( ( x + 1 ) % worker_urls . read ( ) . unwrap ( ) . len ( ) ) ,
2024-11-06 00:02:02 -08:00
)
2024-11-10 21:57:32 -08:00
. unwrap ( ) ;
2024-12-06 01:17:04 -08:00
worker_urls . read ( ) . unwrap ( ) [ idx ] . clone ( )
2024-11-06 00:02:02 -08:00
}
2024-11-10 21:57:32 -08:00
2025-01-20 12:45:13 -08:00
Router ::Random { worker_urls , .. } = > worker_urls . read ( ) . unwrap ( )
2024-12-06 01:17:04 -08:00
[ rand ::random ::< usize > ( ) % worker_urls . read ( ) . unwrap ( ) . len ( ) ]
. clone ( ) ,
2024-11-10 21:57:32 -08:00
2024-11-23 08:34:48 -08:00
Router ::CacheAware {
2024-11-10 21:57:32 -08:00
worker_urls ,
2024-11-23 08:34:48 -08:00
tree ,
running_queue ,
processed_queue ,
2024-11-10 21:57:32 -08:00
cache_threshold ,
2024-11-24 23:17:11 -08:00
balance_abs_threshold ,
balance_rel_threshold ,
2024-11-10 21:57:32 -08:00
..
} = > {
2024-11-26 15:00:41 -08:00
let tree = tree . lock ( ) . unwrap ( ) ;
2024-11-23 08:34:48 -08:00
let mut running_queue = running_queue . lock ( ) . unwrap ( ) ;
2024-11-10 21:57:32 -08:00
2024-11-24 23:17:11 -08:00
// Get current load statistics
let max_load = * running_queue . values ( ) . max ( ) . unwrap_or ( & 0 ) ;
let min_load = * running_queue . values ( ) . min ( ) . unwrap_or ( & 0 ) ;
// Load is considered imbalanced if:
// 1. (max - min) > abs_threshold AND
// 2. max > rel_threshold * min
let is_imbalanced = max_load . saturating_sub ( min_load ) > * balance_abs_threshold
& & ( max_load as f32 ) > ( min_load as f32 * balance_rel_threshold ) ;
let selected_url = if is_imbalanced {
// Log load balancing trigger and current queue state
2024-11-25 13:36:02 -08:00
info! (
2024-11-24 23:17:11 -08:00
" Load balancing triggered due to workload imbalance: \n \
Max load : { } , Min load : { } \ n \
Current running queue : { :? } " ,
max_load , min_load , running_queue
) ;
2025-05-24 22:28:15 -07:00
counter! ( " sgl_router_load_balancing_events_total " ) . increment ( 1 ) ;
gauge! ( " sgl_router_max_load " ) . set ( max_load as f64 ) ;
gauge! ( " sgl_router_min_load " ) . set ( min_load as f64 ) ;
2024-11-24 23:17:11 -08:00
// Use shortest queue routing when load is imbalanced
running_queue
. iter ( )
. min_by_key ( | ( _url , & count ) | count )
. map ( | ( url , _ ) | url . clone ( ) )
2024-12-06 01:17:04 -08:00
. unwrap_or_else ( | | worker_urls . read ( ) . unwrap ( ) [ 0 ] . clone ( ) )
2024-11-24 23:17:11 -08:00
} else {
// Use cache-aware routing when load is balanced
2024-11-23 08:34:48 -08:00
let ( matched_text , matched_worker ) = tree . prefix_match ( & text ) ;
let matched_rate =
matched_text . chars ( ) . count ( ) as f32 / text . chars ( ) . count ( ) as f32 ;
2024-11-10 21:57:32 -08:00
2024-11-23 08:34:48 -08:00
if matched_rate > * cache_threshold {
2025-05-24 22:28:15 -07:00
counter! ( " sgl_router_cache_hits_total " ) . increment ( 1 ) ;
2024-11-23 08:34:48 -08:00
matched_worker . to_string ( )
} else {
2025-05-24 22:28:15 -07:00
counter! ( " sgl_router_cache_misses_total " ) . increment ( 1 ) ;
2024-11-23 08:34:48 -08:00
tree . get_smallest_tenant ( )
2024-11-10 21:57:32 -08:00
}
2024-11-23 08:34:48 -08:00
} ;
2024-11-10 21:57:32 -08:00
2024-11-24 23:17:11 -08:00
// Update queues and tree
* running_queue . get_mut ( & selected_url ) . unwrap ( ) + = 1 ;
2024-11-23 08:34:48 -08:00
2024-11-24 23:17:11 -08:00
* processed_queue
. lock ( )
. unwrap ( )
. get_mut ( & selected_url )
. unwrap ( ) + = 1 ;
2025-05-24 22:28:15 -07:00
gauge! ( " sgl_router_running_requests " , " worker " = > selected_url . to_string ( ) )
. set ( * running_queue . get ( & selected_url ) . unwrap ( ) as f64 ) ;
counter! ( " sgl_router_processed_requests_total " , " worker " = > selected_url . to_string ( ) ) . increment ( 1 ) ;
2024-11-23 08:34:48 -08:00
tree . insert ( & text , & selected_url ) ;
selected_url
2024-11-06 00:02:02 -08:00
}
2025-06-18 11:28:15 -07:00
Router ::PrefillDecode { .. } = > {
// For PD mode, we don't use this method
return " PD_MODE_ERROR " . to_string ( ) ;
}
}
2024-12-11 00:51:21 -08:00
}
2025-06-18 11:28:15 -07:00
// Send typed request directly without conversion
async fn send_typed_request < T : serde ::Serialize > (
2024-12-11 00:51:21 -08:00
& self ,
client : & reqwest ::Client ,
2024-12-11 01:38:50 -08:00
req : & HttpRequest ,
2025-06-18 11:28:15 -07:00
typed_req : & T ,
2024-12-11 00:51:21 -08:00
route : & str ,
worker_url : & str ,
2025-06-18 11:28:15 -07:00
is_stream : bool ,
2024-12-11 00:51:21 -08:00
) -> HttpResponse {
2025-06-18 11:28:15 -07:00
let start = Instant ::now ( ) ;
// Debug: Log what we're sending
if let Ok ( json_str ) = serde_json ::to_string_pretty ( typed_req ) {
debug! ( " Sending request to {}: {} " , route , json_str ) ;
}
2024-11-04 10:56:52 -08:00
2025-01-23 20:30:31 -08:00
let mut request_builder = client
2024-12-11 00:51:21 -08:00
. post ( format! ( " {} {} " , worker_url , route ) )
2025-06-18 11:28:15 -07:00
. json ( typed_req ) ; // Use json() directly with typed request
2025-01-23 20:30:31 -08:00
// Copy all headers from original request
for ( name , value ) in copy_request_headers ( req ) {
2025-06-18 11:28:15 -07:00
// Skip Content-Type and Content-Length as .json() sets them
if name . to_lowercase ( ) ! = " content-type " & & name . to_lowercase ( ) ! = " content-length " {
request_builder = request_builder . header ( & name , & value ) ;
}
2025-01-23 20:30:31 -08:00
}
let res = match request_builder . send ( ) . await {
2024-11-06 00:02:02 -08:00
Ok ( res ) = > res ,
2025-06-18 11:28:15 -07:00
Err ( e ) = > {
error! ( " Failed to send request to {}: {} " , worker_url , e ) ;
return HttpResponse ::InternalServerError ( ) . body ( format! ( " Request failed: {} " , e ) ) ;
}
2024-11-06 00:02:02 -08:00
} ;
2024-11-04 10:56:52 -08:00
2024-11-06 00:02:02 -08:00
let status = actix_web ::http ::StatusCode ::from_u16 ( res . status ( ) . as_u16 ( ) )
. unwrap_or ( actix_web ::http ::StatusCode ::INTERNAL_SERVER_ERROR ) ;
2024-10-28 09:49:48 -07:00
2024-11-06 00:02:02 -08:00
if ! is_stream {
2024-11-23 08:34:48 -08:00
// For non-streaming requests, get response first
let response = match res . bytes ( ) . await {
2024-11-06 00:02:02 -08:00
Ok ( body ) = > HttpResponse ::build ( status ) . body ( body . to_vec ( ) ) ,
2024-12-06 01:17:04 -08:00
Err ( e ) = > {
let error_msg = format! ( " Failed to get response body: {} " , e ) ;
HttpResponse ::InternalServerError ( ) . body ( error_msg )
}
2024-11-23 08:34:48 -08:00
} ;
// Then decrement running queue counter if using CacheAware
if let Router ::CacheAware { running_queue , .. } = self {
if let Ok ( mut queue ) = running_queue . lock ( ) {
2024-12-11 00:51:21 -08:00
if let Some ( count ) = queue . get_mut ( worker_url ) {
2024-11-23 08:34:48 -08:00
* count = count . saturating_sub ( 1 ) ;
}
}
2024-11-06 00:02:02 -08:00
}
2024-11-23 08:34:48 -08:00
2025-06-18 11:28:15 -07:00
// Record metrics
let duration = start . elapsed ( ) ;
histogram! ( " sgl_router_generate_duration_seconds " , " route " = > route . to_string ( ) )
. record ( duration . as_secs_f64 ( ) ) ;
counter! ( " sgl_router_requests_total " , " route " = > route . to_string ( ) ) . increment ( 1 ) ;
2024-11-23 08:34:48 -08:00
response
} else if let Router ::CacheAware { running_queue , .. } = self {
let running_queue = Arc ::clone ( running_queue ) ;
2024-12-11 00:51:21 -08:00
let worker_url = worker_url . to_string ( ) ;
2024-11-23 08:34:48 -08:00
HttpResponse ::build ( status )
. insert_header ( ( CONTENT_TYPE , HeaderValue ::from_static ( " text/event-stream " ) ) )
. streaming (
res . bytes_stream ( )
. map_err ( | _ | {
actix_web ::error ::ErrorInternalServerError ( " Failed to read stream " )
} )
. inspect ( move | bytes | {
let bytes = bytes . as_ref ( ) . unwrap ( ) ;
if bytes
. as_ref ( )
. windows ( 12 )
. any ( | window | window = = b " data: [DONE] " )
{
let mut locked_queue = running_queue . lock ( ) . unwrap ( ) ;
let count = locked_queue . get_mut ( & worker_url ) . unwrap ( ) ;
* count = count . saturating_sub ( 1 ) ;
2024-12-11 00:51:21 -08:00
debug! ( " Streaming is done!! " )
2024-11-23 08:34:48 -08:00
}
} ) ,
)
2024-11-06 00:02:02 -08:00
} else {
HttpResponse ::build ( status )
. insert_header ( ( CONTENT_TYPE , HeaderValue ::from_static ( " text/event-stream " ) ) )
. streaming ( res . bytes_stream ( ) . map_err ( | _ | {
2024-11-23 08:34:48 -08:00
actix_web ::error ::ErrorInternalServerError ( " Failed to read stream " )
2024-11-06 00:02:02 -08:00
} ) )
2024-10-28 09:49:48 -07:00
}
}
2024-12-06 01:17:04 -08:00
2024-12-11 01:38:50 -08:00
pub async fn add_worker ( & self , worker_url : & str ) -> Result < String , String > {
2025-01-20 14:36:54 -08:00
let ( timeout_secs , interval_secs ) = match self {
Router ::Random {
timeout_secs ,
interval_secs ,
..
} = > ( * timeout_secs , * interval_secs ) ,
Router ::RoundRobin {
timeout_secs ,
interval_secs ,
..
} = > ( * timeout_secs , * interval_secs ) ,
Router ::CacheAware {
timeout_secs ,
interval_secs ,
..
} = > ( * timeout_secs , * interval_secs ) ,
2025-06-18 11:28:15 -07:00
Router ::PrefillDecode { .. } = > {
// For PD mode, we don't support adding workers via this method
return Err ( " Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods. " . to_string ( ) ) ;
}
2025-01-20 12:45:13 -08:00
} ;
2024-12-07 15:39:54 -08:00
let start_time = std ::time ::Instant ::now ( ) ;
2025-06-18 11:28:15 -07:00
let client = reqwest ::Client ::builder ( )
. timeout ( Duration ::from_secs ( timeout_secs ) )
. build ( )
. map_err ( | e | format! ( " Failed to create HTTP client: {} " , e ) ) ? ;
2024-12-07 15:39:54 -08:00
loop {
if start_time . elapsed ( ) > Duration ::from_secs ( timeout_secs ) {
2025-01-20 12:45:13 -08:00
error! (
2025-01-22 17:56:21 -08:00
" Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value " ,
2025-01-20 12:45:13 -08:00
timeout_secs , worker_url
) ;
2024-12-08 17:17:37 -08:00
return Err ( format! (
2025-01-22 17:56:21 -08:00
" Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value " ,
2024-12-07 15:39:54 -08:00
timeout_secs , worker_url
) ) ;
}
match client . get ( & format! ( " {} /health " , worker_url ) ) . send ( ) . await {
Ok ( res ) = > {
if res . status ( ) . is_success ( ) {
match self {
Router ::RoundRobin { worker_urls , .. }
2025-01-20 12:45:13 -08:00
| Router ::Random { worker_urls , .. }
2024-12-07 15:39:54 -08:00
| Router ::CacheAware { worker_urls , .. } = > {
info! ( " Worker {} health check passed " , worker_url ) ;
let mut urls = worker_urls . write ( ) . unwrap ( ) ;
2024-12-11 01:38:50 -08:00
if urls . contains ( & worker_url . to_string ( ) ) {
2024-12-08 17:17:37 -08:00
return Err ( format! ( " Worker {} already exists " , worker_url ) ) ;
2024-12-07 15:39:54 -08:00
}
info! ( " Added worker: {} " , worker_url ) ;
2024-12-11 01:38:50 -08:00
urls . push ( worker_url . to_string ( ) ) ;
2025-05-24 22:28:15 -07:00
gauge! ( " sgl_router_active_workers " ) . set ( urls . len ( ) as f64 ) ;
2024-12-07 15:39:54 -08:00
}
2025-06-18 11:28:15 -07:00
Router ::PrefillDecode { .. } = > {
return Err ( " Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods. " . to_string ( ) ) ;
}
2024-12-07 15:39:54 -08:00
}
2024-12-08 17:17:37 -08:00
// If cache aware, initialize the queues for the new worker
if let Router ::CacheAware {
running_queue ,
processed_queue ,
tree ,
..
} = self
{
// Add worker to running queue with initial count of 0
2024-12-11 01:38:50 -08:00
running_queue
. lock ( )
. unwrap ( )
. insert ( worker_url . to_string ( ) , 0 ) ;
2024-12-08 17:17:37 -08:00
// Add worker to processed queue with initial count of 0
processed_queue
. lock ( )
. unwrap ( )
2024-12-11 01:38:50 -08:00
. insert ( worker_url . to_string ( ) , 0 ) ;
2024-12-08 17:17:37 -08:00
// Add worker to tree
2025-06-18 11:28:15 -07:00
tree . lock ( ) . unwrap ( ) . insert ( " " , worker_url ) ;
2024-12-08 17:17:37 -08:00
}
return Ok ( format! ( " Successfully added worker: {} " , worker_url ) ) ;
2024-12-07 15:39:54 -08:00
} else {
info! (
2024-12-08 17:17:37 -08:00
" Worker {} health check is pending with status: {}. " ,
worker_url ,
res . status ( )
2024-12-07 15:39:54 -08:00
) ;
// if the url does not have http or https prefix, warn users
if ! worker_url . starts_with ( " http:// " ) & & ! worker_url . starts_with ( " https:// " )
{
warn! ( " The worker url {} does not have http or https prefix. Please add the prefix to the url. " , worker_url ) ;
}
tokio ::time ::sleep ( Duration ::from_secs ( interval_secs ) ) . await ;
continue ;
}
}
Err ( e ) = > {
2024-12-08 17:17:37 -08:00
info! (
" Worker {} health check is pending with error: {} " ,
worker_url , e
) ;
2024-12-07 15:39:54 -08:00
// if the url does not have http or https prefix, warn users
if ! worker_url . starts_with ( " http:// " ) & & ! worker_url . starts_with ( " https:// " ) {
warn! ( " The worker url {} does not have http or https prefix. Please add the prefix to the url. " , worker_url ) ;
}
tokio ::time ::sleep ( Duration ::from_secs ( interval_secs ) ) . await ;
continue ;
}
2024-12-06 01:17:04 -08:00
}
}
}
2024-12-06 17:16:03 -08:00
2024-12-11 01:38:50 -08:00
pub fn remove_worker ( & self , worker_url : & str ) {
2024-12-06 17:16:03 -08:00
match self {
Router ::RoundRobin { worker_urls , .. }
2025-01-20 12:45:13 -08:00
| Router ::Random { worker_urls , .. }
2024-12-06 17:16:03 -08:00
| Router ::CacheAware { worker_urls , .. } = > {
let mut urls = worker_urls . write ( ) . unwrap ( ) ;
2024-12-11 12:13:08 -08:00
if let Some ( index ) = urls . iter ( ) . position ( | url | url = = & worker_url ) {
urls . remove ( index ) ;
info! ( " Removed worker: {} " , worker_url ) ;
2025-05-24 22:28:15 -07:00
gauge! ( " sgl_router_active_workers " ) . set ( urls . len ( ) as f64 ) ;
2024-12-11 12:13:08 -08:00
} else {
warn! ( " Worker {} not found, skipping removal " , worker_url ) ;
return ;
}
2024-12-06 17:16:03 -08:00
}
2025-06-18 11:28:15 -07:00
Router ::PrefillDecode { .. } = > {
warn! ( " Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods. " ) ;
return ;
}
2024-12-06 17:16:03 -08:00
}
// if cache aware, remove the worker from the tree
2024-12-08 17:17:37 -08:00
if let Router ::CacheAware {
tree ,
running_queue ,
processed_queue ,
..
} = self
{
2024-12-06 17:16:03 -08:00
tree . lock ( ) . unwrap ( ) . remove_tenant ( & worker_url ) ;
2024-12-11 01:38:50 -08:00
running_queue
. lock ( )
. unwrap ( )
. remove ( & worker_url . to_string ( ) ) ;
processed_queue
. lock ( )
. unwrap ( )
. remove ( & worker_url . to_string ( ) ) ;
2024-12-08 17:17:37 -08:00
info! (
" Removed worker from tree and cleaned up queues: {} " ,
worker_url
) ;
2024-12-06 17:16:03 -08:00
}
}
2025-06-18 11:28:15 -07:00
2025-06-22 17:54:14 -07:00
/// Add a worker with PD mode support
pub async fn add_pd_worker (
& self ,
worker_url : & str ,
pod_type : crate ::service_discovery ::PodType ,
bootstrap_port : Option < u16 > ,
) -> Result < String , String > {
match self {
Router ::PrefillDecode { pd_router } = > match pod_type {
crate ::service_discovery ::PodType ::Prefill = > pd_router
. add_prefill_server ( worker_url . to_string ( ) , bootstrap_port )
. await
. map_err ( | e | e . to_string ( ) ) ,
crate ::service_discovery ::PodType ::Decode = > pd_router
. add_decode_server ( worker_url . to_string ( ) )
. await
. map_err ( | e | e . to_string ( ) ) ,
crate ::service_discovery ::PodType ::Regular = > {
Err ( " Regular pod type not supported in PD mode " . to_string ( ) )
}
} ,
_ = > Err ( " add_pd_worker only supported in PD mode " . to_string ( ) ) ,
}
}
/// Remove a worker with PD mode support
pub async fn remove_pd_worker (
& self ,
worker_url : & str ,
pod_type : crate ::service_discovery ::PodType ,
) -> Result < String , String > {
match self {
Router ::PrefillDecode { pd_router } = > match pod_type {
crate ::service_discovery ::PodType ::Prefill = > pd_router
. remove_prefill_server ( worker_url )
. await
. map_err ( | e | e . to_string ( ) ) ,
crate ::service_discovery ::PodType ::Decode = > pd_router
. remove_decode_server ( worker_url )
. await
. map_err ( | e | e . to_string ( ) ) ,
crate ::service_discovery ::PodType ::Regular = > {
Err ( " Regular pod type not supported in PD mode " . to_string ( ) )
}
} ,
_ = > Err ( " remove_pd_worker only supported in PD mode " . to_string ( ) ) ,
}
}
2025-06-18 11:28:15 -07:00
async fn get_worker_load ( & self , client : & reqwest ::Client , worker_url : & str ) -> Option < isize > {
match client . get ( & format! ( " {} /get_load " , worker_url ) ) . send ( ) . await {
Ok ( res ) if res . status ( ) . is_success ( ) = > match res . bytes ( ) . await {
Ok ( bytes ) = > match serde_json ::from_slice ::< serde_json ::Value > ( & bytes ) {
Ok ( data ) = > data
. get ( " load " )
. and_then ( | v | v . as_i64 ( ) )
. map ( | v | v as isize ) ,
Err ( e ) = > {
debug! ( " Failed to parse load response from {}: {} " , worker_url , e ) ;
None
}
} ,
Err ( e ) = > {
debug! ( " Failed to read load response from {}: {} " , worker_url , e ) ;
None
}
} ,
Ok ( res ) = > {
debug! (
" Worker {} returned non-success status: {} " ,
worker_url ,
res . status ( )
) ;
None
}
Err ( e ) = > {
debug! ( " Failed to get load from {}: {} " , worker_url , e ) ;
None
}
}
}
// PD-specific wrapper methods that delegate to PDRouter
pub async fn route_pd_health_generate (
& self ,
_client : & reqwest ::Client ,
_req : & HttpRequest ,
) -> HttpResponse {
match self {
Router ::PrefillDecode { pd_router } = > {
pd_router . health_generate ( & pd_router . http_client ) . await
}
_ = > HttpResponse ::InternalServerError ( ) . body ( " Not in PrefillDecode mode " ) ,
}
}
pub async fn route_pd_generate_typed (
& self ,
_client : & reqwest ::Client ,
req : & HttpRequest ,
typed_req : crate ::pd_types ::GenerateReqInput ,
route : & str ,
) -> HttpResponse {
match self {
Router ::PrefillDecode { pd_router } = > {
pd_router
. route_generate ( & pd_router . http_client , req , typed_req , route )
. await
}
_ = > HttpResponse ::InternalServerError ( ) . body ( " Not in PrefillDecode mode " ) ,
}
}
pub async fn route_pd_chat_typed (
& self ,
_client : & reqwest ::Client ,
req : & HttpRequest ,
typed_req : crate ::pd_types ::ChatReqInput ,
route : & str ,
) -> HttpResponse {
match self {
Router ::PrefillDecode { pd_router } = > {
pd_router
. route_chat ( & pd_router . http_client , req , typed_req , route )
. await
}
_ = > HttpResponse ::InternalServerError ( ) . body ( " Not in PrefillDecode mode " ) ,
}
}
pub async fn get_pd_server_info (
& self ,
_client : & reqwest ::Client ,
_req : & HttpRequest ,
) -> HttpResponse {
match self {
Router ::PrefillDecode { pd_router } = > {
pd_router . get_server_info ( & pd_router . http_client ) . await
}
_ = > HttpResponse ::InternalServerError ( ) . body ( " Not in PrefillDecode mode " ) ,
}
}
pub async fn get_pd_models (
& self ,
_client : & reqwest ::Client ,
req : & HttpRequest ,
) -> HttpResponse {
match self {
Router ::PrefillDecode { pd_router } = > {
pd_router . get_models ( & pd_router . http_client , req ) . await
}
_ = > HttpResponse ::InternalServerError ( ) . body ( " Not in PrefillDecode mode " ) ,
}
}
pub async fn route_pd_flush_cache ( & self , _client : & reqwest ::Client ) -> HttpResponse {
match self {
Router ::PrefillDecode { pd_router } = > {
pd_router . flush_cache ( & pd_router . http_client ) . await
}
_ = > HttpResponse ::InternalServerError ( ) . body ( " Not in PrefillDecode mode " ) ,
}
}
pub async fn get_pd_model_info (
& self ,
_client : & reqwest ::Client ,
req : & HttpRequest ,
) -> HttpResponse {
match self {
Router ::PrefillDecode { pd_router } = > {
pd_router . get_model_info ( & pd_router . http_client , req ) . await
}
_ = > HttpResponse ::InternalServerError ( ) . body ( " Not in PrefillDecode mode " ) ,
}
}
2024-11-04 10:56:52 -08:00
}
2025-06-22 17:54:14 -07:00
#[ cfg(test) ]
mod tests {
use super ::* ;
use crate ::service_discovery ::PodType ;
fn create_test_regular_router ( ) -> Router {
Router ::Random {
worker_urls : Arc ::new ( RwLock ::new ( vec! [
" http://worker1:8080 " . to_string ( ) ,
" http://worker2:8080 " . to_string ( ) ,
] ) ) ,
timeout_secs : 5 ,
interval_secs : 1 ,
}
}
#[ test ]
fn test_router_get_worker_urls_regular ( ) {
let router = create_test_regular_router ( ) ;
let worker_urls = router . get_worker_urls ( ) ;
let urls = worker_urls . read ( ) . unwrap ( ) ;
assert_eq! ( urls . len ( ) , 2 ) ;
assert! ( urls . contains ( & " http://worker1:8080 " . to_string ( ) ) ) ;
assert! ( urls . contains ( & " http://worker2:8080 " . to_string ( ) ) ) ;
}
// #[test]
// fn test_router_get_worker_urls_pd_mode() {
// // For PD mode, get_worker_urls returns empty list
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[ tokio::test ]
async fn test_add_pd_worker_with_regular_router ( ) {
let router = create_test_regular_router ( ) ;
let result = router
. add_pd_worker ( " http://new-worker:8080 " , PodType ::Prefill , Some ( 8081 ) )
. await ;
assert! ( result . is_err ( ) ) ;
assert! ( result
. unwrap_err ( )
. contains ( " add_pd_worker only supported in PD mode " ) ) ;
}
#[ tokio::test ]
async fn test_remove_pd_worker_with_regular_router ( ) {
let router = create_test_regular_router ( ) ;
let result = router
. remove_pd_worker ( " http://worker:8080 " , PodType ::Decode )
. await ;
assert! ( result . is_err ( ) ) ;
assert! ( result
. unwrap_err ( )
. contains ( " remove_pd_worker only supported in PD mode " ) ) ;
}
// #[tokio::test]
// async fn test_add_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
// #[tokio::test]
// async fn test_remove_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[ test ]
fn test_select_first_worker_regular ( ) {
let router = create_test_regular_router ( ) ;
let result = router . select_first_worker ( ) ;
assert! ( result . is_ok ( ) ) ;
assert_eq! ( result . unwrap ( ) , " http://worker1:8080 " ) ;
}
// #[test]
// fn test_select_first_worker_pd_mode() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[ test ]
fn test_wait_for_healthy_workers_empty_list ( ) {
let result = Router ::wait_for_healthy_workers ( & [ ] , 1 , 1 ) ;
assert! ( result . is_ok ( ) ) ;
}
#[ test ]
fn test_wait_for_healthy_workers_invalid_urls ( ) {
// This test will timeout quickly since the URLs are invalid
let result =
Router ::wait_for_healthy_workers ( & [ " http://nonexistent:8080 " . to_string ( ) ] , 1 , 1 ) ;
assert! ( result . is_err ( ) ) ;
assert! ( result . unwrap_err ( ) . contains ( " Timeout " ) ) ;
}
}