diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index fe8e3844b..2be9b6efd 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -5,7 +5,7 @@ //! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything use crate::config::RouterConfig; -use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry}; +use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType}; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesRequest, @@ -56,10 +56,6 @@ pub struct RouterManager { /// Default router for requests without specific routing default_router: Option, - /// Model to router mapping for model-aware routing - /// Multiple models can be served by the same router - model_routers: Arc>>, - /// HTTP client for querying worker info client: reqwest::Client, @@ -81,7 +77,6 @@ impl RouterManager { policy_registry, routers: Arc::new(DashMap::new()), default_router: None, - model_routers: Arc::new(DashMap::new()), client, config, } @@ -92,19 +87,11 @@ impl RouterManager { &mut self, id: RouterId, router: Arc, - models: Vec, + _models: Vec, // Keep parameter for backward compatibility but ignore it ) { // Store router self.routers.insert(id.clone(), router); - // Update model mappings - for model in models { - self.model_routers - .entry(model) - .or_default() - .push(id.clone()); - } - // Set as default if first router if self.default_router.is_none() { self.default_router = Some(id.clone()); @@ -122,14 +109,29 @@ impl RouterManager { self.routers.len() } - /// Get router for a specific model + /// Get router for a specific model based on worker types pub fn get_router_for_model(&self, model_id: &str) -> Option> { - // First try model-specific routers - if let Some(router_ids) = self.model_routers.get(model_id) { - if let Some(router_id) = router_ids.first() { - if let Some(router) = self.routers.get(router_id) { - return Some(router.clone()); - } + // Query workers for this model from registry + let workers = self.worker_registry.get_by_model(model_id); + + if !workers.is_empty() { + // Determine router based on worker types + let has_pd_workers = workers.iter().any(|w| { + matches!( + w.worker_type(), + WorkerType::Prefill { .. } | WorkerType::Decode + ) + }); + + let router_id = if has_pd_workers { + RouterId::new("http-pd".to_string()) + } else { + RouterId::new("http-regular".to_string()) + }; + + // Return the router if it exists + if let Some(router) = self.routers.get(&router_id) { + return Some(router.clone()); } } @@ -240,6 +242,17 @@ impl RouterManager { let policy_hint = labels.get("policy").map(|s| s.as_str()); let policy = self.policy_registry.on_worker_added(&model_id, policy_hint); + // Log which type of router would handle this worker (for debugging) + let expected_router = match config.worker_type.as_deref() { + Some("prefill") | Some("decode") => "http-pd", + _ => "http-regular", + }; + + info!( + "Worker for model '{}' would be handled by '{}' router based on type", + model_id, expected_router + ); + info!( "Added worker {} with URL {} for model {} using policy {}", worker_id.as_str(), @@ -272,8 +285,9 @@ impl RouterManager { if let Some(_worker) = self.worker_registry.remove_by_url(url) { // Notify PolicyRegistry about worker removal - if let Some(model_id) = model_id { - self.policy_registry.on_worker_removed(&model_id); + if let Some(ref model_id) = model_id { + self.policy_registry.on_worker_removed(model_id); + info!("Removed worker with URL {} for model {}", url, model_id); } else { info!("Removed worker with URL {}", url); @@ -406,14 +420,10 @@ impl RouterManager { }) .unwrap_or(false); - // If model specified, find routers serving that model + // If model specified, use get_router_for_model let candidate_routers = if let Some(model) = model_id { - // Get routers for specific model - if let Some(router_ids) = self.model_routers.get(model) { - router_ids - .iter() - .filter_map(|id| self.routers.get(id).map(|r| r.clone())) - .collect::>() + if let Some(router) = self.get_router_for_model(model) { + vec![router] } else { Vec::new() } @@ -547,14 +557,10 @@ impl RouterTrait for RouterManager { .into_response() } - /// Get available models - aggregate from all routers + /// Get available models - query from worker registry async fn get_models(&self, _req: Request) -> Response { - // Return models that have registered routers - let models = self - .model_routers - .iter() - .map(|entry| entry.key().clone()) - .collect::>(); + // Get models from worker registry + let models = self.worker_registry.get_models(); if models.is_empty() { (StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response()