[router] refactor router and worker management 1/n (#10664)

This commit is contained in:
Simo Lin
2025-09-19 09:19:57 -04:00
committed by GitHub
parent 68cdc1893d
commit 36efd5be8a
2 changed files with 119 additions and 80 deletions

View File

@@ -2,11 +2,11 @@
use crate::config::types::RetryConfig;
use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::traits::Tokenizer;
@@ -19,17 +19,17 @@ use axum::{
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn};
/// gRPC router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcRouter {
/// Worker connections
workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
/// gRPC clients for each worker
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// Centralized worker registry
worker_registry: Arc<WorkerRegistry>,
/// Centralized policy registry
policy_registry: Arc<PolicyRegistry>,
/// Load balancing policy
policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding
@@ -38,8 +38,6 @@ pub struct GrpcRouter {
reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry,
/// Worker health checker
_health_checker: Option<HealthChecker>,
/// Configuration
timeout_secs: u64,
interval_secs: u64,
@@ -102,10 +100,11 @@ impl GrpcRouter {
return Err("Failed to connect to any gRPC workers".to_string());
}
// Create Worker trait objects with gRPC connection mode
let mut workers: Vec<Arc<dyn Worker>> = Vec::new();
// Get registries from context
let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone();
// Move clients from the HashMap to the workers
// Create Worker trait objects with gRPC connection mode and register them
for url in &worker_urls {
if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorkerBuilder::new(url.clone())
@@ -122,12 +121,21 @@ impl GrpcRouter {
.grpc_client(client)
.build();
workers.push(Arc::new(worker) as Arc<dyn Worker>);
// Register worker in the centralized registry
worker_registry.register(Arc::new(worker));
} else {
warn!("No gRPC client for worker {}, skipping", url);
}
}
// Get only gRPC workers from registry for policy initialization
let workers = worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include unhealthy workers during initialization
);
// Initialize policy with workers if needed
if let Some(cache_aware) = policy
.as_any()
@@ -136,20 +144,15 @@ impl GrpcRouter {
cache_aware.init_workers(&workers);
}
let workers = Arc::new(RwLock::new(workers));
let health_checker = crate::core::start_health_checker(
Arc::clone(&workers),
ctx.router_config.worker_startup_check_interval_secs,
);
// No need for local health checkers - WorkerRegistry handles health checking
Ok(GrpcRouter {
workers,
grpc_clients: Arc::new(RwLock::new(grpc_clients)),
worker_registry,
policy_registry,
policy,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
_health_checker: Some(health_checker),
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware,
@@ -162,8 +165,9 @@ impl GrpcRouter {
impl std::fmt::Debug for GrpcRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let stats = self.worker_registry.stats();
f.debug_struct("GrpcRouter")
.field("workers_count", &self.workers.read().unwrap().len())
.field("workers_count", &stats.total_workers)
.field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware)
@@ -285,9 +289,13 @@ impl WorkerManagement for GrpcRouter {
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
self.workers
.read()
.unwrap()
self.worker_registry
.get_workers_filtered(
None, // any model
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include all workers
)
.iter()
.map(|w| w.url().to_string())
.collect()