From 36efd5be8a0f0c5e0f07dcd3a0b6b4df5d210c89 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 19 Sep 2025 09:19:57 -0400 Subject: [PATCH] [router] refactor router and worker management 1/n (#10664) --- sgl-router/src/routers/grpc/pd_router.rs | 141 ++++++++++++++--------- sgl-router/src/routers/grpc/router.rs | 58 ++++++---- 2 files changed, 119 insertions(+), 80 deletions(-) diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 86f7acb5e..cd692e20e 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -2,11 +2,11 @@ use crate::config::types::RetryConfig; use crate::core::{ - BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, + BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType, }; use crate::grpc::SglangSchedulerClient; use crate::metrics::RouterMetrics; -use crate::policies::LoadBalancingPolicy; +use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::reasoning_parser::ParserFactory; use crate::routers::{RouterTrait, WorkerManagement}; use crate::tokenizer::traits::Tokenizer; @@ -19,21 +19,17 @@ use axum::{ response::{IntoResponse, Response}, }; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::time::Duration; use tracing::{info, warn}; /// gRPC PD (Prefill-Decode) router implementation for SGLang #[allow(dead_code)] // Fields will be used once implementation is complete pub struct GrpcPDRouter { - /// Prefill worker connections - prefill_workers: Arc>>>, - /// Decode worker connections - decode_workers: Arc>>>, - /// gRPC clients for prefill workers - prefill_grpc_clients: Arc>>, - /// gRPC clients for decode workers - decode_grpc_clients: Arc>>, + /// Centralized worker registry + worker_registry: Arc, + /// Centralized policy registry + policy_registry: Arc, /// Load balancing policy for prefill prefill_policy: Arc, /// Load balancing policy for decode @@ -44,9 +40,6 @@ pub struct GrpcPDRouter { reasoning_parser_factory: ParserFactory, /// Tool parser registry for function/tool calls tool_parser_registry: &'static ParserRegistry, - /// Worker health checkers - _prefill_health_checker: Option, - _decode_health_checker: Option, /// Configuration timeout_secs: u64, interval_secs: u64, @@ -65,6 +58,10 @@ impl GrpcPDRouter { decode_policy: Arc, ctx: &Arc, ) -> Result { + // Get registries from context + let worker_registry = ctx.worker_registry.clone(); + let policy_registry = ctx.policy_registry.clone(); + // Update metrics RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len()); @@ -126,10 +123,9 @@ impl GrpcPDRouter { return Err("Failed to connect to any gRPC workers".to_string()); } - // Create Prefill Worker trait objects with gRPC connection mode - let prefill_workers: Vec> = prefill_urls - .iter() - .map(|(url, bootstrap_port)| { + // Create Prefill Worker trait objects with gRPC connection mode and register them + for (url, bootstrap_port) in &prefill_urls { + if let Some(client) = prefill_grpc_clients.remove(url) { let worker = BasicWorkerBuilder::new(url.clone()) .worker_type(WorkerType::Prefill { bootstrap_port: *bootstrap_port, @@ -145,15 +141,17 @@ impl GrpcPDRouter { failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold, }) + .grpc_client(client) .build(); - Arc::new(worker) as Arc - }) - .collect(); - // Create Decode Worker trait objects with gRPC connection mode - let decode_workers: Vec> = decode_urls - .iter() - .map(|url| { + // Register worker in the centralized registry + worker_registry.register(Arc::new(worker)); + } + } + + // Create Decode Worker trait objects with gRPC connection mode and register them + for url in &decode_urls { + if let Some(client) = decode_grpc_clients.remove(url) { let worker = BasicWorkerBuilder::new(url.clone()) .worker_type(WorkerType::Decode) .connection_mode(crate::core::ConnectionMode::Grpc { port: None }) @@ -165,12 +163,23 @@ impl GrpcPDRouter { failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold, }) + .grpc_client(client) .build(); - Arc::new(worker) as Arc - }) - .collect(); - // Initialize policies with workers if needed + // Register worker in the centralized registry + worker_registry.register(Arc::new(worker)); + } + } + + // Initialize policies with workers if needed - filter for gRPC workers only + let prefill_workers = worker_registry.get_workers_filtered( + None, // any model + Some(WorkerType::Prefill { + bootstrap_port: None, + }), + Some(crate::core::ConnectionMode::Grpc { port: None }), + false, // include unhealthy workers during initialization + ); if let Some(cache_aware) = prefill_policy .as_any() .downcast_ref::() @@ -178,6 +187,12 @@ impl GrpcPDRouter { cache_aware.init_workers(&prefill_workers); } + let decode_workers = worker_registry.get_workers_filtered( + None, // any model + Some(WorkerType::Decode), + Some(crate::core::ConnectionMode::Grpc { port: None }), + false, // include unhealthy workers during initialization + ); if let Some(cache_aware) = decode_policy .as_any() .downcast_ref::() @@ -185,30 +200,16 @@ impl GrpcPDRouter { cache_aware.init_workers(&decode_workers); } - let prefill_workers = Arc::new(RwLock::new(prefill_workers)); - let decode_workers = Arc::new(RwLock::new(decode_workers)); - - let prefill_health_checker = crate::core::start_health_checker( - Arc::clone(&prefill_workers), - ctx.router_config.worker_startup_check_interval_secs, - ); - let decode_health_checker = crate::core::start_health_checker( - Arc::clone(&decode_workers), - ctx.router_config.worker_startup_check_interval_secs, - ); + // No need for local health checkers - WorkerRegistry handles health checking Ok(GrpcPDRouter { - prefill_workers, - decode_workers, - prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)), - decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)), + worker_registry, + policy_registry, prefill_policy, decode_policy, tokenizer, reasoning_parser_factory, tool_parser_registry, - _prefill_health_checker: Some(prefill_health_checker), - _decode_health_checker: Some(decode_health_checker), timeout_secs: ctx.router_config.worker_startup_timeout_secs, interval_secs: ctx.router_config.worker_startup_check_interval_secs, dp_aware: ctx.router_config.dp_aware, @@ -221,15 +222,23 @@ impl GrpcPDRouter { impl std::fmt::Debug for GrpcPDRouter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let prefill_workers = self.worker_registry.get_workers_filtered( + None, + Some(WorkerType::Prefill { + bootstrap_port: None, + }), + Some(crate::core::ConnectionMode::Grpc { port: None }), + false, + ); + let decode_workers = self.worker_registry.get_workers_filtered( + None, + Some(WorkerType::Decode), + Some(crate::core::ConnectionMode::Grpc { port: None }), + false, + ); f.debug_struct("GrpcPDRouter") - .field( - "prefill_workers_count", - &self.prefill_workers.read().unwrap().len(), - ) - .field( - "decode_workers_count", - &self.decode_workers.read().unwrap().len(), - ) + .field("prefill_workers_count", &prefill_workers.len()) + .field("decode_workers_count", &decode_workers.len()) .field("timeout_secs", &self.timeout_secs) .field("interval_secs", &self.interval_secs) .field("dp_aware", &self.dp_aware) @@ -351,6 +360,28 @@ impl WorkerManagement for GrpcPDRouter { fn remove_worker(&self, _worker_url: &str) {} fn get_worker_urls(&self) -> Vec { - vec![] + let mut urls = Vec::new(); + + // Get gRPC prefill worker URLs only + let prefill_workers = self.worker_registry.get_workers_filtered( + None, + Some(WorkerType::Prefill { + bootstrap_port: None, + }), + Some(crate::core::ConnectionMode::Grpc { port: None }), + false, + ); + urls.extend(prefill_workers.iter().map(|w| w.url().to_string())); + + // Get gRPC decode worker URLs only + let decode_workers = self.worker_registry.get_workers_filtered( + None, + Some(WorkerType::Decode), + Some(crate::core::ConnectionMode::Grpc { port: None }), + false, + ); + urls.extend(decode_workers.iter().map(|w| w.url().to_string())); + + urls } } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index f88cf9ed2..3808952de 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -2,11 +2,11 @@ use crate::config::types::RetryConfig; use crate::core::{ - BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, + BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType, }; use crate::grpc::SglangSchedulerClient; use crate::metrics::RouterMetrics; -use crate::policies::LoadBalancingPolicy; +use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::reasoning_parser::ParserFactory; use crate::routers::{RouterTrait, WorkerManagement}; use crate::tokenizer::traits::Tokenizer; @@ -19,17 +19,17 @@ use axum::{ response::{IntoResponse, Response}, }; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::time::Duration; use tracing::{info, warn}; /// gRPC router implementation for SGLang #[allow(dead_code)] // Fields will be used once implementation is complete pub struct GrpcRouter { - /// Worker connections - workers: Arc>>>, - /// gRPC clients for each worker - grpc_clients: Arc>>, + /// Centralized worker registry + worker_registry: Arc, + /// Centralized policy registry + policy_registry: Arc, /// Load balancing policy policy: Arc, /// Tokenizer for handling text encoding/decoding @@ -38,8 +38,6 @@ pub struct GrpcRouter { reasoning_parser_factory: ParserFactory, /// Tool parser registry for function/tool calls tool_parser_registry: &'static ParserRegistry, - /// Worker health checker - _health_checker: Option, /// Configuration timeout_secs: u64, interval_secs: u64, @@ -102,10 +100,11 @@ impl GrpcRouter { return Err("Failed to connect to any gRPC workers".to_string()); } - // Create Worker trait objects with gRPC connection mode - let mut workers: Vec> = Vec::new(); + // Get registries from context + let worker_registry = ctx.worker_registry.clone(); + let policy_registry = ctx.policy_registry.clone(); - // Move clients from the HashMap to the workers + // Create Worker trait objects with gRPC connection mode and register them for url in &worker_urls { if let Some(client) = grpc_clients.remove(url) { let worker = BasicWorkerBuilder::new(url.clone()) @@ -122,12 +121,21 @@ impl GrpcRouter { .grpc_client(client) .build(); - workers.push(Arc::new(worker) as Arc); + // Register worker in the centralized registry + worker_registry.register(Arc::new(worker)); } else { warn!("No gRPC client for worker {}, skipping", url); } } + // Get only gRPC workers from registry for policy initialization + let workers = worker_registry.get_workers_filtered( + None, // any model + Some(WorkerType::Regular), + Some(crate::core::ConnectionMode::Grpc { port: None }), + false, // include unhealthy workers during initialization + ); + // Initialize policy with workers if needed if let Some(cache_aware) = policy .as_any() @@ -136,20 +144,15 @@ impl GrpcRouter { cache_aware.init_workers(&workers); } - let workers = Arc::new(RwLock::new(workers)); - let health_checker = crate::core::start_health_checker( - Arc::clone(&workers), - ctx.router_config.worker_startup_check_interval_secs, - ); + // No need for local health checkers - WorkerRegistry handles health checking Ok(GrpcRouter { - workers, - grpc_clients: Arc::new(RwLock::new(grpc_clients)), + worker_registry, + policy_registry, policy, tokenizer, reasoning_parser_factory, tool_parser_registry, - _health_checker: Some(health_checker), timeout_secs: ctx.router_config.worker_startup_timeout_secs, interval_secs: ctx.router_config.worker_startup_check_interval_secs, dp_aware: ctx.router_config.dp_aware, @@ -162,8 +165,9 @@ impl GrpcRouter { impl std::fmt::Debug for GrpcRouter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let stats = self.worker_registry.stats(); f.debug_struct("GrpcRouter") - .field("workers_count", &self.workers.read().unwrap().len()) + .field("workers_count", &stats.total_workers) .field("timeout_secs", &self.timeout_secs) .field("interval_secs", &self.interval_secs) .field("dp_aware", &self.dp_aware) @@ -285,9 +289,13 @@ impl WorkerManagement for GrpcRouter { fn remove_worker(&self, _worker_url: &str) {} fn get_worker_urls(&self) -> Vec { - self.workers - .read() - .unwrap() + self.worker_registry + .get_workers_filtered( + None, // any model + Some(WorkerType::Regular), + Some(crate::core::ConnectionMode::Grpc { port: None }), + false, // include all workers + ) .iter() .map(|w| w.url().to_string()) .collect()