2025-09-04 00:35:51 -04:00
use crate ::config ::types ::RetryConfig ;
2025-08-11 05:53:26 -07:00
use crate ::core ::{
2025-09-19 15:37:57 -04:00
is_retryable_status , BasicWorkerBuilder , CircuitBreakerConfig , ConnectionMode , RetryExecutor ,
2025-09-19 01:52:57 -04:00
Worker , WorkerRegistry , WorkerType ,
2025-08-11 05:53:26 -07:00
} ;
2025-07-18 22:09:17 -07:00
use crate ::metrics ::RouterMetrics ;
2025-09-12 19:18:27 -04:00
use crate ::policies ::{ LoadBalancingPolicy , PolicyRegistry } ;
2025-08-22 14:18:47 -07:00
use crate ::protocols ::spec ::{
2025-09-15 09:44:35 +08:00
ChatCompletionRequest , CompletionRequest , EmbeddingRequest , GenerateRequest , GenerationRequest ,
RerankRequest , RerankResponse , RerankResult , ResponsesRequest ,
2025-08-18 18:07:58 -07:00
} ;
2025-08-28 12:07:06 -07:00
use crate ::routers ::header_utils ;
2025-07-30 17:47:19 -07:00
use crate ::routers ::{ RouterTrait , WorkerManagement } ;
2025-09-13 00:10:18 +08:00
use axum ::body ::to_bytes ;
2025-07-30 17:47:19 -07:00
use axum ::{
body ::Body ,
extract ::Request ,
2025-09-12 16:19:38 -07:00
http ::{
header ::CONTENT_LENGTH , header ::CONTENT_TYPE , HeaderMap , HeaderValue , Method , StatusCode ,
} ,
2025-07-30 17:47:19 -07:00
response ::{ IntoResponse , Response } ,
Json ,
} ;
use futures_util ::StreamExt ;
2025-08-04 20:42:07 -07:00
use reqwest ::Client ;
2025-07-18 14:24:24 -07:00
use std ::collections ::HashMap ;
2025-09-12 19:18:27 -04:00
use std ::sync ::Arc ;
2025-07-18 14:24:24 -07:00
use std ::time ::{ Duration , Instant } ;
2025-07-30 17:47:19 -07:00
use tokio_stream ::wrappers ::UnboundedReceiverStream ;
2025-07-18 14:24:24 -07:00
use tracing ::{ debug , error , info , warn } ;
/// Regular router that uses injected load balancing policies
#[ derive(Debug) ]
pub struct Router {
2025-09-12 19:18:27 -04:00
worker_registry : Arc < WorkerRegistry > ,
policy_registry : Arc < PolicyRegistry > ,
2025-08-02 19:16:47 -07:00
client : Client ,
2025-09-02 04:57:04 +02:00
worker_startup_timeout_secs : u64 ,
worker_startup_check_interval_secs : u64 ,
2025-07-30 20:58:48 +08:00
dp_aware : bool ,
2025-09-22 03:28:38 +08:00
#[ allow(dead_code) ]
2025-07-30 20:58:48 +08:00
api_key : Option < String > ,
2025-08-04 20:42:07 -07:00
retry_config : RetryConfig ,
2025-08-08 09:20:22 -07:00
circuit_breaker_config : CircuitBreakerConfig ,
2025-07-18 14:24:24 -07:00
_worker_loads : Arc < tokio ::sync ::watch ::Receiver < HashMap < String , isize > > > ,
_load_monitor_handle : Option < Arc < tokio ::task ::JoinHandle < ( ) > > > ,
}
impl Router {
2025-08-02 19:16:47 -07:00
/// Create a new router with injected policy and client
2025-09-19 15:37:57 -04:00
pub async fn new ( ctx : & Arc < crate ::server ::AppContext > ) -> Result < Self , String > {
let workers = ctx . worker_registry . get_workers_filtered (
None , // any model
Some ( WorkerType ::Regular ) ,
Some ( ConnectionMode ::Http ) ,
false , // include all workers
) ;
2025-07-18 14:24:24 -07:00
2025-09-19 15:37:57 -04:00
// Update active workers gauge
RouterMetrics ::set_active_workers ( workers . len ( ) ) ;
2025-07-18 14:24:24 -07:00
2025-09-19 15:37:57 -04:00
// Get worker URLs for monitoring
let worker_urls : Vec < String > = workers . iter ( ) . map ( | w | w . url ( ) . to_string ( ) ) . collect ( ) ;
2025-07-30 20:58:48 +08:00
2025-08-08 09:20:22 -07:00
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
2025-09-04 00:35:51 -04:00
let circuit_breaker_config = ctx . router_config . effective_circuit_breaker_config ( ) ;
2025-08-08 09:20:22 -07:00
let core_cb_config = CircuitBreakerConfig {
failure_threshold : circuit_breaker_config . failure_threshold ,
success_threshold : circuit_breaker_config . success_threshold ,
2025-08-11 05:53:26 -07:00
timeout_duration : Duration ::from_secs ( circuit_breaker_config . timeout_duration_secs ) ,
window_duration : Duration ::from_secs ( circuit_breaker_config . window_duration_secs ) ,
2025-08-08 09:20:22 -07:00
} ;
2025-09-19 23:54:40 -04:00
// Cache-aware policies are initialized in WorkerInitializer
2025-07-18 14:24:24 -07:00
// Setup load monitoring for PowerOfTwo policy
let ( tx , rx ) = tokio ::sync ::watch ::channel ( HashMap ::new ( ) ) ;
let worker_loads = Arc ::new ( rx ) ;
2025-09-19 23:54:40 -04:00
// Get default policy to check if we need load monitoring
let default_policy = ctx . policy_registry . get_default_policy ( ) ;
2025-09-12 19:18:27 -04:00
// Check if default policy is power_of_two for load monitoring
let load_monitor_handle = if default_policy . name ( ) = = " power_of_two " {
2025-07-18 14:24:24 -07:00
let monitor_urls = worker_urls . clone ( ) ;
2025-09-22 03:28:38 +08:00
let monitor_api_keys = monitor_urls
. iter ( )
. map ( | url | {
ctx . worker_registry
. get_by_url ( url )
. and_then ( | w | w . api_key ( ) . clone ( ) )
} )
. collect ::< Vec < Option < String > > > ( ) ;
2025-09-04 00:35:51 -04:00
let monitor_interval = ctx . router_config . worker_startup_check_interval_secs ;
2025-09-12 19:18:27 -04:00
let policy_clone = default_policy . clone ( ) ;
2025-09-04 00:35:51 -04:00
let client_clone = ctx . client . clone ( ) ;
2025-07-18 14:24:24 -07:00
Some ( Arc ::new ( tokio ::spawn ( async move {
2025-08-02 19:16:47 -07:00
Self ::monitor_worker_loads (
monitor_urls ,
2025-09-22 03:28:38 +08:00
monitor_api_keys ,
2025-08-02 19:16:47 -07:00
tx ,
monitor_interval ,
policy_clone ,
client_clone ,
)
. await ;
2025-07-18 14:24:24 -07:00
} ) ) )
} else {
None
} ;
Ok ( Router {
2025-09-12 19:18:27 -04:00
worker_registry : ctx . worker_registry . clone ( ) ,
policy_registry : ctx . policy_registry . clone ( ) ,
2025-09-04 00:35:51 -04:00
client : ctx . client . clone ( ) ,
worker_startup_timeout_secs : ctx . router_config . worker_startup_timeout_secs ,
worker_startup_check_interval_secs : ctx
. router_config
. worker_startup_check_interval_secs ,
dp_aware : ctx . router_config . dp_aware ,
api_key : ctx . router_config . api_key . clone ( ) ,
retry_config : ctx . router_config . effective_retry_config ( ) ,
2025-08-08 09:20:22 -07:00
circuit_breaker_config : core_cb_config ,
2025-07-18 14:24:24 -07:00
_worker_loads : worker_loads ,
_load_monitor_handle : load_monitor_handle ,
} )
}
/// Get the current list of worker URLs
pub fn get_worker_urls ( & self ) -> Vec < String > {
2025-09-12 19:18:27 -04:00
self . worker_registry . get_all_urls ( )
}
/// Get worker URLs for a specific model
pub fn get_worker_urls_for_model ( & self , model_id : Option < & str > ) -> Vec < String > {
let workers = match model_id {
Some ( model ) = > self . worker_registry . get_by_model_fast ( model ) ,
None = > self . worker_registry . get_all ( ) ,
} ;
workers . iter ( ) . map ( | w | w . url ( ) . to_string ( ) ) . collect ( )
2025-07-18 14:24:24 -07:00
}
2025-08-11 21:37:36 -07:00
pub async fn wait_for_healthy_workers (
2025-07-18 14:24:24 -07:00
worker_urls : & [ String ] ,
2025-09-02 04:57:04 +02:00
worker_startup_timeout_secs : u64 ,
worker_startup_check_interval_secs : u64 ,
2025-07-18 14:24:24 -07:00
) -> Result < ( ) , String > {
2025-08-04 20:42:07 -07:00
if worker_urls . is_empty ( ) {
return Err (
" Timeout waiting for workers to become healthy: no workers provided " . to_string ( ) ,
) ;
}
2025-08-11 21:37:36 -07:00
// Perform health check asynchronously
2025-09-02 04:57:04 +02:00
Self ::wait_for_healthy_workers_async (
worker_urls ,
worker_startup_timeout_secs ,
worker_startup_check_interval_secs ,
)
. await
2025-08-11 21:37:36 -07:00
}
async fn wait_for_healthy_workers_async (
worker_urls : & [ String ] ,
2025-09-02 04:57:04 +02:00
worker_startup_timeout_secs : u64 ,
worker_startup_check_interval_secs : u64 ,
2025-08-11 21:37:36 -07:00
) -> Result < ( ) , String > {
info! (
" Waiting for {} workers to become healthy (timeout: {}s) " ,
worker_urls . len ( ) ,
2025-09-02 04:57:04 +02:00
worker_startup_timeout_secs
2025-08-11 21:37:36 -07:00
) ;
2025-07-18 14:24:24 -07:00
let start_time = std ::time ::Instant ::now ( ) ;
2025-08-11 21:37:36 -07:00
let client = reqwest ::Client ::builder ( )
. timeout ( Duration ::from_secs ( 2 ) )
2025-07-18 14:24:24 -07:00
. build ( )
. map_err ( | e | format! ( " Failed to create HTTP client: {} " , e ) ) ? ;
loop {
2025-09-02 04:57:04 +02:00
if start_time . elapsed ( ) > Duration ::from_secs ( worker_startup_timeout_secs ) {
2025-07-18 14:24:24 -07:00
error! (
" Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value " ,
2025-09-02 04:57:04 +02:00
worker_startup_timeout_secs , worker_urls
2025-07-18 14:24:24 -07:00
) ;
return Err ( format! (
" Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value " ,
2025-09-02 04:57:04 +02:00
worker_startup_timeout_secs , worker_urls
2025-07-18 14:24:24 -07:00
) ) ;
}
2025-08-11 21:37:36 -07:00
// Perform all health checks concurrently
let mut health_checks = Vec ::new ( ) ;
for url in worker_urls {
let client_clone = client . clone ( ) ;
let url_clone = url . clone ( ) ;
let check_health = tokio ::spawn ( async move {
let health_url = format! ( " {} /health " , url_clone ) ;
match client_clone . get ( & health_url ) . send ( ) . await {
Ok ( res ) = > {
if res . status ( ) . is_success ( ) {
None
} else {
Some ( ( url_clone , format! ( " status: {} " , res . status ( ) ) ) )
}
}
Err ( _ ) = > Some ( ( url_clone , " not ready " . to_string ( ) ) ) ,
}
} ) ;
health_checks . push ( check_health ) ;
}
// Wait for all health checks to complete
let results = futures ::future ::join_all ( health_checks ) . await ;
2025-07-18 14:24:24 -07:00
let mut all_healthy = true ;
let mut unhealthy_workers = Vec ::new ( ) ;
2025-08-11 21:37:36 -07:00
for result in results {
match result {
Ok ( None ) = > {
// Worker is healthy
2025-07-18 14:24:24 -07:00
}
2025-08-11 21:37:36 -07:00
Ok ( Some ( ( url , reason ) ) ) = > {
2025-07-18 14:24:24 -07:00
all_healthy = false ;
2025-08-11 21:37:36 -07:00
unhealthy_workers . push ( ( url , reason ) ) ;
}
Err ( e ) = > {
all_healthy = false ;
unhealthy_workers
. push ( ( " unknown " . to_string ( ) , format! ( " task error: {} " , e ) ) ) ;
2025-07-18 14:24:24 -07:00
}
}
}
if all_healthy {
2025-07-27 19:30:19 -07:00
info! ( " All {} workers are healthy " , worker_urls . len ( ) ) ;
2025-07-18 14:24:24 -07:00
return Ok ( ( ) ) ;
} else {
2025-07-27 19:30:19 -07:00
debug! (
2025-08-11 21:37:36 -07:00
" Waiting for {} workers to become healthy ({} unhealthy: {:?}) " ,
2025-07-27 19:30:19 -07:00
worker_urls . len ( ) ,
2025-08-11 21:37:36 -07:00
unhealthy_workers . len ( ) ,
unhealthy_workers
2025-07-27 19:30:19 -07:00
) ;
2025-09-02 04:57:04 +02:00
tokio ::time ::sleep ( Duration ::from_secs ( worker_startup_check_interval_secs ) ) . await ;
2025-07-18 14:24:24 -07:00
}
}
}
2025-07-30 20:58:48 +08:00
fn get_worker_dp_size ( worker_url : & str , api_key : & Option < String > ) -> Result < usize , String > {
let sync_client = reqwest ::blocking ::Client ::new ( ) ;
2025-08-15 11:01:21 -07:00
let mut req_builder = sync_client . get ( format! ( " {} /get_server_info " , worker_url ) ) ;
2025-07-30 20:58:48 +08:00
if let Some ( key ) = api_key {
req_builder = req_builder . bearer_auth ( key ) ;
}
match req_builder . send ( ) {
Ok ( res ) = > {
if res . status ( ) . is_success ( ) {
let server_info = res
. text ( )
. map_err ( | e | format! ( " failed to read text from response: {} " , e ) ) ? ;
let server_info : serde_json ::Value = serde_json ::from_str ( & server_info )
. map_err ( | e | format! ( " failed to decode JSON: {} " , e ) ) ? ;
let dp_size = server_info
. get ( " dp_size " )
. and_then ( | v | v . as_u64 ( ) )
. ok_or_else ( | | String ::from ( " dp_size not found or not an u64 " ) ) ? ;
Ok ( if dp_size > usize ::MAX as u64 {
return Err ( format! ( " dp_size is too large: {} " , dp_size ) ) ;
} else {
dp_size as usize
} )
} else {
Err ( format! ( " unexpected status code: {} " , res . status ( ) ) )
}
}
Err ( e ) = > Err ( format! ( " error response: {} " , e ) ) ,
}
}
// Given a list of workers, return a list of workers with dp_rank as suffix
fn get_dp_aware_workers (
worker_urls : & [ String ] ,
api_key : & Option < String > ,
) -> Result < Vec < String > , String > {
let mut dp_aware_workers : Vec < String > = Vec ::new ( ) ;
for url in worker_urls {
match Self ::get_worker_dp_size ( url , api_key ) {
Ok ( dp_size ) = > {
for i in 0 .. dp_size {
dp_aware_workers . push ( format! ( " {} @ {} " , url , i ) ) ;
}
}
Err ( e ) = > return Err ( format! ( " Failed to get DP size for {} : {} " , url , e ) ) ,
}
}
Ok ( dp_aware_workers )
}
2025-07-18 14:24:24 -07:00
fn select_first_worker ( & self ) -> Result < String , String > {
2025-09-12 19:18:27 -04:00
let workers = self . worker_registry . get_all ( ) ;
if workers . is_empty ( ) {
2025-07-18 14:24:24 -07:00
Err ( " No workers are available " . to_string ( ) )
} else {
2025-09-12 19:18:27 -04:00
Ok ( workers [ 0 ] . url ( ) . to_string ( ) )
}
}
#[ allow(dead_code) ]
fn select_first_worker_for_model ( & self , model_id : Option < & str > ) -> Result < String , String > {
let workers = match model_id {
Some ( model ) = > self . worker_registry . get_by_model_fast ( model ) ,
None = > self . worker_registry . get_all ( ) ,
} ;
if workers . is_empty ( ) {
Err ( format! (
" No workers are available for model: {:?} " ,
model_id
) )
} else {
Ok ( workers [ 0 ] . url ( ) . to_string ( ) )
2025-07-18 14:24:24 -07:00
}
}
2025-08-02 19:16:47 -07:00
pub async fn send_health_check ( & self , worker_url : & str ) -> Response {
2025-07-30 17:47:19 -07:00
let health_url = if self . dp_aware {
2025-07-30 20:58:48 +08:00
// Need to extract the URL from "http://host:port@dp_rank"
2025-07-30 17:47:19 -07:00
match Self ::extract_dp_rank ( worker_url ) {
Ok ( ( worker_url_prefix , _dp_rank ) ) = > worker_url_prefix ,
2025-07-30 20:58:48 +08:00
Err ( e ) = > {
2025-07-30 17:47:19 -07:00
error! ( " Failed to extract dp_rank for health check: {} " , e ) ;
return (
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Failed to extract dp_rank: {} " , e ) ,
)
. into_response ( ) ;
2025-07-30 20:58:48 +08:00
}
2025-07-30 17:47:19 -07:00
}
2025-07-30 20:58:48 +08:00
} else {
worker_url
} ;
2025-08-02 19:16:47 -07:00
let request_builder = self . client . get ( format! ( " {} /health " , health_url ) ) ;
2025-07-18 14:24:24 -07:00
let response = match request_builder . send ( ) . await {
Ok ( res ) = > {
2025-07-30 17:47:19 -07:00
let status = StatusCode ::from_u16 ( res . status ( ) . as_u16 ( ) )
. unwrap_or ( StatusCode ::INTERNAL_SERVER_ERROR ) ;
2025-07-18 14:24:24 -07:00
match res . bytes ( ) . await {
2025-07-30 17:47:19 -07:00
Ok ( body ) = > ( status , body ) . into_response ( ) ,
2025-07-27 19:30:19 -07:00
Err ( e ) = > {
error! (
2025-07-30 17:47:19 -07:00
worker_url = % health_url ,
2025-07-27 19:30:19 -07:00
error = % e ,
2025-07-30 17:47:19 -07:00
" Failed to read health response body "
2025-07-27 19:30:19 -07:00
) ;
2025-07-30 17:47:19 -07:00
(
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Failed to read response body: {} " , e ) ,
)
. into_response ( )
2025-07-27 19:30:19 -07:00
}
2025-07-18 14:24:24 -07:00
}
}
2025-07-27 19:30:19 -07:00
Err ( e ) = > {
error! (
2025-07-30 17:47:19 -07:00
worker_url = % health_url ,
2025-07-27 19:30:19 -07:00
error = % e ,
2025-07-30 17:47:19 -07:00
" Failed to send health request to worker "
2025-07-27 19:30:19 -07:00
) ;
2025-07-30 17:47:19 -07:00
(
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Failed to send request to worker {} : {} " , health_url , e ) ,
)
. into_response ( )
2025-07-27 19:30:19 -07:00
}
2025-07-18 14:24:24 -07:00
} ;
2025-07-30 17:47:19 -07:00
// Don't record metrics for health checks
2025-07-18 14:24:24 -07:00
response
}
2025-07-30 17:47:19 -07:00
// Helper method to proxy GET requests to the first available worker
2025-08-02 19:16:47 -07:00
async fn proxy_get_request ( & self , req : Request < Body > , endpoint : & str ) -> Response {
2025-08-28 12:07:06 -07:00
let headers = header_utils ::copy_request_headers ( & req ) ;
2025-07-30 17:47:19 -07:00
match self . select_first_worker ( ) {
Ok ( worker_url ) = > {
2025-08-02 19:16:47 -07:00
let mut request_builder = self . client . get ( format! ( " {} / {} " , worker_url , endpoint ) ) ;
2025-07-30 17:47:19 -07:00
for ( name , value ) in headers {
2025-08-08 13:10:14 -07:00
let name_lc = name . to_lowercase ( ) ;
if name_lc ! = " content-type " & & name_lc ! = " content-length " {
2025-07-30 17:47:19 -07:00
request_builder = request_builder . header ( name , value ) ;
}
}
2025-07-18 14:24:24 -07:00
2025-07-30 17:47:19 -07:00
match request_builder . send ( ) . await {
Ok ( res ) = > {
let status = StatusCode ::from_u16 ( res . status ( ) . as_u16 ( ) )
. unwrap_or ( StatusCode ::INTERNAL_SERVER_ERROR ) ;
2025-08-15 11:01:47 -07:00
// Preserve headers from backend
let response_headers =
header_utils ::preserve_response_headers ( res . headers ( ) ) ;
2025-07-30 17:47:19 -07:00
match res . bytes ( ) . await {
2025-08-15 11:01:47 -07:00
Ok ( body ) = > {
let mut response = Response ::new ( axum ::body ::Body ::from ( body ) ) ;
* response . status_mut ( ) = status ;
* response . headers_mut ( ) = response_headers ;
response
}
2025-07-30 17:47:19 -07:00
Err ( e ) = > (
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Failed to read response: {} " , e ) ,
)
. into_response ( ) ,
2025-07-18 14:24:24 -07:00
}
}
2025-07-30 17:47:19 -07:00
Err ( e ) = > (
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Request failed: {} " , e ) ,
)
. into_response ( ) ,
2025-07-18 14:24:24 -07:00
}
}
2025-07-30 17:47:19 -07:00
Err ( e ) = > ( StatusCode ::SERVICE_UNAVAILABLE , e ) . into_response ( ) ,
2025-07-18 14:24:24 -07:00
}
}
2025-09-12 19:18:27 -04:00
/// Select worker for a specific model considering circuit breaker state
fn select_worker_for_model (
& self ,
model_id : Option < & str > ,
text : Option < & str > ,
) -> Option < Arc < dyn Worker > > {
// Get workers for the specified model (O(1) lookup if model_id is provided)
let workers = match model_id {
Some ( model ) = > self . worker_registry . get_by_model_fast ( model ) ,
None = > self . worker_registry . get_all ( ) ,
} ;
let available : Vec < Arc < dyn Worker > > = workers
2025-08-10 21:19:30 -07:00
. iter ( )
. filter ( | w | w . is_available ( ) )
2025-09-12 19:18:27 -04:00
. cloned ( )
2025-08-10 21:19:30 -07:00
. collect ( ) ;
if available . is_empty ( ) {
return None ;
}
2025-09-12 19:18:27 -04:00
// Get the appropriate policy for this model
let policy = match model_id {
Some ( model ) = > self . policy_registry . get_policy_or_default ( model ) ,
None = > self . policy_registry . get_default_policy ( ) ,
} ;
let idx = policy . select_worker ( & available , text ) ? ;
Some ( available [ idx ] . clone ( ) )
2025-08-10 21:19:30 -07:00
}
2025-08-18 18:07:58 -07:00
pub async fn route_typed_request < T : GenerationRequest + serde ::Serialize + Clone > (
2025-07-18 14:24:24 -07:00
& self ,
2025-07-30 17:47:19 -07:00
headers : Option < & HeaderMap > ,
2025-07-18 14:24:24 -07:00
typed_req : & T ,
route : & str ,
2025-09-12 19:18:27 -04:00
model_id : Option < & str > ,
2025-07-30 17:47:19 -07:00
) -> Response {
2025-07-18 14:24:24 -07:00
let start = Instant ::now ( ) ;
2025-08-10 21:19:30 -07:00
let is_stream = typed_req . is_stream ( ) ;
let text = typed_req . extract_text_for_routing ( ) ;
let response = RetryExecutor ::execute_response_with_retry (
& self . retry_config ,
// operation per attempt
| _ : u32 | async {
2025-09-12 19:18:27 -04:00
let worker = match self . select_worker_for_model ( model_id , Some ( & text ) ) {
2025-08-10 21:19:30 -07:00
Some ( w ) = > w ,
None = > {
RouterMetrics ::record_request_error ( route , " no_available_workers " ) ;
return (
StatusCode ::SERVICE_UNAVAILABLE ,
" No available workers (all circuits open or unhealthy) " ,
)
. into_response ( ) ;
}
} ;
2025-07-18 14:24:24 -07:00
2025-08-10 21:19:30 -07:00
// Optional load tracking for cache-aware policy
2025-09-12 19:18:27 -04:00
// Get the policy for this model to check if it's cache-aware
let policy = match model_id {
Some ( model ) = > self . policy_registry . get_policy_or_default ( model ) ,
None = > self . policy_registry . get_default_policy ( ) ,
} ;
let load_incremented = if policy . name ( ) = = " cache_aware " {
2025-08-10 21:19:30 -07:00
worker . increment_load ( ) ;
RouterMetrics ::set_running_requests ( worker . url ( ) , worker . load ( ) ) ;
true
2025-07-18 14:24:24 -07:00
} else {
false
} ;
2025-08-26 06:40:51 -07:00
// Keep a clone for potential cleanup on retry
let worker_for_cleanup = if load_incremented {
2025-09-15 00:07:23 -04:00
Some ( worker . clone ( ) )
2025-08-26 06:40:51 -07:00
} else {
None
} ;
2025-07-18 14:24:24 -07:00
let response = self
. send_typed_request (
2025-07-30 17:47:19 -07:00
headers ,
2025-07-18 14:24:24 -07:00
typed_req ,
route ,
2025-08-10 21:19:30 -07:00
worker . url ( ) ,
2025-07-18 14:24:24 -07:00
is_stream ,
load_incremented ,
)
. await ;
2025-08-10 21:19:30 -07:00
worker . record_outcome ( response . status ( ) . is_success ( ) ) ;
2025-08-26 06:40:51 -07:00
// For retryable failures, we need to decrement load since send_typed_request
// won't have done it (it only decrements on success or non-retryable failures)
if is_retryable_status ( response . status ( ) ) & & load_incremented {
if let Some ( cleanup_worker ) = worker_for_cleanup {
cleanup_worker . decrement_load ( ) ;
RouterMetrics ::set_running_requests (
cleanup_worker . url ( ) ,
cleanup_worker . load ( ) ,
) ;
}
}
2025-08-10 21:19:30 -07:00
response
} ,
// should_retry predicate
2025-08-11 05:53:26 -07:00
| res , _attempt | is_retryable_status ( res . status ( ) ) ,
2025-08-10 21:19:30 -07:00
// on_backoff hook
| delay , attempt | {
RouterMetrics ::record_retry ( route ) ;
RouterMetrics ::record_retry_backoff_duration ( delay , attempt ) ;
} ,
// on_exhausted hook
| | RouterMetrics ::record_retries_exhausted ( route ) ,
2025-07-30 17:47:19 -07:00
)
2025-08-10 21:19:30 -07:00
. await ;
if response . status ( ) . is_success ( ) {
let duration = start . elapsed ( ) ;
RouterMetrics ::record_request ( route ) ;
RouterMetrics ::record_generate_duration ( duration ) ;
2025-08-11 05:53:26 -07:00
} else if ! is_retryable_status ( response . status ( ) ) {
2025-08-10 21:19:30 -07:00
RouterMetrics ::record_request_error ( route , " non_retryable_error " ) ;
2025-07-18 14:24:24 -07:00
}
2025-08-10 21:19:30 -07:00
response
2025-07-18 14:24:24 -07:00
}
2025-09-12 16:19:38 -07:00
// Helper: return base worker URL (strips DP suffix when enabled)
fn worker_base_url ( & self , worker_url : & str ) -> String {
if self . dp_aware {
if let Ok ( ( prefix , _ ) ) = Self ::extract_dp_rank ( worker_url ) {
return prefix . to_string ( ) ;
}
}
worker_url . to_string ( )
}
// Generic simple routing for GET/POST without JSON body
async fn route_simple_request (
& self ,
headers : Option < & HeaderMap > ,
endpoint : & str ,
method : Method ,
) -> Response {
// TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
// Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
let worker_urls = self . get_worker_urls ( ) ;
if worker_urls . is_empty ( ) {
return ( StatusCode ::SERVICE_UNAVAILABLE , " No available workers " ) . into_response ( ) ;
}
let mut last_response : Option < Response > = None ;
for worker_url in worker_urls {
let base = self . worker_base_url ( & worker_url ) ;
let url = format! ( " {} / {} " , base , endpoint ) ;
let mut request_builder = match method {
Method ::GET = > self . client . get ( url ) ,
Method ::POST = > self . client . post ( url ) ,
_ = > {
return (
StatusCode ::METHOD_NOT_ALLOWED ,
" Unsupported method for simple routing " ,
)
. into_response ( )
}
} ;
if let Some ( hdrs ) = headers {
for ( name , value ) in hdrs {
let name_lc = name . as_str ( ) . to_lowercase ( ) ;
if name_lc ! = " content-type " & & name_lc ! = " content-length " {
request_builder = request_builder . header ( name , value ) ;
}
}
}
match request_builder . send ( ) . await {
Ok ( res ) = > {
let status = StatusCode ::from_u16 ( res . status ( ) . as_u16 ( ) )
. unwrap_or ( StatusCode ::INTERNAL_SERVER_ERROR ) ;
let response_headers = header_utils ::preserve_response_headers ( res . headers ( ) ) ;
match res . bytes ( ) . await {
Ok ( body ) = > {
let mut response = Response ::new ( axum ::body ::Body ::from ( body ) ) ;
* response . status_mut ( ) = status ;
* response . headers_mut ( ) = response_headers ;
if status . is_success ( ) {
return response ;
}
last_response = Some ( response ) ;
}
Err ( e ) = > {
last_response = Some (
(
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Failed to read response: {} " , e ) ,
)
. into_response ( ) ,
) ;
}
}
}
Err ( e ) = > {
last_response = Some (
(
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Request failed: {} " , e ) ,
)
. into_response ( ) ,
) ;
}
}
}
last_response
. unwrap_or_else ( | | ( StatusCode ::BAD_GATEWAY , " No worker response " ) . into_response ( ) )
}
// Route a GET request with provided headers to a specific endpoint
async fn route_get_request ( & self , headers : Option < & HeaderMap > , endpoint : & str ) -> Response {
self . route_simple_request ( headers , endpoint , Method ::GET )
. await
}
// Route a POST request with empty body to a specific endpoint
async fn route_post_empty_request (
& self ,
headers : Option < & HeaderMap > ,
endpoint : & str ,
) -> Response {
self . route_simple_request ( headers , endpoint , Method ::POST )
. await
}
2025-07-30 20:58:48 +08:00
// TODO (rui): Better accommodate to the Worker abstraction
fn extract_dp_rank ( worker_url : & str ) -> Result < ( & str , usize ) , String > {
let parts : Vec < & str > = worker_url . split ( '@' ) . collect ( ) ;
if parts . len ( ) ! = 2 {
return Err ( format! ( " invalid worker_url format: {} " , worker_url ) ) ;
}
// Parse the second part (dp_rank) into an integer
match parts [ 1 ] . parse ::< usize > ( ) {
Ok ( dp_rank ) = > Ok ( ( parts [ 0 ] , dp_rank ) ) ,
Err ( _ ) = > Err ( format! (
" failed to parse dp_rank from worker_url: {} " ,
worker_url
) ) ,
}
}
2025-07-18 14:24:24 -07:00
// Send typed request directly without conversion
async fn send_typed_request < T : serde ::Serialize > (
& self ,
2025-07-30 17:47:19 -07:00
headers : Option < & HeaderMap > ,
2025-07-18 14:24:24 -07:00
typed_req : & T ,
route : & str ,
worker_url : & str ,
is_stream : bool ,
load_incremented : bool , // Whether load was incremented for this request
2025-07-30 17:47:19 -07:00
) -> Response {
2025-07-30 20:58:48 +08:00
let mut request_builder = if self . dp_aware {
let ( worker_url_prefix , dp_rank ) = match Self ::extract_dp_rank ( worker_url ) {
Ok ( tup ) = > tup ,
Err ( e ) = > {
error! ( " Failed to extract dp_rank: {} " , e ) ;
2025-07-30 17:47:19 -07:00
return (
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Failed to extract dp_rank: {} " , e ) ,
)
. into_response ( ) ;
2025-07-30 20:58:48 +08:00
}
} ;
// Parse the request body
let mut json_val = match serde_json ::to_value ( typed_req ) {
Ok ( j ) = > j ,
Err ( e ) = > {
2025-07-30 17:47:19 -07:00
return (
StatusCode ::BAD_REQUEST ,
format! ( " Convert into serde_json::Value failed: {} " , e ) ,
)
. into_response ( ) ;
2025-07-30 20:58:48 +08:00
}
} ;
// Insert the data_parallel_rank field
if let Some ( map ) = json_val . as_object_mut ( ) {
map . insert (
String ::from ( " data_parallel_rank " ) ,
serde_json ::json! ( dp_rank ) ,
) ;
debug! (
" Modified request body: {} " ,
serde_json ::to_string ( & json_val ) . unwrap_or ( String ::from ( " ERR " ) )
) ;
} else {
2025-07-30 17:47:19 -07:00
return (
StatusCode ::BAD_REQUEST ,
" Failed to insert the data_parallel_rank field into the request body " ,
)
. into_response ( ) ;
2025-07-30 20:58:48 +08:00
}
2025-08-02 19:16:47 -07:00
self . client
2025-07-30 20:58:48 +08:00
. post ( format! ( " {} {} " , worker_url_prefix , route ) )
. json ( & json_val )
} else {
2025-08-02 19:16:47 -07:00
self . client
2025-07-30 20:58:48 +08:00
. post ( format! ( " {} {} " , worker_url , route ) )
. json ( typed_req ) // Use json() directly with typed request
} ;
2025-07-18 14:24:24 -07:00
2025-07-30 17:47:19 -07:00
// Copy all headers from original request if provided
if let Some ( headers ) = headers {
for ( name , value ) in headers {
// Skip Content-Type and Content-Length as .json() sets them
2025-08-08 13:10:14 -07:00
if * name ! = CONTENT_TYPE & & * name ! = CONTENT_LENGTH {
2025-07-30 17:47:19 -07:00
request_builder = request_builder . header ( name , value ) ;
}
2025-07-18 14:24:24 -07:00
}
}
let res = match request_builder . send ( ) . await {
Ok ( res ) = > res ,
Err ( e ) = > {
2025-07-27 19:30:19 -07:00
error! (
" Failed to send typed request worker_url={} route={} error={} " ,
worker_url , route , e
) ;
2025-07-18 14:24:24 -07:00
// Decrement load on error if it was incremented
if load_incremented {
2025-09-12 19:18:27 -04:00
if let Some ( worker ) = self . worker_registry . get_by_url ( worker_url ) {
worker . decrement_load ( ) ;
RouterMetrics ::set_running_requests ( worker_url , worker . load ( ) ) ;
2025-07-18 14:24:24 -07:00
}
}
2025-07-30 17:47:19 -07:00
return (
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Request failed: {} " , e ) ,
)
. into_response ( ) ;
2025-07-18 14:24:24 -07:00
}
} ;
2025-07-30 17:47:19 -07:00
let status = StatusCode ::from_u16 ( res . status ( ) . as_u16 ( ) )
. unwrap_or ( StatusCode ::INTERNAL_SERVER_ERROR ) ;
2025-07-18 14:24:24 -07:00
if ! is_stream {
2025-08-15 11:01:47 -07:00
// For non-streaming requests, preserve headers
2025-08-28 12:07:06 -07:00
let response_headers = header_utils ::preserve_response_headers ( res . headers ( ) ) ;
2025-08-15 11:01:47 -07:00
2025-07-18 14:24:24 -07:00
let response = match res . bytes ( ) . await {
2025-08-15 11:01:47 -07:00
Ok ( body ) = > {
let mut response = Response ::new ( axum ::body ::Body ::from ( body ) ) ;
* response . status_mut ( ) = status ;
* response . headers_mut ( ) = response_headers ;
response
}
2025-07-18 14:24:24 -07:00
Err ( e ) = > {
2025-08-26 06:40:51 -07:00
// IMPORTANT: Decrement load on error before returning
if load_incremented {
2025-09-12 19:18:27 -04:00
if let Some ( worker ) = self . worker_registry . get_by_url ( worker_url ) {
worker . decrement_load ( ) ;
RouterMetrics ::set_running_requests ( worker_url , worker . load ( ) ) ;
2025-08-26 06:40:51 -07:00
}
}
2025-07-18 14:24:24 -07:00
let error_msg = format! ( " Failed to get response body: {} " , e ) ;
2025-07-30 17:47:19 -07:00
( StatusCode ::INTERNAL_SERVER_ERROR , error_msg ) . into_response ( )
2025-07-18 14:24:24 -07:00
}
} ;
// Decrement load counter for non-streaming requests if it was incremented
2025-08-26 06:40:51 -07:00
if load_incremented {
2025-09-12 19:18:27 -04:00
if let Some ( worker ) = self . worker_registry . get_by_url ( worker_url ) {
worker . decrement_load ( ) ;
RouterMetrics ::set_running_requests ( worker_url , worker . load ( ) ) ;
2025-07-18 14:24:24 -07:00
}
}
response
} else if load_incremented {
// For streaming with load tracking, we need to manually decrement when done
2025-09-12 19:18:27 -04:00
let registry = Arc ::clone ( & self . worker_registry ) ;
2025-07-18 14:24:24 -07:00
let worker_url = worker_url . to_string ( ) ;
2025-08-15 11:01:47 -07:00
// Preserve headers for streaming response
let mut response_headers = header_utils ::preserve_response_headers ( res . headers ( ) ) ;
// Ensure we set the correct content-type for SSE
response_headers . insert ( CONTENT_TYPE , HeaderValue ::from_static ( " text/event-stream " ) ) ;
2025-07-30 17:47:19 -07:00
let stream = res . bytes_stream ( ) ;
let ( tx , rx ) = tokio ::sync ::mpsc ::unbounded_channel ( ) ;
// Spawn task to forward stream and detect completion
tokio ::spawn ( async move {
let mut stream = stream ;
2025-08-08 13:10:14 -07:00
let mut decremented = false ;
2025-07-30 17:47:19 -07:00
while let Some ( chunk ) = stream . next ( ) . await {
match chunk {
Ok ( bytes ) = > {
// Check for stream end marker
if bytes
. as_ref ( )
. windows ( 12 )
. any ( | window | window = = b " data: [DONE] " )
{
2025-09-12 19:18:27 -04:00
if let Some ( worker ) = registry . get_by_url ( & worker_url ) {
worker . decrement_load ( ) ;
RouterMetrics ::set_running_requests ( & worker_url , worker . load ( ) ) ;
decremented = true ;
2025-07-18 14:24:24 -07:00
}
}
2025-07-30 17:47:19 -07:00
if tx . send ( Ok ( bytes ) ) . is_err ( ) {
break ;
}
}
Err ( e ) = > {
let _ = tx . send ( Err ( format! ( " Stream error: {} " , e ) ) ) ;
break ;
}
}
}
2025-08-08 13:10:14 -07:00
if ! decremented {
2025-09-12 19:18:27 -04:00
if let Some ( worker ) = registry . get_by_url ( & worker_url ) {
worker . decrement_load ( ) ;
RouterMetrics ::set_running_requests ( & worker_url , worker . load ( ) ) ;
2025-08-08 13:10:14 -07:00
}
}
2025-07-30 17:47:19 -07:00
} ) ;
let stream = UnboundedReceiverStream ::new ( rx ) ;
let body = Body ::from_stream ( stream ) ;
let mut response = Response ::new ( body ) ;
* response . status_mut ( ) = status ;
2025-08-15 11:01:47 -07:00
* response . headers_mut ( ) = response_headers ;
2025-07-30 17:47:19 -07:00
response
2025-07-18 14:24:24 -07:00
} else {
// For requests without load tracking, just stream
2025-08-15 11:01:47 -07:00
// Preserve headers for streaming response
let mut response_headers = header_utils ::preserve_response_headers ( res . headers ( ) ) ;
// Ensure we set the correct content-type for SSE
response_headers . insert ( CONTENT_TYPE , HeaderValue ::from_static ( " text/event-stream " ) ) ;
2025-07-30 17:47:19 -07:00
let stream = res . bytes_stream ( ) ;
let ( tx , rx ) = tokio ::sync ::mpsc ::unbounded_channel ( ) ;
// Spawn task to forward stream
tokio ::spawn ( async move {
let mut stream = stream ;
while let Some ( chunk ) = stream . next ( ) . await {
match chunk {
Ok ( bytes ) = > {
if tx . send ( Ok ( bytes ) ) . is_err ( ) {
break ;
}
}
Err ( e ) = > {
let _ = tx . send ( Err ( format! ( " Stream error: {} " , e ) ) ) ;
break ;
}
}
}
} ) ;
let stream = UnboundedReceiverStream ::new ( rx ) ;
let body = Body ::from_stream ( stream ) ;
let mut response = Response ::new ( body ) ;
* response . status_mut ( ) = status ;
2025-08-15 11:01:47 -07:00
* response . headers_mut ( ) = response_headers ;
2025-07-30 17:47:19 -07:00
response
2025-07-18 14:24:24 -07:00
}
}
2025-09-22 03:28:38 +08:00
pub async fn add_worker (
& self ,
worker_url : & str ,
api_key : & Option < String > ,
) -> Result < String , String > {
2025-07-18 14:24:24 -07:00
let start_time = std ::time ::Instant ::now ( ) ;
let client = reqwest ::Client ::builder ( )
2025-09-02 04:57:04 +02:00
. timeout ( Duration ::from_secs ( self . worker_startup_timeout_secs ) )
2025-07-18 14:24:24 -07:00
. build ( )
. map_err ( | e | format! ( " Failed to create HTTP client: {} " , e ) ) ? ;
loop {
2025-09-02 04:57:04 +02:00
if start_time . elapsed ( ) > Duration ::from_secs ( self . worker_startup_timeout_secs ) {
2025-07-18 14:24:24 -07:00
error! (
" Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value " ,
2025-09-02 04:57:04 +02:00
self . worker_startup_timeout_secs , worker_url
2025-07-18 14:24:24 -07:00
) ;
return Err ( format! (
" Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value " ,
2025-09-02 04:57:04 +02:00
self . worker_startup_timeout_secs , worker_url
2025-07-18 14:24:24 -07:00
) ) ;
}
2025-08-15 11:01:21 -07:00
match client . get ( format! ( " {} /health " , worker_url ) ) . send ( ) . await {
2025-07-18 14:24:24 -07:00
Ok ( res ) = > {
if res . status ( ) . is_success ( ) {
2025-07-30 20:58:48 +08:00
if self . dp_aware {
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
let url_vec = vec! [ String ::from ( worker_url ) ] ;
2025-09-22 03:28:38 +08:00
let dp_url_vec = Self ::get_dp_aware_workers ( & url_vec , api_key )
2025-07-30 20:58:48 +08:00
. map_err ( | e | format! ( " Failed to get dp-aware workers: {} " , e ) ) ? ;
let mut worker_added : bool = false ;
for dp_url in & dp_url_vec {
2025-09-12 19:18:27 -04:00
if self . worker_registry . get_by_url ( dp_url ) . is_some ( ) {
2025-07-30 20:58:48 +08:00
warn! ( " Worker {} already exists " , dp_url ) ;
continue ;
}
info! ( " Added worker: {} " , dp_url ) ;
2025-09-12 19:18:27 -04:00
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
2025-09-22 03:28:38 +08:00
let new_worker_builder =
BasicWorkerBuilder ::new ( dp_url . to_string ( ) )
. worker_type ( WorkerType ::Regular )
. circuit_breaker_config (
self . circuit_breaker_config . clone ( ) ,
) ;
let new_worker = if let Some ( api_key ) = api_key {
new_worker_builder . api_key ( api_key ) . build ( )
} else {
new_worker_builder . build ( )
} ;
2025-09-12 19:18:27 -04:00
let worker_arc = Arc ::new ( new_worker ) ;
self . worker_registry . register ( worker_arc . clone ( ) ) ;
// Notify PolicyRegistry about the new worker
let model_id = worker_arc . model_id ( ) ;
2025-09-19 23:54:40 -04:00
self . policy_registry . on_worker_added ( model_id , None ) ;
// Initialize cache-aware policy if applicable
let model_workers =
self . worker_registry . get_by_model_fast ( model_id ) ;
self . policy_registry
. init_cache_aware_policy ( model_id , & model_workers ) ;
2025-09-12 19:18:27 -04:00
2025-07-30 20:58:48 +08:00
worker_added = true ;
}
if ! worker_added {
return Err ( format! ( " No worker added for {} " , worker_url ) ) ;
}
} else {
2025-09-12 19:18:27 -04:00
if self . worker_registry . get_by_url ( worker_url ) . is_some ( ) {
2025-07-30 20:58:48 +08:00
return Err ( format! ( " Worker {} already exists " , worker_url ) ) ;
}
info! ( " Added worker: {} " , worker_url ) ;
2025-07-18 14:24:24 -07:00
2025-09-12 19:18:27 -04:00
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
2025-09-22 03:28:38 +08:00
let new_worker_builder =
BasicWorkerBuilder ::new ( worker_url . to_string ( ) )
. worker_type ( WorkerType ::Regular )
. circuit_breaker_config ( self . circuit_breaker_config . clone ( ) ) ;
let new_worker = if let Some ( api_key ) = api_key {
new_worker_builder . api_key ( api_key ) . build ( )
} else {
new_worker_builder . build ( )
} ;
2025-09-12 19:18:27 -04:00
let worker_arc = Arc ::new ( new_worker ) ;
self . worker_registry . register ( worker_arc . clone ( ) ) ;
// Notify PolicyRegistry about the new worker
let model_id = worker_arc . model_id ( ) ;
2025-09-19 23:54:40 -04:00
self . policy_registry . on_worker_added ( model_id , None ) ;
// Initialize cache-aware policy if applicable
let model_workers = self . worker_registry . get_by_model_fast ( model_id ) ;
self . policy_registry
. init_cache_aware_policy ( model_id , & model_workers ) ;
2025-07-18 14:24:24 -07:00
}
2025-09-12 19:18:27 -04:00
RouterMetrics ::set_active_workers ( self . worker_registry . get_all ( ) . len ( ) ) ;
2025-07-18 14:24:24 -07:00
return Ok ( format! ( " Successfully added worker: {} " , worker_url ) ) ;
} else {
2025-07-27 19:30:19 -07:00
debug! (
" Worker {} health check pending - status: {} " ,
2025-07-18 14:24:24 -07:00
worker_url ,
res . status ( )
) ;
// if the url does not have http or https prefix, warn users
if ! worker_url . starts_with ( " http:// " ) & & ! worker_url . starts_with ( " https:// " )
{
warn! ( " The worker url {} does not have http or https prefix. Please add the prefix to the url. " , worker_url ) ;
}
2025-09-02 04:57:04 +02:00
tokio ::time ::sleep ( Duration ::from_secs (
self . worker_startup_check_interval_secs ,
) )
. await ;
2025-07-18 14:24:24 -07:00
continue ;
}
}
Err ( e ) = > {
2025-07-27 19:30:19 -07:00
debug! ( " Worker {} health check pending - error: {} " , worker_url , e ) ;
2025-07-18 14:24:24 -07:00
// if the url does not have http or https prefix, warn users
if ! worker_url . starts_with ( " http:// " ) & & ! worker_url . starts_with ( " https:// " ) {
warn! ( " The worker url {} does not have http or https prefix. Please add the prefix to the url. " , worker_url ) ;
}
2025-09-02 04:57:04 +02:00
tokio ::time ::sleep ( Duration ::from_secs (
self . worker_startup_check_interval_secs ,
) )
. await ;
2025-07-18 14:24:24 -07:00
continue ;
}
}
}
}
pub fn remove_worker ( & self , worker_url : & str ) {
2025-07-30 20:58:48 +08:00
if self . dp_aware {
// remove dp-aware workers in a prefix-matching fashion
// without contacting the remote worker
let mut removed_workers : Vec < String > = Vec ::new ( ) ;
let worker_url_prefix = format! ( " {} @ " , worker_url ) ;
2025-09-12 19:18:27 -04:00
// Find and remove all workers with matching prefix
let all_workers = self . worker_registry . get_all ( ) ;
for w in all_workers . iter ( ) {
if w . url ( ) . starts_with ( & worker_url_prefix ) {
// Get model_id before removing
let model_id = w . model_id ( ) . to_string ( ) ;
if self . worker_registry . remove_by_url ( w . url ( ) ) . is_some ( ) {
info! ( " Removed worker: {} " , w . url ( ) ) ;
removed_workers . push ( w . url ( ) . to_string ( ) ) ;
2025-07-30 20:58:48 +08:00
2025-09-12 19:18:27 -04:00
// Notify PolicyRegistry about the removed worker
self . policy_registry . on_worker_removed ( & model_id ) ;
2025-07-30 20:58:48 +08:00
} else {
2025-09-12 19:18:27 -04:00
warn! ( " Worker {} not found, skipping removal " , w . url ( ) ) ;
2025-07-30 20:58:48 +08:00
}
}
}
2025-09-12 19:18:27 -04:00
RouterMetrics ::set_active_workers ( self . worker_registry . get_all ( ) . len ( ) ) ;
for dp_url in removed_workers . iter ( ) {
if let Some ( worker ) = self . worker_registry . get_by_url ( dp_url ) {
let model_id = worker . model_id ( ) ;
2025-09-19 23:54:40 -04:00
self . policy_registry
. remove_worker_from_cache_aware ( model_id , dp_url ) ;
2025-07-30 20:58:48 +08:00
}
}
} else {
2025-09-12 19:18:27 -04:00
// Get the worker first to extract model_id
let model_id = if let Some ( worker ) = self . worker_registry . get_by_url ( worker_url ) {
worker . model_id ( ) . to_string ( )
2025-07-30 20:58:48 +08:00
} else {
warn! ( " Worker {} not found, skipping removal " , worker_url ) ;
return ;
2025-09-12 19:18:27 -04:00
} ;
if self . worker_registry . remove_by_url ( worker_url ) . is_some ( ) {
info! ( " Removed worker: {} " , worker_url ) ;
// Notify PolicyRegistry about the removed worker
self . policy_registry . on_worker_removed ( & model_id ) ;
RouterMetrics ::set_active_workers ( self . worker_registry . get_all ( ) . len ( ) ) ;
2025-07-30 20:58:48 +08:00
}
2025-09-19 23:54:40 -04:00
self . policy_registry
. remove_worker_from_cache_aware ( & model_id , worker_url ) ;
2025-07-30 20:58:48 +08:00
}
}
2025-09-22 03:28:38 +08:00
async fn get_worker_load ( & self , worker_url : & str , api_key : & Option < String > ) -> Option < isize > {
2025-07-30 20:58:48 +08:00
let worker_url = if self . dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let ( worker_url_prefix , _dp_rank ) = match Self ::extract_dp_rank ( worker_url ) {
Ok ( tup ) = > tup ,
Err ( e ) = > {
error! ( " Failed to extract dp_rank: {} " , e ) ;
return None ;
}
} ;
worker_url_prefix
} else {
worker_url
} ;
2025-09-22 03:28:38 +08:00
let mut req_builder = self . client . get ( format! ( " {} /get_load " , worker_url ) ) ;
if let Some ( key ) = api_key {
req_builder = req_builder . bearer_auth ( key ) ;
}
match req_builder . send ( ) . await {
2025-07-18 14:24:24 -07:00
Ok ( res ) if res . status ( ) . is_success ( ) = > match res . bytes ( ) . await {
Ok ( bytes ) = > match serde_json ::from_slice ::< serde_json ::Value > ( & bytes ) {
Ok ( data ) = > data
. get ( " load " )
. and_then ( | v | v . as_i64 ( ) )
. map ( | v | v as isize ) ,
Err ( e ) = > {
debug! ( " Failed to parse load response from {}: {} " , worker_url , e ) ;
None
}
} ,
Err ( e ) = > {
debug! ( " Failed to read load response from {}: {} " , worker_url , e ) ;
None
}
} ,
Ok ( res ) = > {
debug! (
" Worker {} returned non-success status: {} " ,
worker_url ,
res . status ( )
) ;
None
}
Err ( e ) = > {
debug! ( " Failed to get load from {}: {} " , worker_url , e ) ;
None
}
}
}
// Background task to monitor worker loads
async fn monitor_worker_loads (
worker_urls : Vec < String > ,
2025-09-22 03:28:38 +08:00
worker_api_keys : Vec < Option < String > > ,
2025-07-18 14:24:24 -07:00
tx : tokio ::sync ::watch ::Sender < HashMap < String , isize > > ,
interval_secs : u64 ,
policy : Arc < dyn LoadBalancingPolicy > ,
2025-08-02 19:16:47 -07:00
client : Client ,
2025-07-18 14:24:24 -07:00
) {
let mut interval = tokio ::time ::interval ( Duration ::from_secs ( interval_secs ) ) ;
loop {
interval . tick ( ) . await ;
let mut loads = HashMap ::new ( ) ;
2025-09-22 03:28:38 +08:00
for ( url , api_key ) in worker_urls . iter ( ) . zip ( worker_api_keys . iter ( ) ) {
if let Some ( load ) = Self ::get_worker_load_static ( & client , url , api_key ) . await {
2025-07-18 14:24:24 -07:00
loads . insert ( url . clone ( ) , load ) ;
}
}
if ! loads . is_empty ( ) {
// Update policy with new loads
policy . update_loads ( & loads ) ;
// Send to watchers
if let Err ( e ) = tx . send ( loads ) {
error! ( " Failed to send load update: {} " , e ) ;
}
}
}
}
// Static version of get_worker_load for use in monitoring task
2025-09-22 03:28:38 +08:00
async fn get_worker_load_static (
client : & reqwest ::Client ,
worker_url : & str ,
api_key : & Option < String > ,
) -> Option < isize > {
2025-07-30 20:58:48 +08:00
let worker_url = if worker_url . contains ( " @ " ) {
// Need to extract the URL from "http://host:port@dp_rank"
let ( worker_url_prefix , _dp_rank ) = match Self ::extract_dp_rank ( worker_url ) {
Ok ( tup ) = > tup ,
Err ( e ) = > {
debug! ( " Failed to extract dp_rank: {} " , e ) ;
return None ;
}
} ;
worker_url_prefix
} else {
worker_url
} ;
2025-09-22 03:28:38 +08:00
let mut req_builder = client . get ( format! ( " {} /get_load " , worker_url ) ) ;
if let Some ( key ) = api_key {
req_builder = req_builder . bearer_auth ( key ) ;
}
match req_builder . send ( ) . await {
2025-07-18 14:24:24 -07:00
Ok ( res ) if res . status ( ) . is_success ( ) = > match res . bytes ( ) . await {
Ok ( bytes ) = > match serde_json ::from_slice ::< serde_json ::Value > ( & bytes ) {
Ok ( data ) = > data
. get ( " load " )
. and_then ( | v | v . as_i64 ( ) )
. map ( | v | v as isize ) ,
Err ( e ) = > {
debug! ( " Failed to parse load response from {}: {} " , worker_url , e ) ;
None
}
} ,
Err ( e ) = > {
debug! ( " Failed to read load response from {}: {} " , worker_url , e ) ;
None
}
} ,
Ok ( res ) = > {
debug! (
" Worker {} returned non-success status: {} " ,
worker_url ,
res . status ( )
) ;
None
}
Err ( e ) = > {
debug! ( " Failed to get load from {}: {} " , worker_url , e ) ;
None
}
}
}
2025-09-13 00:10:18 +08:00
async fn build_rerank_response (
req : & RerankRequest ,
response : Response ,
) -> anyhow ::Result < Response > {
let ( _ , response_body ) = response . into_parts ( ) ;
let body_bytes = to_bytes ( response_body , usize ::MAX ) . await ? ;
let rerank_results = serde_json ::from_slice ::< Vec < RerankResult > > ( & body_bytes ) ? ;
let mut rerank_response =
RerankResponse ::new ( rerank_results , req . model . clone ( ) , req . rid . clone ( ) ) ;
rerank_response . sort_by_score ( ) ;
if let Some ( top_k ) = req . top_k {
rerank_response . apply_top_k ( top_k ) ;
}
if ! req . return_documents {
rerank_response . drop_documents ( ) ;
}
Ok ( Json ( rerank_response ) . into_response ( ) )
}
2025-07-18 14:24:24 -07:00
}
use async_trait ::async_trait ;
#[ async_trait ]
impl WorkerManagement for Router {
2025-09-22 03:28:38 +08:00
async fn add_worker (
& self ,
worker_url : & str ,
api_key : & Option < String > ,
) -> Result < String , String > {
Router ::add_worker ( self , worker_url , api_key ) . await
2025-07-18 14:24:24 -07:00
}
fn remove_worker ( & self , worker_url : & str ) {
Router ::remove_worker ( self , worker_url )
}
fn get_worker_urls ( & self ) -> Vec < String > {
Router ::get_worker_urls ( self )
}
}
2025-07-30 17:47:19 -07:00
#[ async_trait ]
2025-07-18 14:24:24 -07:00
impl RouterTrait for Router {
fn as_any ( & self ) -> & dyn std ::any ::Any {
self
}
2025-08-02 19:16:47 -07:00
async fn health ( & self , _req : Request < Body > ) -> Response {
2025-09-12 19:18:27 -04:00
let workers = self . worker_registry . get_all ( ) ;
2025-07-30 17:47:19 -07:00
let unhealthy_servers : Vec < _ > = workers
. iter ( )
. filter ( | w | ! w . is_healthy ( ) )
. map ( | w | w . url ( ) . to_string ( ) )
. collect ( ) ;
2025-07-18 14:24:24 -07:00
2025-07-30 17:47:19 -07:00
if unhealthy_servers . is_empty ( ) {
( StatusCode ::OK , " All servers healthy " ) . into_response ( )
2025-07-18 14:24:24 -07:00
} else {
2025-07-30 17:47:19 -07:00
(
StatusCode ::SERVICE_UNAVAILABLE ,
format! ( " Unhealthy servers: {:?} " , unhealthy_servers ) ,
)
. into_response ( )
2025-07-18 14:24:24 -07:00
}
}
2025-08-02 19:16:47 -07:00
async fn health_generate ( & self , req : Request < Body > ) -> Response {
self . proxy_get_request ( req , " health_generate " ) . await
2025-07-18 14:24:24 -07:00
}
2025-08-02 19:16:47 -07:00
async fn get_server_info ( & self , req : Request < Body > ) -> Response {
self . proxy_get_request ( req , " get_server_info " ) . await
2025-07-18 14:24:24 -07:00
}
2025-08-02 19:16:47 -07:00
async fn get_models ( & self , req : Request < Body > ) -> Response {
self . proxy_get_request ( req , " v1/models " ) . await
2025-07-18 14:24:24 -07:00
}
2025-08-02 19:16:47 -07:00
async fn get_model_info ( & self , req : Request < Body > ) -> Response {
self . proxy_get_request ( req , " get_model_info " ) . await
2025-07-18 14:24:24 -07:00
}
async fn route_generate (
& self ,
2025-07-30 17:47:19 -07:00
headers : Option < & HeaderMap > ,
body : & GenerateRequest ,
2025-09-12 19:18:27 -04:00
model_id : Option < & str > ,
2025-07-30 17:47:19 -07:00
) -> Response {
2025-09-12 19:18:27 -04:00
self . route_typed_request ( headers , body , " /generate " , model_id )
. await
2025-07-18 14:24:24 -07:00
}
async fn route_chat (
& self ,
2025-07-30 17:47:19 -07:00
headers : Option < & HeaderMap > ,
body : & ChatCompletionRequest ,
2025-09-12 19:18:27 -04:00
model_id : Option < & str > ,
2025-07-30 17:47:19 -07:00
) -> Response {
2025-09-12 19:18:27 -04:00
self . route_typed_request ( headers , body , " /v1/chat/completions " , model_id )
2025-07-30 17:47:19 -07:00
. await
2025-07-18 14:24:24 -07:00
}
async fn route_completion (
& self ,
2025-07-30 17:47:19 -07:00
headers : Option < & HeaderMap > ,
body : & CompletionRequest ,
2025-09-12 19:18:27 -04:00
model_id : Option < & str > ,
2025-07-30 17:47:19 -07:00
) -> Response {
2025-09-12 19:18:27 -04:00
self . route_typed_request ( headers , body , " /v1/completions " , model_id )
2025-07-30 17:47:19 -07:00
. await
2025-07-18 14:24:24 -07:00
}
2025-09-11 20:56:17 -07:00
async fn route_responses (
& self ,
headers : Option < & HeaderMap > ,
body : & ResponsesRequest ,
2025-09-12 19:18:27 -04:00
model_id : Option < & str > ,
2025-09-11 20:56:17 -07:00
) -> Response {
2025-09-12 19:18:27 -04:00
self . route_typed_request ( headers , body , " /v1/responses " , model_id )
2025-09-11 20:56:17 -07:00
. await
}
2025-09-12 16:19:38 -07:00
async fn get_response ( & self , headers : Option < & HeaderMap > , response_id : & str ) -> Response {
let endpoint = format! ( " v1/responses/ {} " , response_id ) ;
self . route_get_request ( headers , & endpoint ) . await
}
async fn cancel_response ( & self , headers : Option < & HeaderMap > , response_id : & str ) -> Response {
let endpoint = format! ( " v1/responses/ {} /cancel " , response_id ) ;
self . route_post_empty_request ( headers , & endpoint ) . await
}
2025-09-15 09:44:35 +08:00
async fn route_embeddings (
& self ,
headers : Option < & HeaderMap > ,
body : & EmbeddingRequest ,
model_id : Option < & str > ,
) -> Response {
// Record embeddings-specific metrics in addition to general request metrics
let start = Instant ::now ( ) ;
let res = self
. route_typed_request ( headers , body , " /v1/embeddings " , model_id )
. await ;
// Embedding specific metrics
if res . status ( ) . is_success ( ) {
RouterMetrics ::record_embeddings_request ( ) ;
RouterMetrics ::record_embeddings_duration ( start . elapsed ( ) ) ;
} else {
let error_type = format! ( " http_ {} " , res . status ( ) . as_u16 ( ) ) ;
RouterMetrics ::record_embeddings_error ( & error_type ) ;
}
res
2025-08-28 12:07:06 -07:00
}
2025-09-12 19:18:27 -04:00
async fn route_rerank (
& self ,
headers : Option < & HeaderMap > ,
body : & RerankRequest ,
model_id : Option < & str > ,
) -> Response {
2025-09-13 00:10:18 +08:00
if let Err ( e ) = body . validate ( ) {
return ( StatusCode ::BAD_REQUEST , e ) . into_response ( ) ;
}
2025-09-12 19:18:27 -04:00
let response = self
. route_typed_request ( headers , body , " /v1/rerank " , model_id )
. await ;
2025-09-13 00:10:18 +08:00
if response . status ( ) . is_success ( ) {
match Self ::build_rerank_response ( body , response ) . await {
Ok ( rerank_response ) = > rerank_response ,
Err ( e ) = > {
error! ( " Failed to build rerank response: {} " , e ) ;
return (
StatusCode ::INTERNAL_SERVER_ERROR ,
" Failed to build rerank response " . to_string ( ) ,
)
. into_response ( ) ;
}
}
} else {
response
}
2025-08-28 12:07:06 -07:00
}
2025-08-02 19:16:47 -07:00
async fn flush_cache ( & self ) -> Response {
2025-07-18 14:24:24 -07:00
// Get all worker URLs
let worker_urls = self . get_worker_urls ( ) ;
// Send requests to all workers concurrently without headers
let mut tasks = Vec ::new ( ) ;
for worker_url in & worker_urls {
2025-07-30 20:58:48 +08:00
let worker_url = if self . dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let ( worker_url_prefix , _dp_rank ) = match Self ::extract_dp_rank ( worker_url ) {
Ok ( tup ) = > tup ,
Err ( e ) = > {
error! ( " Failed to extract dp_rank: {} " , e ) ;
2025-07-30 17:47:19 -07:00
return (
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Failed to extract dp_rank: {} " , e ) ,
)
. into_response ( ) ;
2025-07-30 20:58:48 +08:00
}
} ;
worker_url_prefix
} else {
worker_url
} ;
2025-08-02 19:16:47 -07:00
let request_builder = self . client . post ( format! ( " {} /flush_cache " , worker_url ) ) ;
2025-07-18 14:24:24 -07:00
tasks . push ( request_builder . send ( ) ) ;
}
// Wait for all responses
let results = futures_util ::future ::join_all ( tasks ) . await ;
// Check if all succeeded
let all_success = results . iter ( ) . all ( | r | {
r . as_ref ( )
. map ( | res | res . status ( ) . is_success ( ) )
. unwrap_or ( false )
} ) ;
if all_success {
2025-07-30 17:47:19 -07:00
( StatusCode ::OK , " Cache flushed on all servers " ) . into_response ( )
2025-07-18 14:24:24 -07:00
} else {
2025-07-30 17:47:19 -07:00
(
StatusCode ::INTERNAL_SERVER_ERROR ,
" Cache flush failed on one or more servers " ,
)
. into_response ( )
2025-07-18 14:24:24 -07:00
}
}
2025-08-02 19:16:47 -07:00
async fn get_worker_loads ( & self ) -> Response {
2025-09-22 03:28:38 +08:00
let urls_with_key = self . worker_registry . get_all_urls_with_api_key ( ) ;
2025-07-18 14:24:24 -07:00
let mut loads = Vec ::new ( ) ;
// Get loads from all workers
2025-09-22 03:28:38 +08:00
for ( url , api_key ) in & urls_with_key {
let load = self . get_worker_load ( url , api_key ) . await . unwrap_or ( - 1 ) ;
2025-07-18 14:24:24 -07:00
loads . push ( serde_json ::json! ( {
" worker " : url ,
" load " : load
} ) ) ;
}
2025-07-30 17:47:19 -07:00
Json ( serde_json ::json! ( {
2025-07-18 14:24:24 -07:00
" workers " : loads
} ) )
2025-07-30 17:47:19 -07:00
. into_response ( )
2025-07-18 14:24:24 -07:00
}
fn router_type ( & self ) -> & 'static str {
" regular "
}
2025-07-30 17:47:19 -07:00
fn readiness ( & self ) -> Response {
2025-07-18 14:24:24 -07:00
// Regular router is ready if it has at least one healthy worker
2025-09-12 19:18:27 -04:00
let workers = self . worker_registry . get_all ( ) ;
let healthy_count = workers . iter ( ) . filter ( | w | w . is_healthy ( ) ) . count ( ) ;
let total_workers = workers . len ( ) ;
2025-07-18 14:24:24 -07:00
if healthy_count > 0 {
2025-07-30 17:47:19 -07:00
Json ( serde_json ::json! ( {
2025-07-18 14:24:24 -07:00
" status " : " ready " ,
" healthy_workers " : healthy_count ,
2025-09-12 19:18:27 -04:00
" total_workers " : total_workers
2025-07-18 14:24:24 -07:00
} ) )
2025-07-30 17:47:19 -07:00
. into_response ( )
2025-07-18 14:24:24 -07:00
} else {
2025-07-30 17:47:19 -07:00
(
StatusCode ::SERVICE_UNAVAILABLE ,
Json ( serde_json ::json! ( {
" status " : " not_ready " ,
" reason " : " no healthy workers available " ,
2025-09-12 19:18:27 -04:00
" total_workers " : total_workers
2025-07-30 17:47:19 -07:00
} ) ) ,
)
. into_response ( )
2025-07-18 14:24:24 -07:00
}
}
}
#[ cfg(test) ]
mod tests {
use super ::* ;
use std ::collections ::HashMap ;
fn create_test_regular_router ( ) -> Router {
2025-09-12 19:18:27 -04:00
// Create registries
let worker_registry = Arc ::new ( WorkerRegistry ::new ( ) ) ;
let policy_registry = Arc ::new ( PolicyRegistry ::new (
crate ::config ::types ::PolicyConfig ::RoundRobin ,
) ) ;
// Register test workers
2025-09-19 01:52:57 -04:00
let worker1 = BasicWorkerBuilder ::new ( " http://worker1:8080 " )
. worker_type ( WorkerType ::Regular )
2025-09-22 03:28:38 +08:00
. api_key ( " test_api_key " )
2025-09-19 01:52:57 -04:00
. build ( ) ;
let worker2 = BasicWorkerBuilder ::new ( " http://worker2:8080 " )
. worker_type ( WorkerType ::Regular )
2025-09-22 03:28:38 +08:00
. api_key ( " test_api_key " )
2025-09-19 01:52:57 -04:00
. build ( ) ;
2025-09-12 19:18:27 -04:00
worker_registry . register ( Arc ::new ( worker1 ) ) ;
worker_registry . register ( Arc ::new ( worker2 ) ) ;
2025-07-18 14:24:24 -07:00
let ( _ , rx ) = tokio ::sync ::watch ::channel ( HashMap ::new ( ) ) ;
Router {
2025-09-12 19:18:27 -04:00
worker_registry ,
policy_registry ,
2025-09-02 04:57:04 +02:00
worker_startup_timeout_secs : 5 ,
worker_startup_check_interval_secs : 1 ,
2025-07-30 20:58:48 +08:00
dp_aware : false ,
api_key : None ,
2025-08-02 19:16:47 -07:00
client : Client ::new ( ) ,
2025-08-04 20:42:07 -07:00
retry_config : RetryConfig ::default ( ) ,
2025-08-08 09:20:22 -07:00
circuit_breaker_config : CircuitBreakerConfig ::default ( ) ,
2025-07-18 14:24:24 -07:00
_worker_loads : Arc ::new ( rx ) ,
_load_monitor_handle : None ,
}
}
#[ test ]
fn test_router_get_worker_urls_regular ( ) {
let router = create_test_regular_router ( ) ;
let urls = router . get_worker_urls ( ) ;
assert_eq! ( urls . len ( ) , 2 ) ;
assert! ( urls . contains ( & " http://worker1:8080 " . to_string ( ) ) ) ;
assert! ( urls . contains ( & " http://worker2:8080 " . to_string ( ) ) ) ;
}
#[ test ]
fn test_select_first_worker_regular ( ) {
let router = create_test_regular_router ( ) ;
let result = router . select_first_worker ( ) ;
assert! ( result . is_ok ( ) ) ;
2025-09-12 19:18:27 -04:00
let url = result . unwrap ( ) ;
// DashMap doesn't guarantee order, so just check we get one of the workers
assert! ( url = = " http://worker1:8080 " | | url = = " http://worker2:8080 " ) ;
2025-07-18 14:24:24 -07:00
}
2025-08-11 21:37:36 -07:00
#[ tokio::test ]
async fn test_wait_for_healthy_workers_empty_list ( ) {
// Empty list will return error immediately
let result = Router ::wait_for_healthy_workers ( & [ ] , 1 , 1 ) . await ;
2025-08-04 20:42:07 -07:00
assert! ( result . is_err ( ) ) ;
2025-08-11 21:37:36 -07:00
assert! ( result . unwrap_err ( ) . contains ( " no workers provided " ) ) ;
2025-07-18 14:24:24 -07:00
}
2025-08-11 21:37:36 -07:00
#[ tokio::test ]
async fn test_wait_for_healthy_workers_invalid_urls ( ) {
2025-07-18 14:24:24 -07:00
// This test will timeout quickly since the URLs are invalid
let result =
2025-08-11 21:37:36 -07:00
Router ::wait_for_healthy_workers ( & [ " http://nonexistent:8080 " . to_string ( ) ] , 1 , 1 ) . await ;
2025-07-18 14:24:24 -07:00
assert! ( result . is_err ( ) ) ;
assert! ( result . unwrap_err ( ) . contains ( " Timeout " ) ) ;
}
}