[router] refactor router and worker management 3/n (#10727)

This commit is contained in:
Simo Lin
2025-09-22 15:17:50 -04:00
committed by GitHub
parent 60dbbd086a
commit 97c3823931
25 changed files with 1427 additions and 2540 deletions

View File

@@ -8,7 +8,7 @@ use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
@@ -350,42 +350,3 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcPDRouter {
async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
let mut urls = Vec::new();
// Get gRPC prefill worker URLs only
let prefill_workers = self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false,
);
urls.extend(prefill_workers.iter().map(|w| w.url().to_string()));
// Get gRPC decode worker URLs only
let decode_workers = self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Decode),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false,
);
urls.extend(decode_workers.iter().map(|w| w.url().to_string()));
urls
}
}

View File

@@ -8,7 +8,7 @@ use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
@@ -279,29 +279,3 @@ impl RouterTrait for GrpcRouter {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcRouter {
async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
self.worker_registry
.get_workers_filtered(
None, // any model
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include all workers
)
.iter()
.map(|w| w.url().to_string())
.collect()
}
}

View File

@@ -65,25 +65,6 @@ impl OpenAIRouter {
}
}
#[async_trait]
impl super::super::WorkerManagement for OpenAIRouter {
async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
Err("Cannot add workers to OpenAI router".to_string())
}
fn remove_worker(&self, _worker_url: &str) {
// No-op for OpenAI router
}
fn get_worker_urls(&self) -> Vec<String> {
vec![self.base_url.clone()]
}
}
#[async_trait]
impl super::super::RouterTrait for OpenAIRouter {
fn as_any(&self) -> &dyn Any {

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,6 @@
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
Worker, WorkerRegistry, WorkerType,
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
@@ -10,7 +9,7 @@ use crate::protocols::spec::{
RerankRequest, RerankResponse, RerankResult, ResponsesRequest,
};
use crate::routers::header_utils;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::routers::RouterTrait;
use axum::body::to_bytes;
use axum::{
body::Body,
@@ -27,7 +26,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
use tracing::{debug, error};
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
@@ -35,13 +34,8 @@ pub struct Router {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
client: Client,
worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64,
dp_aware: bool,
#[allow(dead_code)]
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
}
@@ -56,30 +50,15 @@ impl Router {
false, // include all workers
);
// Update active workers gauge
RouterMetrics::set_active_workers(workers.len());
// Get worker URLs for monitoring
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Cache-aware policies are initialized in WorkerInitializer
// Setup load monitoring for PowerOfTwo policy
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
// Get default policy to check if we need load monitoring
let default_policy = ctx.policy_registry.get_default_policy();
// Check if default policy is power_of_two for load monitoring
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
let monitor_api_keys = monitor_urls
@@ -113,201 +92,13 @@ impl Router {
worker_registry: ctx.worker_registry.clone(),
policy_registry: ctx.policy_registry.clone(),
client: ctx.client.clone(),
worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs,
worker_startup_check_interval_secs: ctx
.router_config
.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config,
_worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle,
})
}
/// Get the current list of worker URLs
pub fn get_worker_urls(&self) -> Vec<String> {
self.worker_registry.get_all_urls()
}
/// Get worker URLs for a specific model
pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec<String> {
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(ConnectionMode::Http),
false, // get all workers
);
workers.iter().map(|w| w.url().to_string()).collect()
}
pub async fn wait_for_healthy_workers(
worker_urls: &[String],
worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64,
) -> Result<(), String> {
if worker_urls.is_empty() {
return Err(
"Timeout waiting for workers to become healthy: no workers provided".to_string(),
);
}
// Perform health check asynchronously
Self::wait_for_healthy_workers_async(
worker_urls,
worker_startup_timeout_secs,
worker_startup_check_interval_secs,
)
.await
}
async fn wait_for_healthy_workers_async(
worker_urls: &[String],
worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64,
) -> Result<(), String> {
info!(
"Waiting for {} workers to become healthy (timeout: {}s)",
worker_urls.len(),
worker_startup_timeout_secs
);
let start_time = std::time::Instant::now();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
loop {
if start_time.elapsed() > Duration::from_secs(worker_startup_timeout_secs) {
error!(
"Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
worker_startup_timeout_secs, worker_urls
);
return Err(format!(
"Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
worker_startup_timeout_secs, worker_urls
));
}
// Perform all health checks concurrently
let mut health_checks = Vec::new();
for url in worker_urls {
let client_clone = client.clone();
let url_clone = url.clone();
let check_health = tokio::spawn(async move {
let health_url = format!("{}/health", url_clone);
match client_clone.get(&health_url).send().await {
Ok(res) => {
if res.status().is_success() {
None
} else {
Some((url_clone, format!("status: {}", res.status())))
}
}
Err(_) => Some((url_clone, "not ready".to_string())),
}
});
health_checks.push(check_health);
}
// Wait for all health checks to complete
let results = futures::future::join_all(health_checks).await;
let mut all_healthy = true;
let mut unhealthy_workers = Vec::new();
for result in results {
match result {
Ok(None) => {
// Worker is healthy
}
Ok(Some((url, reason))) => {
all_healthy = false;
unhealthy_workers.push((url, reason));
}
Err(e) => {
all_healthy = false;
unhealthy_workers
.push(("unknown".to_string(), format!("task error: {}", e)));
}
}
}
if all_healthy {
info!("All {} workers are healthy", worker_urls.len());
return Ok(());
} else {
debug!(
"Waiting for {} workers to become healthy ({} unhealthy: {:?})",
worker_urls.len(),
unhealthy_workers.len(),
unhealthy_workers
);
tokio::time::sleep(Duration::from_secs(worker_startup_check_interval_secs)).await;
}
}
}
fn get_worker_dp_size(worker_url: &str, api_key: &Option<String>) -> Result<usize, String> {
let sync_client = reqwest::blocking::Client::new();
let mut req_builder = sync_client.get(format!("{}/get_server_info", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send() {
Ok(res) => {
if res.status().is_success() {
let server_info = res
.text()
.map_err(|e| format!("failed to read text from response: {}", e))?;
let server_info: serde_json::Value = serde_json::from_str(&server_info)
.map_err(|e| format!("failed to decode JSON: {}", e))?;
let dp_size = server_info
.get("dp_size")
.and_then(|v| v.as_u64())
.ok_or_else(|| String::from("dp_size not found or not an u64"))?;
Ok(if dp_size > usize::MAX as u64 {
return Err(format!("dp_size is too large: {}", dp_size));
} else {
dp_size as usize
})
} else {
Err(format!("unexpected status code: {}", res.status()))
}
}
Err(e) => Err(format!("error response: {}", e)),
}
}
// Given a list of workers, return a list of workers with dp_rank as suffix
fn get_dp_aware_workers(
worker_urls: &[String],
api_key: &Option<String>,
) -> Result<Vec<String>, String> {
let mut dp_aware_workers: Vec<String> = Vec::new();
for url in worker_urls {
match Self::get_worker_dp_size(url, api_key) {
Ok(dp_size) => {
for i in 0..dp_size {
dp_aware_workers.push(format!("{}@{}", url, i));
}
}
Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)),
}
}
Ok(dp_aware_workers)
}
fn select_first_worker(&self) -> Result<String, String> {
let workers = self.worker_registry.get_all();
if workers.is_empty() {
@@ -317,65 +108,6 @@ impl Router {
}
}
pub async fn send_health_check(&self, worker_url: &str) -> Response {
let health_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
match Self::extract_dp_rank(worker_url) {
Ok((worker_url_prefix, _dp_rank)) => worker_url_prefix,
Err(e) => {
error!("Failed to extract dp_rank for health check: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to extract dp_rank: {}", e),
)
.into_response();
}
}
} else {
worker_url
};
let request_builder = self.client.get(format!("{}/health", health_url));
let response = match request_builder.send().await {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => (status, body).into_response(),
Err(e) => {
error!(
worker_url = %health_url,
error = %e,
"Failed to read health response body"
);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
}
}
Err(e) => {
error!(
worker_url = %health_url,
error = %e,
"Failed to send health request to worker"
);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to send request to worker {}: {}", health_url, e),
)
.into_response()
}
};
// Don't record metrics for health checks
response
}
// Helper method to proxy GET requests to the first available worker
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
let headers = header_utils::copy_request_headers(&req);
@@ -575,14 +307,15 @@ impl Router {
) -> Response {
// TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
// Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
let worker_urls = self.get_worker_urls();
if worker_urls.is_empty() {
let workers = self.worker_registry.get_all();
if workers.is_empty() {
return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
}
let mut last_response: Option<Response> = None;
for worker_url in worker_urls {
let base = self.worker_base_url(&worker_url);
for worker in workers {
let worker_url = worker.url();
let base = self.worker_base_url(worker_url);
let url = format!("{}/{}", base, endpoint);
let mut request_builder = match method {
@@ -597,6 +330,11 @@ impl Router {
}
};
if let Some(api_key) = worker.api_key() {
request_builder =
request_builder.header("Authorization", format!("Bearer {}", api_key));
}
if let Some(hdrs) = headers {
for (name, value) in hdrs {
let name_lc = name.as_str().to_lowercase();
@@ -691,6 +429,12 @@ impl Router {
is_stream: bool,
load_incremented: bool, // Whether load was incremented for this request
) -> Response {
// Get the worker's API key if available
let api_key = self
.worker_registry
.get_by_url(worker_url)
.and_then(|w| w.api_key().clone());
let mut request_builder = if self.dp_aware {
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
@@ -704,7 +448,6 @@ impl Router {
}
};
// Parse the request body
let mut json_val = match serde_json::to_value(typed_req) {
Ok(j) => j,
Err(e) => {
@@ -716,7 +459,6 @@ impl Router {
}
};
// Insert the data_parallel_rank field
if let Some(map) = json_val.as_object_mut() {
map.insert(
String::from("data_parallel_rank"),
@@ -743,6 +485,10 @@ impl Router {
.json(typed_req) // Use json() directly with typed request
};
if let Some(key) = api_key {
request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
}
// Copy all headers from original request if provided
if let Some(headers) = headers {
for (name, value) in headers {
@@ -909,215 +655,6 @@ impl Router {
}
}
pub async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
let start_time = std::time::Instant::now();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.worker_startup_timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
loop {
if start_time.elapsed() > Duration::from_secs(self.worker_startup_timeout_secs) {
error!(
"Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
self.worker_startup_timeout_secs, worker_url
);
return Err(format!(
"Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
self.worker_startup_timeout_secs, worker_url
));
}
match client.get(format!("{}/health", worker_url)).send().await {
Ok(res) => {
if res.status().is_success() {
if self.dp_aware {
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
let url_vec = vec![String::from(worker_url)];
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
let mut worker_added: bool = false;
for dp_url in &dp_url_vec {
if self.worker_registry.get_by_url(dp_url).is_some() {
warn!("Worker {} already exists", dp_url);
continue;
}
info!("Added worker: {}", dp_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker_builder =
BasicWorkerBuilder::new(dp_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(
self.circuit_breaker_config.clone(),
);
let new_worker = if let Some(api_key) = api_key {
new_worker_builder.api_key(api_key).build()
} else {
new_worker_builder.build()
};
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
self.policy_registry.on_worker_added(model_id, None);
// Initialize cache-aware policy if applicable
let model_workers = self.worker_registry.get_workers_filtered(
Some(model_id),
Some(WorkerType::Regular),
Some(ConnectionMode::Http),
false,
);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
worker_added = true;
}
if !worker_added {
return Err(format!("No worker added for {}", worker_url));
}
} else {
if self.worker_registry.get_by_url(worker_url).is_some() {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker_builder =
BasicWorkerBuilder::new(worker_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(self.circuit_breaker_config.clone());
let new_worker = if let Some(api_key) = api_key {
new_worker_builder.api_key(api_key).build()
} else {
new_worker_builder.build()
};
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
self.policy_registry.on_worker_added(model_id, None);
// Initialize cache-aware policy if applicable
let model_workers = self.worker_registry.get_workers_filtered(
Some(model_id),
Some(WorkerType::Regular),
Some(ConnectionMode::Http),
false,
);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
}
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
return Ok(format!("Successfully added worker: {}", worker_url));
} else {
debug!(
"Worker {} health check pending - status: {}",
worker_url,
res.status()
);
// if the url does not have http or https prefix, warn users
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
{
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
}
tokio::time::sleep(Duration::from_secs(
self.worker_startup_check_interval_secs,
))
.await;
continue;
}
}
Err(e) => {
debug!("Worker {} health check pending - error: {}", worker_url, e);
// if the url does not have http or https prefix, warn users
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
}
tokio::time::sleep(Duration::from_secs(
self.worker_startup_check_interval_secs,
))
.await;
continue;
}
}
}
}
pub fn remove_worker(&self, worker_url: &str) {
if self.dp_aware {
// remove dp-aware workers in a prefix-matching fashion
// without contacting the remote worker
let mut removed_workers: Vec<String> = Vec::new();
let worker_url_prefix = format!("{}@", worker_url);
// Find and remove all workers with matching prefix
let all_workers = self.worker_registry.get_all();
for w in all_workers.iter() {
if w.url().starts_with(&worker_url_prefix) {
// Get model_id before removing
let model_id = w.model_id().to_string();
if self.worker_registry.remove_by_url(w.url()).is_some() {
info!("Removed worker: {}", w.url());
removed_workers.push(w.url().to_string());
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
} else {
warn!("Worker {} not found, skipping removal", w.url());
}
}
}
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
for dp_url in removed_workers.iter() {
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
let model_id = worker.model_id();
self.policy_registry
.remove_worker_from_cache_aware(model_id, dp_url);
}
}
} else {
// Get the worker first to extract model_id
let model_id = if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
worker.model_id().to_string()
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
};
if self.worker_registry.remove_by_url(worker_url).is_some() {
info!("Removed worker: {}", worker_url);
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
}
self.policy_registry
.remove_worker_from_cache_aware(&model_id, worker_url);
}
}
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
@@ -1205,7 +742,7 @@ impl Router {
// Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(
client: &reqwest::Client,
client: &Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
@@ -1281,25 +818,6 @@ impl Router {
use async_trait::async_trait;
#[async_trait]
impl WorkerManagement for Router {
async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
Router::add_worker(self, worker_url, api_key).await
}
fn remove_worker(&self, worker_url: &str) {
Router::remove_worker(self, worker_url)
}
fn get_worker_urls(&self) -> Vec<String> {
Router::get_worker_urls(self)
}
}
#[async_trait]
impl RouterTrait for Router {
fn as_any(&self) -> &dyn std::any::Any {
@@ -1445,12 +963,19 @@ impl RouterTrait for Router {
}
async fn flush_cache(&self) -> Response {
// Get all worker URLs
let worker_urls = self.get_worker_urls();
// Get all workers
let workers = self.worker_registry.get_all();
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
// Send requests to all workers concurrently without headers
let mut tasks = Vec::new();
for worker_url in &worker_urls {
// Get the worker's API key if available
let api_key = self
.worker_registry
.get_by_url(worker_url)
.and_then(|w| w.api_key().clone());
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
@@ -1468,7 +993,13 @@ impl RouterTrait for Router {
} else {
worker_url
};
let request_builder = self.client.post(format!("{}/flush_cache", worker_url));
let mut request_builder = self.client.post(format!("{}/flush_cache", worker_url));
if let Some(key) = api_key {
request_builder =
request_builder.header("Authorization", format!("Bearer {}", key));
}
tasks.push(request_builder.send());
}
@@ -1546,6 +1077,7 @@ impl RouterTrait for Router {
#[cfg(test)]
mod tests {
use super::*;
use crate::core::BasicWorkerBuilder;
use std::collections::HashMap;
fn create_test_regular_router() -> Router {
@@ -1558,11 +1090,9 @@ mod tests {
// Register test workers
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
worker_registry.register(Arc::new(worker1));
worker_registry.register(Arc::new(worker2));
@@ -1571,13 +1101,9 @@ mod tests {
Router {
worker_registry,
policy_registry,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
client: Client::new(),
retry_config: RetryConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
}
@@ -1586,7 +1112,8 @@ mod tests {
#[test]
fn test_router_get_worker_urls_regular() {
let router = create_test_regular_router();
let urls = router.get_worker_urls();
let workers = router.worker_registry.get_all();
let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
assert_eq!(urls.len(), 2);
assert!(urls.contains(&"http://worker1:8080".to_string()));
@@ -1603,21 +1130,4 @@ mod tests {
// DashMap doesn't guarantee order, so just check we get one of the workers
assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
}
#[tokio::test]
async fn test_wait_for_healthy_workers_empty_list() {
// Empty list will return error immediately
let result = Router::wait_for_healthy_workers(&[], 1, 1).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("no workers provided"));
}
#[tokio::test]
async fn test_wait_for_healthy_workers_invalid_urls() {
// This test will timeout quickly since the URLs are invalid
let result =
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
}
}

View File

@@ -19,39 +19,18 @@ pub mod grpc;
pub mod header_utils;
pub mod http;
pub mod router_manager;
pub mod worker_initializer;
pub use factory::RouterFactory;
pub use worker_initializer::WorkerInitializer;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
pub use http::{openai_router, pd_router, pd_types, router};
/// Worker management trait for administrative operations
///
/// This trait is separate from RouterTrait to allow Send futures
/// for use in service discovery and other background tasks
#[async_trait]
pub trait WorkerManagement: Send + Sync {
/// Add a worker to the router
async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String>;
/// Remove a worker from the router
fn remove_worker(&self, worker_url: &str);
/// Get all worker URLs
fn get_worker_urls(&self) -> Vec<String>;
}
/// Core trait for all router implementations
///
/// This trait provides a unified interface for routing requests,
/// regardless of whether it's a regular router or PD router.
#[async_trait]
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
pub trait RouterTrait: Send + Sync + Debug {
/// Get a reference to self as Any for downcasting
fn as_any(&self) -> &dyn std::any::Any;

View File

@@ -4,17 +4,12 @@
//! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use crate::config::RouterConfig;
use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig, Worker, WorkerRegistry, WorkerType};
use crate::core::{Worker, WorkerRegistry, WorkerType};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest,
};
use crate::protocols::worker_spec::{
ServerInfo, WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse, WorkerInfo,
WorkerListResponse, WorkerStats, WorkerTypeStats,
};
use crate::routers::{RouterTrait, WorkerManagement};
use crate::routers::RouterTrait;
use async_trait::async_trait;
use axum::{
body::Body,
@@ -24,7 +19,7 @@ use axum::{
};
use dashmap::DashMap;
use std::sync::Arc;
use tracing::{info, warn};
use tracing::info;
/// Router identifier
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
@@ -45,48 +40,28 @@ pub struct RouterManager {
/// Worker registry (single source of truth in multi-router mode)
worker_registry: Arc<WorkerRegistry>,
/// Policy registry for managing model-to-policy mappings
policy_registry: Arc<crate::policies::PolicyRegistry>,
/// All routers managed by this manager
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
/// Default router for requests without specific routing
default_router: Arc<std::sync::RwLock<Option<RouterId>>>,
/// HTTP client for querying worker info
client: reqwest::Client,
/// Configuration
#[allow(dead_code)] // May be used in future enhancements
config: RouterConfig,
}
impl RouterManager {
/// Create a new router manager with shared registries
pub fn new(
config: RouterConfig,
client: reqwest::Client,
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<crate::policies::PolicyRegistry>,
) -> Self {
pub fn new(worker_registry: Arc<WorkerRegistry>) -> Self {
Self {
worker_registry,
policy_registry,
routers: Arc::new(DashMap::new()),
default_router: Arc::new(std::sync::RwLock::new(None)),
client,
config,
}
}
/// Register a router with the manager
pub fn register_router(&self, id: RouterId, router: Arc<dyn RouterTrait>) {
// Store router
self.routers.insert(id.clone(), router);
// Set as default if first router
let mut default_router = self.default_router.write().unwrap();
if default_router.is_none() {
*default_router = Some(id.clone());
@@ -107,11 +82,9 @@ impl RouterManager {
/// Get router for a specific model based on worker types
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
// Query workers for this model from registry
let workers = self.worker_registry.get_by_model(model_id);
if !workers.is_empty() {
// Determine router based on worker types
let has_pd_workers = workers.iter().any(|w| {
matches!(
w.worker_type(),
@@ -125,13 +98,11 @@ impl RouterManager {
RouterId::new("http-regular".to_string())
};
// Return the router if it exists
if let Some(router) = self.routers.get(&router_id) {
return Some(router.clone());
}
}
// Fall back to default router
let default_router = self.default_router.read().unwrap();
if let Some(ref default_id) = *default_router {
self.routers.get(default_id).map(|r| r.clone())
@@ -149,277 +120,12 @@ impl RouterManager {
}
}
/// Add a worker to the registry
pub async fn add_worker(
&self,
config: WorkerConfigRequest,
) -> Result<WorkerApiResponse, WorkerErrorResponse> {
// Build labels from configuration
let mut labels = config.labels.clone();
// Query server info if model_id not provided
let model_id = if let Some(model_id) = config.model_id {
model_id
} else {
match self.query_server_info(&config.url, &config.api_key).await {
Ok(info) => {
// Extract model_id from server info
info.model_id
.or_else(|| {
info.model_path
.as_ref()
.and_then(|path| path.split('/').next_back().map(|s| s.to_string()))
})
.unwrap_or_else(|| "unknown".to_string())
}
Err(e) => {
warn!("Failed to query server info from {}: {}", config.url, e);
"unknown".to_string()
}
}
};
// Add configuration to labels
labels.insert("model_id".to_string(), model_id.clone());
if let Some(priority) = config.priority {
labels.insert("priority".to_string(), priority.to_string());
}
if let Some(cost) = config.cost {
labels.insert("cost".to_string(), cost.to_string());
}
// Add gRPC-specific configuration if provided
if let Some(tokenizer_path) = config.tokenizer_path {
labels.insert("tokenizer_path".to_string(), tokenizer_path);
}
if let Some(reasoning_parser) = config.reasoning_parser {
labels.insert("reasoning_parser".to_string(), reasoning_parser);
}
if let Some(tool_parser) = config.tool_parser {
labels.insert("tool_parser".to_string(), tool_parser);
}
if let Some(chat_template) = config.chat_template {
labels.insert("chat_template".to_string(), chat_template);
}
let worker = match config.worker_type.as_deref() {
Some("prefill") => {
let mut builder = BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Prefill {
bootstrap_port: config.bootstrap_port,
})
.labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default());
if let Some(api_key) = config.api_key.clone() {
builder = builder.api_key(api_key);
}
Box::new(builder.build()) as Box<dyn Worker>
}
Some("decode") => {
let mut builder = BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Decode)
.labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default());
if let Some(api_key) = config.api_key.clone() {
builder = builder.api_key(api_key);
}
Box::new(builder.build()) as Box<dyn Worker>
}
_ => {
let mut builder = BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Regular)
.labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default());
if let Some(api_key) = config.api_key.clone() {
builder = builder.api_key(api_key);
}
Box::new(builder.build()) as Box<dyn Worker>
}
};
// Register worker
let worker_arc: Arc<dyn Worker> = Arc::from(worker);
let worker_id = self.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
// Extract policy hint from labels if provided
let policy_hint = labels.get("policy").map(|s| s.as_str());
let policy = self.policy_registry.on_worker_added(&model_id, policy_hint);
// Log which type of router would handle this worker (for debugging)
let expected_router = match config.worker_type.as_deref() {
Some("prefill") | Some("decode") => "http-pd",
_ => "http-regular",
};
info!(
"Worker for model '{}' would be handled by '{}' router based on type",
model_id, expected_router
);
info!(
"Added worker {} with URL {} for model {} using policy {}",
worker_id.as_str(),
config.url,
model_id,
policy.name()
);
// Return worker info
let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc);
Ok(WorkerApiResponse {
success: true,
message: format!("Worker {} added successfully", worker_id.as_str()),
worker: Some(worker_info),
})
}
/// Remove a worker from the registry
pub fn remove_worker_from_registry(
&self,
url: &str,
) -> Result<WorkerApiResponse, WorkerErrorResponse> {
// Get worker to extract model_id before removing
let model_id = self
.worker_registry
.get_by_url(url)
.map(|worker| worker.model_id().to_string());
if let Some(_worker) = self.worker_registry.remove_by_url(url) {
// Notify PolicyRegistry about worker removal
if let Some(ref model_id) = model_id {
self.policy_registry.on_worker_removed(model_id);
info!("Removed worker with URL {} for model {}", url, model_id);
} else {
info!("Removed worker with URL {}", url);
}
Ok(WorkerApiResponse {
success: true,
message: format!("Worker {} removed successfully", url),
worker: None,
})
} else {
Err(WorkerErrorResponse {
error: format!("Worker with URL {} not found", url),
code: "WORKER_NOT_FOUND".to_string(),
})
}
}
/// List all workers
pub fn list_workers(&self) -> WorkerListResponse {
let workers = self.worker_registry.get_all_with_ids();
let worker_infos: Vec<WorkerInfo> = workers
.iter()
.map(|(id, w)| self.worker_to_info(id.as_str(), w))
.collect();
let total = worker_infos.len();
// Get stats from the worker registry
let registry_stats = self.worker_registry.stats();
// Convert WorkerRegistryStats to WorkerStats
let stats = WorkerStats {
total_workers: registry_stats.total_workers,
healthy_workers: registry_stats.healthy_workers,
total_models: registry_stats.total_models,
total_load: registry_stats.total_load,
by_type: WorkerTypeStats {
regular: registry_stats.regular_workers,
prefill: registry_stats.prefill_workers,
decode: registry_stats.decode_workers,
},
};
WorkerListResponse {
workers: worker_infos,
total,
stats,
}
}
/// Get worker by URL
pub fn get_worker(&self, url: &str) -> Option<WorkerInfo> {
self.worker_registry
.get_by_url(url)
.map(|w| self.worker_to_info("unknown", &w))
}
/// Query server info from a worker URL
async fn query_server_info(
&self,
url: &str,
api_key: &Option<String>,
) -> Result<ServerInfo, String> {
let info_url = format!("{}/get_server_info", url.trim_end_matches('/'));
let mut req_builder = self.client.get(&info_url);
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(response) => {
if response.status().is_success() {
response
.json::<ServerInfo>()
.await
.map_err(|e| format!("Failed to parse server info: {}", e))
} else {
Err(format!("Server returned status: {}", response.status()))
}
}
Err(e) => Err(format!("Failed to connect to server: {}", e)),
}
}
/// Convert Worker to WorkerInfo
fn worker_to_info(&self, id: &str, worker: &Arc<dyn Worker>) -> WorkerInfo {
let metadata = worker.metadata();
WorkerInfo {
id: id.to_string(),
url: worker.url().to_string(),
model_id: worker.model_id().to_string(),
priority: worker.priority(),
cost: worker.cost(),
worker_type: match worker.worker_type() {
WorkerType::Regular => "regular".to_string(),
WorkerType::Prefill { .. } => "prefill".to_string(),
WorkerType::Decode => "decode".to_string(),
},
is_healthy: worker.is_healthy(),
load: worker.load(),
connection_mode: format!("{:?}", worker.connection_mode()),
tokenizer_path: worker.tokenizer_path().map(|s| s.to_string()),
reasoning_parser: worker.reasoning_parser().map(|s| s.to_string()),
tool_parser: worker.tool_parser().map(|s| s.to_string()),
chat_template: worker.chat_template().map(|s| s.to_string()),
metadata: metadata.labels.clone(),
}
}
/// Get the appropriate router for a request based on headers and request content
pub fn select_router_for_request(
&self,
headers: Option<&HeaderMap>,
model_id: Option<&str>,
) -> Option<Arc<dyn RouterTrait>> {
// Extract priority and cost preferences from headers if available
let _priority_threshold = headers.and_then(|h| {
h.get("x-worker-priority")
.and_then(|v| v.to_str().ok())
@@ -432,7 +138,6 @@ impl RouterManager {
.and_then(|s| s.parse::<f32>().ok())
});
// Check if PD (prefill-decode) mode is preferred from headers
let prefer_pd = headers
.and_then(|h| {
h.get("x-prefer-pd")
@@ -441,7 +146,6 @@ impl RouterManager {
})
.unwrap_or(false);
// If model specified, use get_router_for_model
let candidate_routers = if let Some(model) = model_id {
if let Some(router) = self.get_router_for_model(model) {
vec![router]
@@ -449,7 +153,6 @@ impl RouterManager {
Vec::new()
}
} else {
// No model specified, consider all routers
self.routers
.iter()
.map(|entry| entry.value().clone())
@@ -457,23 +160,20 @@ impl RouterManager {
};
if candidate_routers.is_empty() {
// No routers found for the specified model
return None;
}
// Score routers based on worker attributes and request preferences
let mut best_router = None;
let mut best_score = 0.0;
for router in candidate_routers {
let mut score = 1.0;
// Check if this is a PD router
let is_pd = router.is_pd_mode();
if prefer_pd && is_pd {
score += 2.0; // Bonus for matching PD preference
score += 2.0;
} else if !prefer_pd && !is_pd {
score += 1.0; // Bonus for matching regular preference
score += 1.0;
}
// Get workers for this router and evaluate based on priority/cost
@@ -495,49 +195,6 @@ impl RouterManager {
}
}
/// RouterManager implements RouterTrait to act as a meta-router
/// that delegates requests to the appropriate underlying router
#[async_trait]
impl WorkerManagement for RouterManager {
/// Add a worker - in multi-router mode, this adds to the registry
async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
// Create a basic worker config request
let config = WorkerConfigRequest {
url: worker_url.to_string(),
api_key: api_key.clone(),
model_id: None,
worker_type: None,
priority: None,
cost: None,
labels: std::collections::HashMap::new(),
bootstrap_port: None,
tokenizer_path: None,
reasoning_parser: None,
tool_parser: None,
chat_template: None,
};
match self.add_worker(config).await {
Ok(response) => Ok(response.message),
Err(e) => Err(e.error),
}
}
/// Remove a worker from the registry
fn remove_worker(&self, worker_url: &str) {
let _ = self.remove_worker_from_registry(worker_url);
}
/// Get all worker URLs from the registry
fn get_worker_urls(&self) -> Vec<String> {
self.worker_registry.get_all_urls()
}
}
#[async_trait]
impl RouterTrait for RouterManager {
fn as_any(&self) -> &dyn std::any::Any {
@@ -639,7 +296,6 @@ impl RouterTrait for RouterManager {
body: &ChatCompletionRequest,
_model_id: Option<&str>,
) -> Response {
// Select router based on headers and model
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
@@ -662,7 +318,6 @@ impl RouterTrait for RouterManager {
body: &CompletionRequest,
_model_id: Option<&str>,
) -> Response {
// Select router based on headers and model
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
@@ -746,7 +401,6 @@ impl RouterTrait for RouterManager {
body: &EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
// Select router based on headers and model
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {

View File

@@ -1,497 +0,0 @@
// Worker Initialization Module
// Separates worker lifecycle management from router construction
use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode};
use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, Worker, WorkerRegistry,
WorkerType,
};
use crate::policies::PolicyRegistry;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn};
/// WorkerInitializer handles the creation and registration of workers
/// based on routing configuration, separating this concern from router constructors
pub struct WorkerInitializer;
impl WorkerInitializer {
/// Initialize workers based on configuration and register them in the WorkerRegistry
pub async fn initialize_workers(
config: &RouterConfig,
worker_registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Initializing workers for routing mode: {:?}", config.mode);
match &config.mode {
RoutingMode::Regular { worker_urls } => {
// use router's api_key, repeat for each worker
let worker_api_keys: Vec<Option<String>> =
worker_urls.iter().map(|_| config.api_key.clone()).collect();
Self::create_regular_workers(
worker_urls,
&worker_api_keys,
&config.connection_mode,
config,
worker_registry,
policy_registry,
)
.await?;
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
..
} => {
// use router's api_key, repeat for each prefill/decode worker
let prefill_api_keys: Vec<Option<String>> = prefill_urls
.iter()
.map(|_| config.api_key.clone())
.collect();
let decode_api_keys: Vec<Option<String>> =
decode_urls.iter().map(|_| config.api_key.clone()).collect();
Self::create_prefill_workers(
prefill_urls,
&prefill_api_keys,
&config.connection_mode,
config,
worker_registry,
policy_registry,
)
.await?;
Self::create_decode_workers(
decode_urls,
&decode_api_keys,
&config.connection_mode,
config,
worker_registry,
policy_registry,
)
.await?;
}
RoutingMode::OpenAI { .. } => {
info!("OpenAI routing mode - no local workers to initialize");
}
}
// Wait for workers to be healthy if any were registered
if worker_registry.stats().total_workers > 0 {
Self::wait_for_healthy_workers(
worker_registry,
config.worker_startup_timeout_secs,
config.worker_startup_check_interval_secs,
)
.await?;
}
Ok(())
}
/// Create regular workers for standard routing mode
async fn create_regular_workers(
urls: &[String],
api_keys: &[Option<String>],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} regular workers", urls.len());
// Convert config connection mode to core connection mode
let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first());
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for (url, api_key) in urls.iter().zip(api_keys.iter()) {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Regular)
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone());
let worker = if let Some(api_key) = api_key.clone() {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered regular worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
// Initialize cache-aware policies with all workers for each model
if let Some(policy_reg) = policy_registry {
for (model_id, workers) in registered_workers {
policy_reg.init_cache_aware_policy(&model_id, &workers);
}
}
Ok(())
}
/// Create prefill workers for disaggregated routing mode
async fn create_prefill_workers(
prefill_entries: &[(String, Option<u16>)],
api_keys: &[Option<String>],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} prefill workers", prefill_entries.len());
// Convert config connection mode to core connection mode
let connection_mode = Self::convert_connection_mode(
config_connection_mode,
prefill_entries.first().map(|(url, _)| url),
);
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for ((url, bootstrap_port), api_key) in prefill_entries.iter().zip(api_keys.iter()) {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
})
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone());
let worker = if let Some(api_key) = api_key.clone() {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered prefill worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
// Initialize cache-aware policies for PD mode
if let Some(policy_reg) = policy_registry {
// Collect all prefill workers
let all_prefill_workers: Vec<Arc<dyn Worker>> = registered_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
// Initialize PD policies (will handle both prefill and decode, but we only have prefill here)
policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &[]);
}
Ok(())
}
/// Create decode workers for disaggregated routing mode
async fn create_decode_workers(
urls: &[String],
api_keys: &[Option<String>],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} decode workers", urls.len());
// Convert config connection mode to core connection mode
let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first());
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for (url, api_key) in urls.iter().zip(api_keys.iter()) {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Decode)
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone());
let worker = if let Some(api_key) = api_key.clone() {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered decode worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
// Initialize cache-aware policies for PD mode
if let Some(policy_reg) = policy_registry {
// Collect all decode workers
let all_decode_workers: Vec<Arc<dyn Worker>> = registered_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
// Initialize PD policies (will handle both prefill and decode, but we only have decode here)
policy_reg.init_pd_cache_aware_policies(&[], &all_decode_workers);
}
Ok(())
}
/// Convert config connection mode to core connection mode
fn convert_connection_mode(
config_mode: &ConfigConnectionMode,
_sample_url: Option<&String>,
) -> ConnectionMode {
match config_mode {
ConfigConnectionMode::Http => ConnectionMode::Http,
ConfigConnectionMode::Grpc => ConnectionMode::Grpc { port: None },
}
}
/// Wait for workers to become healthy
async fn wait_for_healthy_workers(
registry: &Arc<WorkerRegistry>,
timeout_secs: u64,
check_interval_secs: u64,
) -> Result<(), String> {
let timeout = Duration::from_secs(timeout_secs);
let check_interval = Duration::from_secs(check_interval_secs);
let start_time = std::time::Instant::now();
info!(
"Waiting for workers to become healthy (timeout: {}s)",
timeout_secs
);
loop {
let stats = registry.stats();
if stats.healthy_workers > 0 {
info!(
"Workers healthy: {}/{} workers are ready",
stats.healthy_workers, stats.total_workers
);
// If we have at least one healthy worker, we can proceed
// This allows partial degradation rather than total failure
return Ok(());
}
if start_time.elapsed() > timeout {
let error_msg = format!(
"Timeout waiting for workers to become healthy after {}s. Total workers: {}, Healthy: {}",
timeout_secs, stats.total_workers, stats.healthy_workers
);
warn!("{}", error_msg);
// If we have workers but none are healthy, it's still a failure
if stats.total_workers > 0 {
return Err(error_msg);
} else {
// No workers at all might be OK for some configurations
warn!("No workers registered, proceeding anyway");
return Ok(());
}
}
tokio::time::sleep(check_interval).await;
}
}
/// Initialize workers for gRPC connections specifically
/// This is used when gRPC clients are pre-connected
pub async fn initialize_grpc_workers(
worker_urls: &[String],
worker_type: WorkerType,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
grpc_clients: &mut HashMap<String, crate::grpc::SglangSchedulerClient>,
) -> Result<(), String> {
info!(
"Creating {} gRPC workers of type {:?}",
worker_urls.len(),
worker_type
);
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in worker_urls {
if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(worker_type.clone())
.connection_mode(ConnectionMode::Grpc { port: None })
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.grpc_client(client)
.build();
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered gRPC worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
} else {
warn!("No gRPC client available for worker {}, skipping", url);
}
}
// Initialize cache-aware policies with all workers for each model
if let Some(policy_reg) = policy_registry {
for (model_id, workers) in registered_workers {
policy_reg.init_cache_aware_policy(&model_id, &workers);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convert_connection_mode() {
// HTTP mode
assert!(matches!(
WorkerInitializer::convert_connection_mode(
&ConfigConnectionMode::Http,
Some(&"http://localhost:8080".to_string())
),
ConnectionMode::Http
));
// gRPC mode
assert!(matches!(
WorkerInitializer::convert_connection_mode(
&ConfigConnectionMode::Grpc,
Some(&"grpc://localhost:50051".to_string())
),
ConnectionMode::Grpc { .. }
));
// No URL provided
assert!(matches!(
WorkerInitializer::convert_connection_mode(&ConfigConnectionMode::Http, None),
ConnectionMode::Http
));
}
}