diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index dbccb21e0..a6c204e39 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -53,7 +53,7 @@ pub struct RouterManager { routers: Arc>>, /// Default router for requests without specific routing - default_router: Option, + default_router: Arc>>, /// HTTP client for querying worker info client: reqwest::Client, @@ -75,27 +75,29 @@ impl RouterManager { worker_registry, policy_registry, routers: Arc::new(DashMap::new()), - default_router: None, + default_router: Arc::new(std::sync::RwLock::new(None)), client, config, } } /// Register a router with the manager - pub fn register_router(&mut self, id: RouterId, router: Arc) { + pub fn register_router(&self, id: RouterId, router: Arc) { // Store router self.routers.insert(id.clone(), router); // Set as default if first router - if self.default_router.is_none() { - self.default_router = Some(id.clone()); + let mut default_router = self.default_router.write().unwrap(); + if default_router.is_none() { + *default_router = Some(id.clone()); info!("Set default router to {}", id.as_str()); } } /// Set the default router - pub fn set_default_router(&mut self, id: RouterId) { - self.default_router = Some(id); + pub fn set_default_router(&self, id: RouterId) { + let mut default_router = self.default_router.write().unwrap(); + *default_router = Some(id); } /// Get the number of registered routers @@ -130,7 +132,8 @@ impl RouterManager { } // Fall back to default router - if let Some(ref default_id) = self.default_router { + let default_router = self.default_router.read().unwrap(); + if let Some(ref default_id) = *default_router { self.routers.get(default_id).map(|r| r.clone()) } else { None @@ -808,7 +811,7 @@ impl std::fmt::Debug for RouterManager { f.debug_struct("RouterManager") .field("routers_count", &self.routers.len()) .field("workers_count", &self.worker_registry.get_all().len()) - .field("default_router", &self.default_router) + .field("default_router", &*self.default_router.read().unwrap()) .finish() } } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 348d70aea..215e2b54c 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -122,6 +122,7 @@ pub struct AppState { pub router: Arc, pub context: Arc, pub concurrency_queue_tx: Option>, + pub router_manager: Option>, } // Fallback handler for unmatched routes @@ -326,8 +327,8 @@ async fn create_worker( State(state): State>, Json(config): Json, ) -> Response { - // Check if the router is actually a RouterManager (enable_igw=true) - if let Some(router_manager) = state.router.as_any().downcast_ref::() { + // Check if we have a RouterManager (enable_igw=true) + if let Some(router_manager) = &state.router_manager { // Call RouterManager's add_worker method directly with the full config match router_manager.add_worker(config).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -357,7 +358,7 @@ async fn create_worker( /// GET /workers - List all workers with details async fn list_workers_rest(State(state): State>) -> Response { - if let Some(router_manager) = state.router.as_any().downcast_ref::() { + if let Some(router_manager) = &state.router_manager { let response = router_manager.list_workers(); Json(response).into_response() } else { @@ -400,7 +401,7 @@ async fn list_workers_rest(State(state): State>) -> Response { /// GET /workers/{url} - Get specific worker info async fn get_worker(State(state): State>, Path(url): Path) -> Response { - if let Some(router_manager) = state.router.as_any().downcast_ref::() { + if let Some(router_manager) = &state.router_manager { if let Some(worker) = router_manager.get_worker(&url) { Json(worker).into_response() } else { @@ -431,7 +432,7 @@ async fn get_worker(State(state): State>, Path(url): Path) /// DELETE /workers/{url} - Remove a worker async fn delete_worker(State(state): State>, Path(url): Path) -> Response { - if let Some(router_manager) = state.router.as_any().downcast_ref::() { + if let Some(router_manager) = &state.router_manager { match router_manager.remove_worker_from_registry(&url) { Ok(response) => (StatusCode::OK, Json(response)).into_response(), Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(), @@ -594,69 +595,76 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box = if config.router_config.enable_igw { - info!("Multi-router mode enabled (enable_igw=true)"); + let (router, router_manager): (Arc, Option>) = + if config.router_config.enable_igw { + info!("Multi-router mode enabled (enable_igw=true)"); - // Create RouterManager with shared registries from AppContext - let mut router_manager = RouterManager::new( - config.router_config.clone(), - client.clone(), - app_context.worker_registry.clone(), - app_context.policy_registry.clone(), - ); + // Create RouterManager with shared registries from AppContext + let router_manager = Arc::new(RouterManager::new( + config.router_config.clone(), + client.clone(), + app_context.worker_registry.clone(), + app_context.policy_registry.clone(), + )); - // 1. HTTP Regular Router - match RouterFactory::create_regular_router( - &[], // Empty worker list - workers added later - &app_context, - ) - .await - { - Ok(http_regular) => { - info!("Created HTTP Regular router"); - router_manager.register_router( - RouterId::new("http-regular".to_string()), - Arc::from(http_regular), - ); + // 1. HTTP Regular Router + match RouterFactory::create_regular_router( + &[], // Empty worker list - workers added later + &app_context, + ) + .await + { + Ok(http_regular) => { + info!("Created HTTP Regular router"); + router_manager.register_router( + RouterId::new("http-regular".to_string()), + Arc::from(http_regular), + ); + } + Err(e) => { + warn!("Failed to create HTTP Regular router: {e}"); + } } - Err(e) => { - warn!("Failed to create HTTP Regular router: {e}"); - } - } - // 2. HTTP PD Router - match RouterFactory::create_pd_router( - &[], - &[], - None, - None, - &config.router_config.policy, - &app_context, - ) - .await - { - Ok(http_pd) => { - info!("Created HTTP PD router"); - router_manager - .register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd)); + // 2. HTTP PD Router + match RouterFactory::create_pd_router( + &[], + &[], + None, + None, + &config.router_config.policy, + &app_context, + ) + .await + { + Ok(http_pd) => { + info!("Created HTTP PD router"); + router_manager + .register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd)); + } + Err(e) => { + warn!("Failed to create HTTP PD router: {e}"); + } } - Err(e) => { - warn!("Failed to create HTTP PD router: {e}"); - } - } - // TODO: Add gRPC routers once we have dynamic tokenizer loading + // TODO: Add gRPC routers once we have dynamic tokenizer loading - info!( - "RouterManager initialized with {} routers", - router_manager.router_count() - ); - Box::new(router_manager) - } else { - info!("Single router mode (enable_igw=false)"); - // Create single router with the context - RouterFactory::create_router(&app_context).await? - }; + info!( + "RouterManager initialized with {} routers", + router_manager.router_count() + ); + ( + router_manager.clone() as Arc, + Some(router_manager), + ) + } else { + info!("Single router mode (enable_igw=false)"); + // Create single router with the context + ( + Arc::from(RouterFactory::create_router(&app_context).await?), + None, + ) + }; // Start health checker for all workers in the registry let _health_checker = app_context @@ -685,9 +693,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box