diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 5498b4cc9..076ea2e23 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -133,10 +133,12 @@ impl Router { /// Get worker URLs for a specific model pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec { - let workers = match model_id { - Some(model) => self.worker_registry.get_by_model_fast(model), - None => self.worker_registry.get_all(), - }; + let workers = self.worker_registry.get_workers_filtered( + model_id, + Some(WorkerType::Regular), + Some(ConnectionMode::Http), + false, // get all workers + ); workers.iter().map(|w| w.url().to_string()).collect() } @@ -315,22 +317,6 @@ impl Router { } } - #[allow(dead_code)] - fn select_first_worker_for_model(&self, model_id: Option<&str>) -> Result { - let workers = match model_id { - Some(model) => self.worker_registry.get_by_model_fast(model), - None => self.worker_registry.get_all(), - }; - if workers.is_empty() { - Err(format!( - "No workers are available for model: {:?}", - model_id - )) - } else { - Ok(workers[0].url().to_string()) - } - } - pub async fn send_health_check(&self, worker_url: &str) -> Response { let health_url = if self.dp_aware { // Need to extract the URL from "http://host:port@dp_rank" @@ -444,11 +430,13 @@ impl Router { model_id: Option<&str>, text: Option<&str>, ) -> Option> { - // Get workers for the specified model (O(1) lookup if model_id is provided) - let workers = match model_id { - Some(model) => self.worker_registry.get_by_model_fast(model), - None => self.worker_registry.get_all(), - }; + // Get workers for the specified model O(1), filtered by connection mode + let workers = self.worker_registry.get_workers_filtered( + model_id, + Some(WorkerType::Regular), + Some(ConnectionMode::Http), + false, // get all workers, we'll filter by is_available() next + ); let available: Vec> = workers .iter() @@ -982,8 +970,12 @@ impl Router { self.policy_registry.on_worker_added(model_id, None); // Initialize cache-aware policy if applicable - let model_workers = - self.worker_registry.get_by_model_fast(model_id); + let model_workers = self.worker_registry.get_workers_filtered( + Some(model_id), + Some(WorkerType::Regular), + Some(ConnectionMode::Http), + false, + ); self.policy_registry .init_cache_aware_policy(model_id, &model_workers); @@ -1018,7 +1010,12 @@ impl Router { self.policy_registry.on_worker_added(model_id, None); // Initialize cache-aware policy if applicable - let model_workers = self.worker_registry.get_by_model_fast(model_id); + let model_workers = self.worker_registry.get_workers_filtered( + Some(model_id), + Some(WorkerType::Regular), + Some(ConnectionMode::Http), + false, + ); self.policy_registry .init_cache_aware_policy(model_id, &model_workers); }