From 35ef3f2902a771bce04e814a970c4da9e4f61e56 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Mon, 15 Sep 2025 18:05:00 -0700 Subject: [PATCH] [router] fix worker registration in multi model mode (#10486) --- sgl-router/src/core/worker.rs | 31 +++++++++++ sgl-router/src/routers/router_manager.rs | 67 +++++++++--------------- sgl-router/src/server.rs | 25 ++++----- 3 files changed, 69 insertions(+), 54 deletions(-) diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 27165bb0e..b6fb25e75 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -804,6 +804,37 @@ impl WorkerFactory { Box::new(worker) } + /// Create a prefill worker with labels + pub fn create_prefill_with_labels( + url: String, + bootstrap_port: Option, + labels: std::collections::HashMap, + circuit_breaker_config: CircuitBreakerConfig, + ) -> Box { + let mut worker = BasicWorker::new(url.clone(), WorkerType::Prefill { bootstrap_port }) + .with_circuit_breaker_config(circuit_breaker_config); + + // Add labels to metadata + worker.metadata.labels = labels; + + Box::new(worker) + } + + /// Create a decode worker with labels + pub fn create_decode_with_labels( + url: String, + labels: std::collections::HashMap, + circuit_breaker_config: CircuitBreakerConfig, + ) -> Box { + let mut worker = BasicWorker::new(url.clone(), WorkerType::Decode) + .with_circuit_breaker_config(circuit_breaker_config); + + // Add labels to metadata + worker.metadata.labels = labels; + + Box::new(worker) + } + /// Create a DP-aware worker of specified type pub fn create_dp_aware( base_url: String, diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 2be9b6efd..dbccb21e0 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -41,7 +41,6 @@ impl RouterId { } /// Router Manager - Central coordinator for routers and workers -/// Only created when enable_igw=true pub struct RouterManager { /// Worker registry (single source of truth in multi-router mode) worker_registry: Arc, @@ -49,7 +48,7 @@ pub struct RouterManager { /// Policy registry for managing model-to-policy mappings policy_registry: Arc, - /// All routers managed by this manager (max 4 routers in Phase 2) + /// All routers managed by this manager /// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd" routers: Arc>>, @@ -83,12 +82,7 @@ impl RouterManager { } /// Register a router with the manager - pub fn register_router( - &mut self, - id: RouterId, - router: Arc, - _models: Vec, // Keep parameter for backward compatibility but ignore it - ) { + pub fn register_router(&mut self, id: RouterId, router: Arc) { // Store router self.routers.insert(id.clone(), router); @@ -210,32 +204,28 @@ impl RouterManager { labels.insert("chat_template".to_string(), chat_template); } - // Create worker based on type - // Note: For prefill and decode workers, we can't easily add labels after creation - // since they return Box. We'll need to enhance WorkerFactory in the future. let worker = match config.worker_type.as_deref() { - Some("prefill") => { - // For now, prefill workers won't have custom labels - // TODO: Enhance WorkerFactory to accept labels for prefill workers - WorkerFactory::create_prefill(config.url.clone(), config.bootstrap_port) - } - Some("decode") => { - // For now, decode workers won't have custom labels - // TODO: Enhance WorkerFactory to accept labels for decode workers - WorkerFactory::create_decode(config.url.clone()) - } - _ => { - // Regular workers can have labels - WorkerFactory::create_regular_with_labels( - config.url.clone(), - labels.clone(), - CircuitBreakerConfig::default(), - ) - } + Some("prefill") => WorkerFactory::create_prefill_with_labels( + config.url.clone(), + config.bootstrap_port, + labels.clone(), + CircuitBreakerConfig::default(), + ), + Some("decode") => WorkerFactory::create_decode_with_labels( + config.url.clone(), + labels.clone(), + CircuitBreakerConfig::default(), + ), + _ => WorkerFactory::create_regular_with_labels( + config.url.clone(), + labels.clone(), + CircuitBreakerConfig::default(), + ), }; // Register worker - let worker_id = self.worker_registry.register(Arc::from(worker)); + let worker_arc: Arc = Arc::from(worker); + let worker_id = self.worker_registry.register(worker_arc.clone()); // Notify PolicyRegistry about the new worker // Extract policy hint from labels if provided @@ -262,7 +252,6 @@ impl RouterManager { ); // Return worker info - let worker_arc = self.worker_registry.get(&worker_id).unwrap(); let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc); Ok(WorkerApiResponse { @@ -375,7 +364,11 @@ impl RouterManager { model_id: worker.model_id().to_string(), priority: worker.priority(), cost: worker.cost(), - worker_type: format!("{:?}", worker.worker_type()), + worker_type: match worker.worker_type() { + WorkerType::Regular => "regular".to_string(), + WorkerType::Prefill { .. } => "prefill".to_string(), + WorkerType::Decode => "decode".to_string(), + }, is_healthy: worker.is_healthy(), load: worker.load(), connection_mode: format!("{:?}", worker.connection_mode()), @@ -387,11 +380,6 @@ impl RouterManager { } } - // Note: calculate_stats removed - using WorkerRegistry::stats() instead - - // === Phase 2: Router Management === - // Note: Dynamic router creation removed - routers are created and registered externally - /// Get the appropriate router for a request based on headers and request content pub fn select_router_for_request( &self, @@ -474,11 +462,6 @@ impl RouterManager { } } -// Note: Default implementation removed as RouterManager now requires AppContext -// which cannot be defaulted. RouterManager must be created with explicit context. - -// === Phase 2: RouterManager as RouterTrait === - /// RouterManager implements RouterTrait to act as a meta-router /// that delegates requests to the appropriate underlying router #[async_trait] diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 8428794d8..2bb9bbb05 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -317,8 +317,9 @@ async fn create_worker( State(state): State>, Json(config): Json, ) -> Response { - // Check if RouterManager is available (enable_igw=true) - if let Some(router_manager) = &state.context.router_manager { + // Check if the router is actually a RouterManager (enable_igw=true) + if let Some(router_manager) = state.router.as_any().downcast_ref::() { + // Call RouterManager's add_worker method directly with the full config match router_manager.add_worker(config).await { Ok(response) => (StatusCode::OK, Json(response)).into_response(), Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(), @@ -347,7 +348,7 @@ async fn create_worker( /// GET /workers - List all workers with details async fn list_workers_rest(State(state): State>) -> Response { - if let Some(router_manager) = &state.context.router_manager { + if let Some(router_manager) = state.router.as_any().downcast_ref::() { let response = router_manager.list_workers(); Json(response).into_response() } else { @@ -358,7 +359,11 @@ async fn list_workers_rest(State(state): State>) -> Response { let mut worker_info = serde_json::json!({ "url": worker.url(), "model_id": worker.model_id(), - "worker_type": format!("{:?}", worker.worker_type()), + "worker_type": match worker.worker_type() { + WorkerType::Regular => "regular", + WorkerType::Prefill { .. } => "prefill", + WorkerType::Decode => "decode", + }, "is_healthy": worker.is_healthy(), "load": worker.load(), "connection_mode": format!("{:?}", worker.connection_mode()), @@ -386,7 +391,7 @@ async fn list_workers_rest(State(state): State>) -> Response { /// GET /workers/{url} - Get specific worker info async fn get_worker(State(state): State>, Path(url): Path) -> Response { - if let Some(router_manager) = &state.context.router_manager { + if let Some(router_manager) = state.router.as_any().downcast_ref::() { if let Some(worker) = router_manager.get_worker(&url) { Json(worker).into_response() } else { @@ -417,7 +422,7 @@ async fn get_worker(State(state): State>, Path(url): Path) /// DELETE /workers/{url} - Remove a worker async fn delete_worker(State(state): State>, Path(url): Path) -> Response { - if let Some(router_manager) = &state.context.router_manager { + if let Some(router_manager) = state.router.as_any().downcast_ref::() { match router_manager.remove_worker_from_registry(&url) { Ok(response) => (StatusCode::OK, Json(response)).into_response(), Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(), @@ -603,7 +608,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box { @@ -624,11 +628,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box { info!("Created HTTP PD router"); - router_manager.register_router( - RouterId::new("http-pd".to_string()), - Arc::from(http_pd), - vec![], - ); + router_manager + .register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd)); } Err(e) => { warn!("Failed to create HTTP PD router: {e}");