diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 77d9141c0..d9cbf9bac 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -35,7 +35,7 @@ pub struct PDRouter { pub interval_secs: u64, pub worker_loads: Arc>>, pub load_monitor_handle: Option>>, - pub http_client: reqwest::Client, + pub http_client: Client, _prefill_health_checker: Option, _decode_health_checker: Option, } @@ -206,51 +206,17 @@ impl PDRouter { } // Initialize cache-aware components if needed for prefill policy - let prefill_tree = if prefill_policy.name() == "cache_aware" { - // Initialize the policy's internal tree with prefill workers - if let Some(cache_policy) = prefill_policy - .as_any() - .downcast_ref::() - { - cache_policy.init_workers(&prefill_workers); - } - - let tree = Arc::new(Mutex::new(Tree::new())); - // Initialize tree with prefill workers - for worker in &prefill_workers { - tree.lock().unwrap().insert("", worker.url()); - } - Some(tree) - } else { - None - }; + let prefill_tree = Self::initialize_radix_tree(&prefill_policy, &prefill_workers)?; // Initialize cache-aware components if needed for decode policy - let decode_tree = if decode_policy.name() == "cache_aware" { - // Initialize the policy's internal tree with decode workers - if let Some(cache_policy) = decode_policy - .as_any() - .downcast_ref::() - { - cache_policy.init_workers(&decode_workers); - } - - let tree = Arc::new(Mutex::new(Tree::new())); - // Initialize tree with decode workers - for worker in &decode_workers { - tree.lock().unwrap().insert("", worker.url()); - } - Some(tree) - } else { - None - }; + let decode_tree = Self::initialize_radix_tree(&decode_policy, &decode_workers)?; // Set up background load monitoring for power-of-two selection let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let worker_loads = Arc::new(rx); // Create a shared HTTP client for all operations - let http_client = reqwest::Client::builder() + let http_client = Client::builder() .timeout(Duration::from_secs(timeout_secs)) .build() .map_err(|e| format!("Failed to create HTTP client: {}", e))?; @@ -304,6 +270,35 @@ impl PDRouter { }) } + // Helper function to initialize radix tree for cache-aware policies + fn initialize_radix_tree( + policy: &Arc, + workers: &[Box], + ) -> Result>>, String> { + if let Some(cache_policy) = policy + .as_any() + .downcast_ref::() + { + // Initialize the policy's internal tree with workers + cache_policy.init_workers(workers); + + let tree = Arc::new(Mutex::new(Tree::new())); + + { + let tree_guard = tree + .lock() + .map_err(|e| format!("Failed to lock tree: {}", e))?; + for worker in workers { + tree_guard.insert("", worker.url()); + } + } + + Ok(Some(tree)) + } else { + Ok(None) + } + } + // Route a typed generate request pub async fn route_generate( &self, @@ -329,7 +324,7 @@ impl PDRouter { }); // Select servers - let (prefill, decode) = match self.select_pd_pair(client, request_text).await { + let (prefill, decode) = match self.select_pd_pair(request_text).await { Ok(pair) => pair, Err(e) => { error!("Failed to select PD pair error={}", e); @@ -417,7 +412,7 @@ impl PDRouter { .and_then(|content| content.as_str()); // Select servers - let (prefill, decode) = match self.select_pd_pair(client, request_text).await { + let (prefill, decode) = match self.select_pd_pair(request_text).await { Ok(pair) => pair, Err(e) => { error!("Failed to select PD pair error={}", e); @@ -498,7 +493,7 @@ impl PDRouter { }; // Select servers - let (prefill, decode) = match self.select_pd_pair(client, request_text).await { + let (prefill, decode) = match self.select_pd_pair(request_text).await { Ok(pair) => pair, Err(e) => { error!("Failed to select PD pair error={}", e); @@ -833,7 +828,6 @@ impl PDRouter { // Select a pair of prefill and decode servers async fn select_pd_pair( &self, - _client: &Client, request_text: Option<&str>, ) -> Result<(Box, Box), String> { // Get read locks for both worker lists @@ -998,7 +992,7 @@ impl PDRouter { // Note: This endpoint actually causes the model to generate tokens, so we only test one pair // Select a random worker pair using the policy - let (prefill, decode) = match self.select_pd_pair(client, None).await { + let (prefill, decode) = match self.select_pd_pair(None).await { Ok(pair) => pair, Err(e) => { return ( @@ -1921,8 +1915,7 @@ mod tests { router.prefill_workers.write().unwrap().push(healthy_worker); router.decode_workers.write().unwrap().push(decode_worker); - let client = reqwest::Client::new(); - let result = router.select_pd_pair(&client, None).await; + let result = router.select_pd_pair(None).await; assert!(result.is_ok()); let (prefill, _decode) = result.unwrap(); @@ -1936,8 +1929,7 @@ mod tests { async fn test_empty_worker_lists() { let router = create_test_pd_router(); - let client = reqwest::Client::new(); - let result = router.select_pd_pair(&client, None).await; + let result = router.select_pd_pair(None).await; assert!(result.is_err()); assert!(result.unwrap_err().contains("No prefill workers available"));