diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 1870404d9..8428794d8 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,35 +1,42 @@ -use crate::config::RouterConfig; -use crate::core::WorkerRegistry; -use crate::logging::{self, LoggingConfig}; -use crate::metrics::{self, PrometheusConfig}; -use crate::middleware::TokenBucket; -use crate::policies::PolicyRegistry; -use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, - ResponsesRequest, V1RerankReqInput, +use crate::{ + config::{ConnectionMode, RouterConfig}, + core::{WorkerRegistry, WorkerType}, + logging::{self, LoggingConfig}, + metrics::{self, PrometheusConfig}, + middleware::{self, QueuedRequest, TokenBucket}, + policies::PolicyRegistry, + protocols::{ + spec::{ + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, + RerankRequest, ResponsesRequest, V1RerankReqInput, + }, + worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, + }, + reasoning_parser::ParserFactory, + routers::{ + router_manager::{RouterId, RouterManager}, + RouterFactory, RouterTrait, + }, + service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, + tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, + tool_parser::ParserRegistry, }; -use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}; -use crate::reasoning_parser::ParserFactory; -use crate::routers::router_manager::{RouterId, RouterManager}; -use crate::routers::{RouterFactory, RouterTrait}; -use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; -use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer}; -use crate::tool_parser::ParserRegistry; use axum::{ extract::{Path, Query, Request, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{delete, get, post}, - Json, Router, + serve, Json, Router, }; use reqwest::Client; -use std::collections::HashMap; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::time::Duration; -use tokio::net::TcpListener; -use tokio::signal; -use tokio::spawn; +use serde::Deserialize; +use serde_json::json; +use std::{ + sync::atomic::{AtomicBool, Ordering}, + sync::Arc, + time::Duration, +}; +use tokio::{net::TcpListener, signal, spawn}; use tracing::{error, info, warn, Level}; #[derive(Clone)] @@ -40,9 +47,9 @@ pub struct AppContext { pub tokenizer: Option>, pub reasoning_parser_factory: Option, pub tool_parser_registry: Option<&'static ParserRegistry>, - pub worker_registry: Arc, // Shared worker registry - pub policy_registry: Arc, // Shared policy registry - pub router_manager: Option>, // Only present when enable_igw=true + pub worker_registry: Arc, + pub policy_registry: Arc, + pub router_manager: Option>, } impl AppContext { @@ -57,7 +64,7 @@ impl AppContext { // Initialize gRPC-specific components only when in gRPC mode let (tokenizer, reasoning_parser_factory, tool_parser_registry) = - if router_config.connection_mode == crate::config::ConnectionMode::Grpc { + if router_config.connection_mode == ConnectionMode::Grpc { // Get tokenizer path (required for gRPC mode) let tokenizer_path = router_config .tokenizer_path @@ -71,7 +78,7 @@ impl AppContext { // Initialize all gRPC components let tokenizer = Some( tokenizer_factory::create_tokenizer(&tokenizer_path) - .map_err(|e| format!("Failed to create tokenizer: {}", e))?, + .map_err(|e| format!("Failed to create tokenizer: {e}"))?, ); let reasoning_parser_factory = Some(ParserFactory::new()); let tool_parser_registry = Some(ParserRegistry::new()); @@ -82,14 +89,10 @@ impl AppContext { (None, None, None) }; - // Initialize shared registries let worker_registry = Arc::new(WorkerRegistry::new()); - let policy_registry = Arc::new(PolicyRegistry::new( - router_config.policy.clone(), // Use default policy from config - )); + let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone())); - // Initialize RouterManager only when enable_igw is true - let router_manager = None; // Will be initialized in startup() based on config + let router_manager = None; Ok(Self { client, @@ -109,7 +112,7 @@ impl AppContext { pub struct AppState { pub router: Arc, pub context: Arc, - pub concurrency_queue_tx: Option>, + pub concurrency_queue_tx: Option>, } // Fallback handler for unmatched routes @@ -265,23 +268,18 @@ async fn v1_responses_list_input_items( .await } -// Worker management endpoints +// ---------- Worker management endpoints (Legacy) ---------- + +#[derive(Deserialize)] +struct UrlQuery { + url: String, +} + async fn add_worker( State(state): State>, - Query(params): Query>, + Query(UrlQuery { url }): Query, ) -> Response { - let worker_url = match params.get("url") { - Some(url) => url.to_string(), - None => { - return ( - StatusCode::BAD_REQUEST, - "Worker URL required. Provide 'url' query parameter", - ) - .into_response(); - } - }; - - match state.router.add_worker(&worker_url).await { + match state.router.add_worker(&url).await { Ok(message) => (StatusCode::OK, message).into_response(), Err(error) => (StatusCode::BAD_REQUEST, error).into_response(), } @@ -294,17 +292,12 @@ async fn list_workers(State(state): State>) -> Response { async fn remove_worker( State(state): State>, - Query(params): Query>, + Query(UrlQuery { url }): Query, ) -> Response { - let worker_url = match params.get("url") { - Some(url) => url.to_string(), - None => return StatusCode::BAD_REQUEST.into_response(), - }; - - state.router.remove_worker(&worker_url); + state.router.remove_worker(&url); ( StatusCode::OK, - format!("Successfully removed worker: {}", worker_url), + format!("Successfully removed worker: {url}"), ) .into_response() } @@ -317,7 +310,7 @@ async fn get_loads(State(state): State>, _req: Request) -> Respons state.router.get_worker_loads().await } -// New RESTful worker management endpoints (when enable_igw=true) +// ---------- Worker management endpoints (RESTful) ---------- /// POST /workers - Add a new worker with full configuration async fn create_worker( @@ -374,7 +367,7 @@ async fn list_workers_rest(State(state): State>) -> Response { }); // Add bootstrap_port for Prefill workers - if let crate::core::WorkerType::Prefill { bootstrap_port } = worker.worker_type() { + if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() { worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port); } @@ -384,7 +377,7 @@ async fn list_workers_rest(State(state): State>) -> Response { "stats": { "prefill_count": state.context.worker_registry.get_prefill_workers().len(), "decode_count": state.context.worker_registry.get_decode_workers().len(), - "regular_count": state.context.worker_registry.get_by_type(&crate::core::WorkerType::Regular).len(), + "regular_count": state.context.worker_registry.get_by_type(&WorkerType::Regular).len(), } }); Json(response).into_response() @@ -392,33 +385,29 @@ async fn list_workers_rest(State(state): State>) -> Response { } /// GET /workers/{url} - Get specific worker info -async fn get_worker( - State(state): State>, - axum::extract::Path(url): axum::extract::Path, -) -> Response { +async fn get_worker(State(state): State>, Path(url): Path) -> Response { if let Some(router_manager) = &state.context.router_manager { if let Some(worker) = router_manager.get_worker(&url) { Json(worker).into_response() } else { let error = WorkerErrorResponse { - error: format!("Worker {} not found", url), + error: format!("Worker {url} not found"), code: "WORKER_NOT_FOUND".to_string(), }; (StatusCode::NOT_FOUND, Json(error)).into_response() } } else { - // In single router mode, check if worker exists let workers = state.router.get_worker_urls(); if workers.contains(&url) { - let worker_info = serde_json::json!({ + Json(json!({ "url": url, "model_id": "unknown", "is_healthy": true - }); - Json(worker_info).into_response() + })) + .into_response() } else { let error = WorkerErrorResponse { - error: format!("Worker {} not found", url), + error: format!("Worker {url} not found"), code: "WORKER_NOT_FOUND".to_string(), }; (StatusCode::NOT_FOUND, Json(error)).into_response() @@ -427,10 +416,7 @@ async fn get_worker( } /// DELETE /workers/{url} - Remove a worker -async fn delete_worker( - State(state): State>, - axum::extract::Path(url): axum::extract::Path, -) -> Response { +async fn delete_worker(State(state): State>, Path(url): Path) -> Response { if let Some(router_manager) = &state.context.router_manager { match router_manager.remove_worker_from_registry(&url) { Ok(response) => (StatusCode::OK, Json(response)).into_response(), @@ -441,7 +427,7 @@ async fn delete_worker( state.router.remove_worker(&url); let response = WorkerApiResponse { success: true, - message: format!("Worker {} removed successfully", url), + message: format!("Worker {url} removed successfully"), worker: None, }; (StatusCode::OK, Json(response)).into_response() @@ -489,7 +475,7 @@ pub fn build_app( ) .route_layer(axum::middleware::from_fn_with_state( app_state.clone(), - crate::middleware::concurrency_limit_middleware, + middleware::concurrency_limit_middleware, )); let public_routes = Router::new() @@ -513,7 +499,7 @@ pub fn build_app( .route("/workers", post(create_worker)) .route("/workers", get(list_workers_rest)) .route("/workers/{url}", get(get_worker)) - .route("/workers/{url}", axum::routing::delete(delete_worker)); + .route("/workers/{url}", delete(delete_worker)); // Build app with all routes and middleware Router::new() @@ -525,17 +511,10 @@ pub fn build_app( .layer(tower_http::limit::RequestBodyLimitLayer::new( max_payload_size, )) - // Logging layer - must be added BEFORE request ID layer in the code - // so it executes AFTER request ID layer at runtime (layers execute bottom-up) - // This way the TraceLayer can see the request ID that was added to extensions - .layer(crate::middleware::create_logging_layer()) - // Request ID layer - adds request ID to extensions first - .layer(crate::middleware::RequestIdLayer::new(request_id_headers)) - // CORS (should be outermost) + .layer(middleware::create_logging_layer()) + .layer(middleware::RequestIdLayer::new(request_id_headers)) .layer(create_cors_layer(cors_allowed_origins)) - // Fallback .fallback(sink_handler) - // State - apply last to get Router> .with_state(app_state) } @@ -551,7 +530,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box() { Ok(l) => Some(l), Err(_) => { - warn!("Invalid log level string: '{}'. Defaulting to INFO.", s); + warn!("Invalid log level string: '{s}'. Defaulting to INFO."); None } }) @@ -582,11 +561,11 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Result<(), Box Result<(), Box { - warn!("Failed to create HTTP Regular router: {}", e); + warn!("Failed to create HTTP Regular router: {e}"); } } // 2. HTTP PD Router match RouterFactory::create_pd_router( - &[], // Empty prefill URLs - &[], // Empty decode URLs - None, // Use default prefill policy - None, // Use default decode policy + &[], + &[], + None, + None, &config.router_config.policy, &app_context, ) @@ -655,16 +631,11 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box { - warn!("Failed to create HTTP PD router: {}", e); + warn!("Failed to create HTTP PD router: {e}"); } } // TODO: Add gRPC routers once we have dynamic tokenizer loading - // Currently gRPC routers require tokenizer to be initialized first, - // but each model needs its own tokenizer. Once we implement dynamic - // tokenizer loading per model, we can enable gRPC routers here: - // - RouterType::GrpcRegular (RouterId: "grpc-regular") - // - RouterType::GrpcPd (RouterId: "grpc-pd") info!( "RouterManager initialized with {} routers", @@ -687,7 +658,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Result<(), Box { - error!("Failed to start service discovery: {}", e); + error!("Failed to start service discovery: {e}"); warn!("Continuing without service discovery"); } } @@ -736,7 +707,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Result<(), Box)?;