[router] refactor router and worker management 4/n (#10756)

Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
Simo Lin
2025-09-22 21:35:10 -04:00
committed by GitHub
parent 113f8f65a2
commit 89971c4c3c
4 changed files with 161 additions and 196 deletions

View File

@@ -14,10 +14,7 @@ use crate::{
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
},
reasoning_parser::ParserFactory,
routers::{
router_manager::{RouterId, RouterManager},
RouterFactory, RouterTrait,
},
routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tool_parser::ParserRegistry,
@@ -64,10 +61,8 @@ impl AppContext {
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
// Initialize gRPC-specific components only when in gRPC mode
let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
if router_config.connection_mode == ConnectionMode::Grpc {
// Get tokenizer path (required for gRPC mode)
let tokenizer_path = router_config
.tokenizer_path
.clone()
@@ -77,7 +72,6 @@ impl AppContext {
.to_string()
})?;
// Initialize all gRPC components
let tokenizer = Some(
tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {e}"))?,
@@ -87,7 +81,6 @@ impl AppContext {
(tokenizer, reasoning_parser_factory, tool_parser_registry)
} else {
// HTTP mode doesn't need these components
(None, None, None)
};
@@ -96,7 +89,6 @@ impl AppContext {
let router_manager = None;
// Initialize response storage based on configuration
let response_storage: SharedResponseStorage = match router_config.history_backend {
HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()),
HistoryBackend::None => Arc::new(NoOpResponseStorage::new()),
@@ -125,12 +117,10 @@ pub struct AppState {
pub router_manager: Option<Arc<RouterManager>>,
}
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
StatusCode::NOT_FOUND.into_response()
}
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
state.router.liveness()
}
@@ -257,7 +247,6 @@ async fn v1_responses_delete(
Path(response_id): Path<String>,
headers: http::HeaderMap,
) -> Response {
// Python server does not support this yet
state
.router
.delete_response(Some(&headers), &response_id)
@@ -269,15 +258,12 @@ async fn v1_responses_list_input_items(
Path(response_id): Path<String>,
headers: http::HeaderMap,
) -> Response {
// Python server does not support this yet
state
.router
.list_response_input_items(Some(&headers), &response_id)
.await
}
// ---------- Worker management endpoints (Legacy) ----------
#[derive(Deserialize)]
struct AddWorkerQuery {
url: String,
@@ -288,7 +274,6 @@ async fn add_worker(
State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
) -> Response {
// Use centralized WorkerManager with full context
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
match result {
@@ -298,7 +283,6 @@ async fn add_worker(
}
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
// Use centralized WorkerManager instead of router's get_worker_urls
let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
Json(json!({ "urls": worker_list })).into_response()
}
@@ -307,7 +291,6 @@ async fn remove_worker(
State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
) -> Response {
// Use centralized WorkerManager with full context
let result = WorkerManager::remove_worker(&url, &state.context);
match result {
@@ -324,14 +307,10 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
state.router.get_worker_loads().await
}
// ---------- Worker management endpoints (RESTful) ----------
/// POST /workers - Add a new worker with full configuration
async fn create_worker(
State(state): State<Arc<AppState>>,
Json(config): Json<WorkerConfigRequest>,
) -> Response {
// In single router mode, use centralized WorkerManager with full context
let result = WorkerManager::add_worker_from_config(&config, &state.context).await;
match result {
@@ -353,9 +332,7 @@ async fn create_worker(
}
}
/// GET /workers - List all workers with details
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
// In single router mode, get detailed worker info from registry
let workers = state.context.worker_registry.get_all();
let response = serde_json::json!({
"workers": workers.iter().map(|worker| {
@@ -374,7 +351,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
"cost": worker.cost(),
});
// Add bootstrap_port for Prefill workers
if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
}
@@ -391,7 +367,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
Json(response).into_response()
}
/// GET /workers/{url} - Get specific worker info
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
let workers = WorkerManager::get_worker_urls(&state.context.worker_registry);
if workers.contains(&url) {
@@ -410,9 +385,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
}
}
/// DELETE /workers/{url} - Remove a worker
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
// In single router mode, use centralized WorkerManager with full context
let result = WorkerManager::remove_worker(&url, &state.context);
match result {
@@ -447,14 +420,12 @@ pub struct ServerConfig {
pub request_id_headers: Option<Vec<String>>,
}
/// Build the Axum application with all routes and middleware
pub fn build_app(
app_state: Arc<AppState>,
max_payload_size: usize,
request_id_headers: Vec<String>,
cors_allowed_origins: Vec<String>,
) -> Router {
// Create routes
let protected_routes = Router::new()
.route("/generate", post(generate))
.route("/v1/chat/completions", post(v1_chat_completions))
@@ -494,20 +465,17 @@ pub fn build_app(
.route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads));
// Worker management routes
let worker_routes = Router::new()
.route("/workers", post(create_worker))
.route("/workers", get(list_workers_rest))
.route("/workers/{url}", get(get_worker))
.route("/workers/{url}", delete(delete_worker));
// Build app with all routes and middleware
Router::new()
.merge(protected_routes)
.merge(public_routes)
.merge(admin_routes)
.merge(worker_routes)
// Request body size limiting
.layer(tower_http::limit::RequestBodyLimitLayer::new(
max_payload_size,
))
@@ -519,7 +487,6 @@ pub fn build_app(
}
pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
// Only initialize logging if not already done (for Python bindings support)
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
@@ -545,9 +512,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
None
};
// Initialize prometheus metrics exporter
if let Some(prometheus_config) = config.prometheus_config {
metrics::start_prometheus(prometheus_config);
if let Some(prometheus_config) = &config.prometheus_config {
metrics::start_prometheus(prometheus_config.clone());
}
info!(
@@ -569,7 +535,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.build()
.expect("Failed to create HTTP client");
// Create the application context with all dependencies
let app_context = AppContext::new(
config.router_config.clone(),
client.clone(),
@@ -597,67 +562,9 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
worker_stats.total_workers, worker_stats.healthy_workers
);
// Create the appropriate router based on enable_igw flag
let (router, router_manager): (Arc<dyn RouterTrait>, Option<Arc<RouterManager>>) =
if config.router_config.enable_igw {
info!("Multi-router mode enabled (enable_igw=true)");
let router_manager = RouterManager::from_config(&config, &app_context).await?;
let router: Arc<dyn RouterTrait> = router_manager.clone();
// Create RouterManager with shared registries from AppContext
let router_manager = Arc::new(RouterManager::new(app_context.worker_registry.clone()));
// 1. HTTP Regular Router
match RouterFactory::create_regular_router(&app_context).await {
Ok(http_regular) => {
info!("Created HTTP Regular router");
router_manager.register_router(
RouterId::new("http-regular".to_string()),
Arc::from(http_regular),
);
}
Err(e) => {
warn!("Failed to create HTTP Regular router: {e}");
}
}
// 2. HTTP PD Router
match RouterFactory::create_pd_router(
None,
None,
&config.router_config.policy,
&app_context,
)
.await
{
Ok(http_pd) => {
info!("Created HTTP PD router");
router_manager
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
}
Err(e) => {
warn!("Failed to create HTTP PD router: {e}");
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
info!(
"RouterManager initialized with {} routers",
router_manager.router_count()
);
(
router_manager.clone() as Arc<dyn RouterTrait>,
Some(router_manager),
)
} else {
info!("Single router mode (enable_igw=false)");
// Create single router with the context
(
Arc::from(RouterFactory::create_router(&app_context).await?),
None,
)
};
// Start health checker for all workers in the registry
let _health_checker = app_context
.worker_registry
.start_health_checker(config.router_config.health_check.check_interval_secs);
@@ -666,14 +573,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config.router_config.health_check.check_interval_secs
);
// Set up concurrency limiter with queue if configured
let (limiter, processor) = middleware::ConcurrencyLimiter::new(
app_context.rate_limiter.clone(),
config.router_config.queue_size,
Duration::from_secs(config.router_config.queue_timeout_secs),
);
// Start queue processor if enabled
if let Some(processor) = processor {
spawn(processor.run());
info!(
@@ -682,21 +587,18 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
);
}
// Create app state with router and context
let app_state = Arc::new(AppState {
router,
context: app_context.clone(),
concurrency_queue_tx: limiter.queue_tx.clone(),
router_manager,
router_manager: Some(router_manager),
});
// Start the service discovery if enabled
if let Some(service_discovery_config) = config.service_discovery_config {
if service_discovery_config.enabled {
let app_context_arc = Arc::clone(&app_state.context);
match start_service_discovery(service_discovery_config, app_context_arc).await {
Ok(handle) => {
info!("Service discovery started");
// Spawn a task to handle the service discovery thread
spawn(async move {
if let Err(e) = handle.await {
error!("Service discovery task failed: {:?}", e);
@@ -725,7 +627,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
]
});
// Build the application
let app = build_app(
app_state,
config.max_payload_size,
@@ -744,7 +645,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
Ok(())
}
// Graceful shutdown handler
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
@@ -773,19 +673,16 @@ async fn shutdown_signal() {
}
}
// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
use tower_http::cors::Any;
let cors = if allowed_origins.is_empty() {
// Allow all origins if none specified
tower_http::cors::CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
.expose_headers(Any)
} else {
// Restrict to specific origins
let origins: Vec<http::HeaderValue> = allowed_origins
.into_iter()
.filter_map(|origin| origin.parse().ok())