diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 418bf8d2c..30b248e87 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -21,7 +21,7 @@ serde_json = "1.0" pyo3 = { version = "0.22.5", features = ["extension-module"] } dashmap = "6.1.0" http = "1.1.0" -tokio = "1.42.0" +tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread"] } # Added for enhanced logging system tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "chrono"] } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index e60760415..0d6cf6910 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -18,7 +18,7 @@ use tracing::{error, info, warn, Level}; #[derive(Debug)] pub struct AppState { - router: Router, + router: Arc, client: Client, } @@ -29,7 +29,7 @@ impl AppState { policy_config: PolicyConfig, ) -> Result { // Create router based on policy - let router = Router::new(worker_urls, policy_config)?; + let router = Arc::new(Router::new(worker_urls, policy_config)?); Ok(Self { router, client }) } } @@ -218,24 +218,23 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .build() .expect("Failed to create HTTP client"); - let app_state = web::Data::new( - AppState::new( - config.worker_urls.clone(), - client.clone(), // Clone the client here - config.policy_config.clone(), - ) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, - ); + let app_state_init = AppState::new( + config.worker_urls.clone(), + client.clone(), + config.policy_config.clone(), + ) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + let router_arc = Arc::clone(&app_state_init.router); + let app_state = web::Data::new(app_state_init); // Start the service discovery if enabled if let Some(service_discovery_config) = config.service_discovery_config { if service_discovery_config.enabled { - let worker_urls = Arc::clone(&app_state.router.get_worker_urls()); - - match start_service_discovery(service_discovery_config, worker_urls).await { + info!("🚧 Initializing Kubernetes service discovery"); + // Pass the Arc directly + match start_service_discovery(service_discovery_config, router_arc).await { Ok(handle) => { info!("✅ Service discovery started successfully"); - // Spawn a task to handle the service discovery thread spawn(async move { if let Err(e) = handle.await { @@ -252,7 +251,10 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { } info!("✅ Serving router on {}:{}", config.host, config.port); - info!("✅ Serving workers on {:?}", config.worker_urls); + info!( + "✅ Serving workers on {:?}", + app_state.router.get_worker_urls().read().unwrap() + ); HttpServer::new(move || { App::new() diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 28210f9fa..103551891 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -1,3 +1,5 @@ +use crate::router::Router; + use futures::{StreamExt, TryStreamExt}; use k8s_openapi::api::core::v1::Pod; use kube::{ @@ -7,11 +9,12 @@ use kube::{ Client, }; use std::collections::{HashMap, HashSet}; -use std::sync::{Arc, Mutex, RwLock}; + +use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::task; use tokio::time; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; /// Represents the service discovery configuration #[derive(Debug, Clone)] @@ -81,7 +84,7 @@ impl PodInfo { pub async fn start_service_discovery( config: ServiceDiscoveryConfig, - worker_urls: Arc>>, + router: Arc, ) -> Result, kube::Error> { // Don't initialize anything if service discovery is disabled if !config.enabled { @@ -136,7 +139,6 @@ pub async fn start_service_discovery( // Clone Arcs for the closures let selector_clone = Arc::clone(&selector); let tracked_pods_clone = Arc::clone(&tracked_pods); - let worker_urls_clone = Arc::clone(&worker_urls); // Apply label selector filter separately since we can't do it directly with the watcher anymore let filtered_stream = watcher_stream.filter_map(move |obj_res| { @@ -164,12 +166,12 @@ pub async fn start_service_discovery( // Clone again for the next closure let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone); - let worker_urls_clone2 = Arc::clone(&worker_urls_clone); + let router_clone = Arc::clone(&router); match filtered_stream .try_for_each(move |pod| { let tracked_pods_inner = Arc::clone(&tracked_pods_clone2); - let worker_urls_inner = Arc::clone(&worker_urls_clone2); + let router_inner = Arc::clone(&router_clone); async move { if let Some(pod_info) = PodInfo::from_pod(&pod) { @@ -177,18 +179,13 @@ pub async fn start_service_discovery( handle_pod_deletion( &pod_info, tracked_pods_inner, - worker_urls_inner, + router_inner, port, ) .await; } else { - handle_pod_event( - &pod_info, - tracked_pods_inner, - worker_urls_inner, - port, - ) - .await; + handle_pod_event(&pod_info, tracked_pods_inner, router_inner, port) + .await; } } Ok(()) @@ -219,7 +216,7 @@ pub async fn start_service_discovery( async fn handle_pod_event( pod_info: &PodInfo, tracked_pods: Arc>>, - worker_urls: Arc>>, + router: Arc, port: u16, ) { let worker_url = pod_info.worker_url(port); @@ -234,52 +231,275 @@ async fn handle_pod_event( if pod_info.is_healthy() { if !already_tracked { info!( - "Adding healthy pod {} ({}) as worker", - pod_info.name, pod_info.ip + "Healthy pod found: {}. Adding worker: {}", + pod_info.name, worker_url ); - - // Add URL to worker list - let mut urls = worker_urls.write().unwrap(); - if !urls.contains(&worker_url) { - urls.push(worker_url.clone()); - info!("Added new worker URL: {}", worker_url); + match router.add_worker(&worker_url).await { + Ok(msg) => { + info!("Router add_worker: {}", msg); + let mut tracker = tracked_pods.lock().unwrap(); + tracker.insert(pod_info.clone()); + } + Err(e) => error!("Failed to add worker {} to router: {}", worker_url, e), } - - // Track this pod - let mut tracker = tracked_pods.lock().unwrap(); - tracker.insert(pod_info.clone()); } } else if already_tracked { // If pod was healthy before but not anymore, remove it - handle_pod_deletion(pod_info, tracked_pods, worker_urls, port).await; + handle_pod_deletion(pod_info, tracked_pods, router, port).await; } } async fn handle_pod_deletion( pod_info: &PodInfo, tracked_pods: Arc>>, - worker_urls: Arc>>, + router: Arc, port: u16, ) { let worker_url = pod_info.worker_url(port); + let mut tracked = tracked_pods.lock().unwrap(); - // Remove the pod from our tracking - let was_tracked = { - let mut tracker = tracked_pods.lock().unwrap(); - tracker.remove(pod_info) - }; - - if was_tracked { + if tracked.remove(pod_info) { info!( - "Removing pod {} ({}) from workers", - pod_info.name, pod_info.ip + "Pod deleted: {}. Removing worker: {}", + pod_info.name, worker_url + ); + router.remove_worker(&worker_url); + } else { + // This case might occur if a pod is deleted before it was ever marked healthy and added. + // Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added). + debug!( + "Pod deletion event for untracked/already removed pod: {}. Worker URL: {}", + pod_info.name, worker_url ); - - // Remove URL from worker list - let mut urls = worker_urls.write().unwrap(); - if let Some(idx) = urls.iter().position(|url| url == &worker_url) { - urls.remove(idx); - info!("Removed worker URL: {}", worker_url); - } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::router::Router; + use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus}; + use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; + use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time; + use std::sync::RwLock; + + // Helper function to create a Pod for testing PodInfo::from_pod + fn create_k8s_pod( + name: Option<&str>, + ip: Option<&str>, + phase: Option<&str>, + ready_status: Option<&str>, + deletion_timestamp: Option