[router] add worker abstraction (#7960)

This commit is contained in:
Simo Lin
2025-07-11 20:17:48 -07:00
committed by GitHub
parent 2a2d3478af
commit f2d5c4920e
11 changed files with 960 additions and 410 deletions

View File

@@ -1,8 +1,9 @@
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
use crate::pd_types::{
Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy,
api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy,
};
use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt};
use metrics::{counter, histogram};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use tracing::{debug, error, info, warn};
@@ -21,49 +21,17 @@ use uuid::Uuid;
#[derive(Debug)]
pub struct PDRouter {
pub prefill_workers: Arc<RwLock<Vec<EngineInfo>>>,
pub decode_workers: Arc<RwLock<Vec<EngineInfo>>>,
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub selection_policy: PDSelectionPolicy,
pub load_tracking: Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
pub timeout_secs: u64,
pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub http_client: reqwest::Client,
}
// RAII guard for load tracking to ensure cleanup even on panic
struct LoadGuard<'a> {
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
urls: Vec<String>,
}
impl<'a> LoadGuard<'a> {
fn new(
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
urls: Vec<String>,
) -> Self {
// Increment counters
for url in &urls {
let counter = tracking
.entry(url.clone())
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
counter.fetch_add(1, Ordering::Relaxed);
}
LoadGuard { tracking, urls }
}
}
impl Drop for LoadGuard<'_> {
fn drop(&mut self) {
// Guaranteed cleanup even on panic
for url in &self.urls {
if let Some(counter) = self.tracking.get(url) {
counter.fetch_sub(1, Ordering::Relaxed);
}
}
}
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
}
impl PDRouter {
@@ -73,9 +41,6 @@ impl PDRouter {
url: String,
bootstrap_port: Option<u16>,
) -> Result<String, PDRouterError> {
// Create EngineInfo for the new prefill server
let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port);
// Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers(
&[url.clone()],
@@ -84,6 +49,9 @@ impl PDRouter {
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Create Worker for the new prefill server
let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port);
// Add to prefill workers list
let mut workers = self
.prefill_workers
@@ -93,15 +61,11 @@ impl PDRouter {
})?;
// Check if already exists
if workers.iter().any(|w| w.url == url) {
if workers.iter().any(|w| w.url() == &url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
workers.push(engine_info);
// Initialize load tracking
self.load_tracking
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
workers.push(worker);
// Add to cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
@@ -113,9 +77,6 @@ impl PDRouter {
}
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
// Create EngineInfo for the new decode server
let engine_info = EngineInfo::new_decode(url.clone());
// Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers(
&[url.clone()],
@@ -124,6 +85,9 @@ impl PDRouter {
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Create Worker for the new decode server
let worker = WorkerFactory::create_decode(url.clone());
// Add to decode workers list
let mut workers = self
.decode_workers
@@ -133,15 +97,14 @@ impl PDRouter {
})?;
// Check if already exists
if workers.iter().any(|w| w.url == url) {
if workers.iter().any(|w| w.url() == &url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
workers.push(engine_info);
workers.push(worker);
// Initialize load tracking
self.load_tracking
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
// Worker tracks its own load internally
info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url))
@@ -157,7 +120,7 @@ impl PDRouter {
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url != url);
workers.retain(|w| w.url() != url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
@@ -166,7 +129,7 @@ impl PDRouter {
}
// Remove from load tracking
self.load_tracking.remove(url);
// Worker load tracking is internal
// Remove from cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
@@ -174,7 +137,7 @@ impl PDRouter {
let mut tree_guard = tree.lock().unwrap();
*tree_guard = Tree::new();
for worker in workers.iter() {
tree_guard.insert("", &worker.url);
tree_guard.insert("", worker.url());
}
}
@@ -192,7 +155,7 @@ impl PDRouter {
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url != url);
workers.retain(|w| w.url() != url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
@@ -200,9 +163,6 @@ impl PDRouter {
});
}
// Remove from load tracking
self.load_tracking.remove(url);
info!("Removed decode server: {}", url);
Ok(format!("Successfully removed decode server: {}", url))
}
@@ -214,41 +174,32 @@ impl PDRouter {
timeout_secs: u64,
interval_secs: u64,
) -> Result<Self, String> {
// Convert URLs to EngineInfo
let prefill_workers: Vec<EngineInfo> = prefill_urls
// Convert URLs to Worker trait objects
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.into_iter()
.map(|(url, port)| EngineInfo::new_prefill(url, port))
.map(|(url, port)| WorkerFactory::create_prefill(url, port))
.collect();
let decode_workers: Vec<EngineInfo> = decode_urls
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
.into_iter()
.map(EngineInfo::new_decode)
.map(WorkerFactory::create_decode)
.collect();
// Wait for PD workers to be healthy
let all_urls: Vec<String> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|engine| engine.url.clone())
.map(|worker| worker.url().to_string())
.collect();
crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?;
// Initialize load tracking with atomic counters
let load_tracking = Arc::new(dashmap::DashMap::new());
for engine in &prefill_workers {
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
}
for engine in &decode_workers {
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
}
// Initialize cache-aware components if needed
let prefill_tree = match &selection_policy {
PDSelectionPolicy::CacheAware { .. } => {
let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with prefill workers
for engine in &prefill_workers {
tree.lock().unwrap().insert("", &engine.url);
for worker in &prefill_workers {
tree.lock().unwrap().insert("", worker.url());
}
Some(tree)
}
@@ -283,17 +234,27 @@ impl PDRouter {
None
};
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers));
// Start health checkers for both worker pools
let prefill_health_checker =
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs);
let decode_health_checker =
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
Ok(PDRouter {
prefill_workers: Arc::new(RwLock::new(prefill_workers)),
decode_workers: Arc::new(RwLock::new(decode_workers)),
prefill_workers,
decode_workers,
selection_policy,
load_tracking,
prefill_tree,
timeout_secs,
interval_secs,
worker_loads,
load_monitor_handle,
http_client,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
})
}
@@ -330,11 +291,13 @@ impl PDRouter {
// Log routing decision
info!(
"PD routing: {} -> prefill={}, decode={}",
route, prefill.url, decode.url
route,
prefill.url(),
decode.url()
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
return HttpResponse::InternalServerError()
@@ -356,8 +319,8 @@ impl PDRouter {
req,
json_with_bootstrap,
route,
&prefill,
&decode,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
@@ -397,11 +360,13 @@ impl PDRouter {
// Log routing decision
info!(
"PD routing: {} -> prefill={}, decode={}",
route, prefill.url, decode.url
route,
prefill.url(),
decode.url()
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
return HttpResponse::InternalServerError()
@@ -423,8 +388,8 @@ impl PDRouter {
req,
json_with_bootstrap,
route,
&prefill,
&decode,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
@@ -440,22 +405,23 @@ impl PDRouter {
req: &HttpRequest,
json_request: serde_json::Value,
route: &str,
prefill: &EngineInfo,
decode: &EngineInfo,
prefill: &dyn Worker,
decode: &dyn Worker,
is_stream: bool,
return_logprob: bool,
start_time: Instant,
) -> HttpResponse {
// Update load tracking for both workers
let _guard = LoadGuard::new(
&self.load_tracking,
vec![prefill.url.clone(), decode.url.clone()],
);
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
// Build requests using .json() method
let mut prefill_request = client.post(prefill.api_path(route)).json(&json_request);
let mut prefill_request = client
.post(api_path(prefill.url(), route))
.json(&json_request);
let mut decode_request = client.post(decode.api_path(route)).json(&json_request);
let mut decode_request = client
.post(api_path(decode.url(), route))
.json(&json_request);
// Copy headers from original request
for (name, value) in crate::router::copy_request_headers(req) {
@@ -474,9 +440,9 @@ impl PDRouter {
histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1);
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url.to_string())
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url().to_string())
.increment(1);
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url.to_string())
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url().to_string())
.increment(1);
// Process decode response
@@ -486,10 +452,11 @@ impl PDRouter {
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !status.is_success() {
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()).increment(1);
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()).increment(1);
error!(
"Decode server {} returned error status: {}",
decode.url, status
decode.url(),
status
);
// Return the error response from decode server
@@ -508,9 +475,10 @@ impl PDRouter {
if let Err(e) = &prefill_result {
error!(
"Prefill server {} failed (non-critical): {}",
prefill.url, e
prefill.url(),
e
);
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url.to_string()).increment(1);
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url().to_string()).increment(1);
}
if is_stream {
@@ -559,7 +527,7 @@ impl PDRouter {
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming({
let decode_url = decode.url.clone();
let decode_url = decode.url().to_string();
res.bytes_stream().map_err(move |e| {
error!("Stream error from decode server {}: {}", decode_url, e);
counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1);
@@ -587,7 +555,7 @@ impl PDRouter {
}
Err(e) => {
error!("Decode request failed: {}", e);
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string())
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string())
.increment(1);
HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
}
@@ -652,7 +620,7 @@ impl PDRouter {
async fn select_pd_pair(
&self,
_client: &reqwest::Client,
) -> Result<(EngineInfo, EngineInfo), String> {
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
// Check we have workers
if self
.prefill_workers
@@ -681,17 +649,17 @@ impl PDRouter {
}
}
fn select_random(&self) -> Result<(EngineInfo, EngineInfo), String> {
fn select_random(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone();
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone();
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone_worker();
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone_worker();
Ok((prefill, decode))
}
async fn select_power_of_two(&self) -> Result<(EngineInfo, EngineInfo), String> {
async fn select_power_of_two(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
@@ -700,33 +668,45 @@ impl PDRouter {
let loads = self.worker_loads.borrow();
let p1_load = loads.get(&prefill_list[p1_idx].url).copied().unwrap_or(0);
let p2_load = loads.get(&prefill_list[p2_idx].url).copied().unwrap_or(0);
let d1_load = loads.get(&decode_list[d1_idx].url).copied().unwrap_or(0);
let d2_load = loads.get(&decode_list[d2_idx].url).copied().unwrap_or(0);
let p1_load = loads
.get(prefill_list[p1_idx].url())
.copied()
.unwrap_or(isize::MAX);
let p2_load = loads
.get(prefill_list[p2_idx].url())
.copied()
.unwrap_or(isize::MAX);
let d1_load = loads
.get(decode_list[d1_idx].url())
.copied()
.unwrap_or(isize::MAX);
let d2_load = loads
.get(decode_list[d2_idx].url())
.copied()
.unwrap_or(isize::MAX);
info!(
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}",
prefill_list[p1_idx].url,
prefill_list[p1_idx].url(),
p1_load,
prefill_list[p2_idx].url,
prefill_list[p2_idx].url(),
p2_load,
decode_list[d1_idx].url,
decode_list[d1_idx].url(),
d1_load,
decode_list[d2_idx].url,
decode_list[d2_idx].url(),
d2_load
);
let selected_prefill = if p1_load <= p2_load {
prefill_list[p1_idx].clone()
prefill_list[p1_idx].clone_worker()
} else {
prefill_list[p2_idx].clone()
prefill_list[p2_idx].clone_worker()
};
let selected_decode = if d1_load <= d2_load {
decode_list[d1_idx].clone()
decode_list[d1_idx].clone_worker()
} else {
decode_list[d2_idx].clone()
decode_list[d2_idx].clone_worker()
};
Ok((selected_prefill, selected_decode))
@@ -868,11 +848,11 @@ impl PDRouter {
let mut worker_infos = Vec::new();
for worker in self.prefill_workers.read().unwrap().iter() {
worker_infos.push((worker.url.clone(), "prefill"));
worker_infos.push((worker.url().to_string(), "prefill"));
}
for worker in self.decode_workers.read().unwrap().iter() {
worker_infos.push((worker.url.clone(), "decode"));
worker_infos.push((worker.url().to_string(), "decode"));
}
// Create tasks with URL tracking
@@ -922,7 +902,7 @@ impl PDRouter {
pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse {
// Get info from the first decode server to match sglang's server info format
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
workers.first().map(|w| w.url.clone())
workers.first().map(|w| w.url().to_string())
} else {
return HttpResponse::InternalServerError().body("Failed to access decode workers");
};
@@ -967,7 +947,7 @@ impl PDRouter {
pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse {
// Get first prefill worker URL to avoid holding lock across await
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
workers.first().map(|w| w.url.clone())
workers.first().map(|w| w.url().to_string())
} else {
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
};
@@ -1005,14 +985,14 @@ impl PDRouter {
.read()
.unwrap()
.iter()
.map(|w| w.url.clone())
.map(|w| w.url().to_string())
.collect();
let d_urls: Vec<_> = self
.decode_workers
.read()
.unwrap()
.iter()
.map(|w| w.url.clone())
.map(|w| w.url().to_string())
.collect();
let mut prefill_loads = Vec::new();
@@ -1048,7 +1028,7 @@ impl PDRouter {
// Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get first prefill worker URL to avoid holding lock across await
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
workers.first().map(|w| w.url.clone())
workers.first().map(|w| w.url().to_string())
} else {
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
};
@@ -1084,13 +1064,13 @@ impl PDRouter {
// Flush cache on all prefill servers
for worker in self.prefill_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url);
let url = format!("{}/flush_cache", worker.url());
tasks.push(client.post(&url).send());
}
// Flush cache on all decode servers
for worker in self.decode_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url);
let url = format!("{}/flush_cache", worker.url());
tasks.push(client.post(&url).send());
}