Files
sglang/sgl-router/src/core/worker_manager.rs

371 lines
12 KiB
Rust

//! Unified Worker Management Module
//!
//! Handles all aspects of worker lifecycle including discovery, initialization,
//! runtime management, and health monitoring.
use std::{collections::HashMap, sync::Arc, time::Duration};
use futures::future;
use serde_json::Value;
use tokio::{
sync::{watch, Mutex},
task::JoinHandle,
};
use tracing::{debug, error, info, warn};
use crate::{
core::{ConnectionMode, WorkerRegistry, WorkerType},
policies::PolicyRegistry,
protocols::worker_spec::{FlushCacheResult, WorkerLoadInfo, WorkerLoadsResult},
};
/// Unified worker management
pub struct WorkerManager;
impl WorkerManager {
pub fn get_worker_urls(registry: &Arc<WorkerRegistry>) -> Vec<String> {
registry
.get_all()
.iter()
.map(|w| w.url().to_string())
.collect()
}
/// Flush cache on all workers
///
/// Sends a POST request to /flush_cache endpoint on all HTTP workers.
/// Returns detailed results showing which workers succeeded and which failed.
pub async fn flush_cache_all(
worker_registry: &WorkerRegistry,
client: &reqwest::Client,
) -> Result<FlushCacheResult, String> {
warn!("Flushing cache for ALL workers - this may impact performance temporarily");
let workers = worker_registry.get_all();
let http_workers: Vec<_> = workers
.iter()
.filter(|w| matches!(w.connection_mode(), ConnectionMode::Http))
.collect();
if http_workers.is_empty() {
return Ok(FlushCacheResult {
successful: vec![],
failed: vec![],
total_workers: workers.len(),
http_workers: 0,
message: "No HTTP workers available for cache flush".to_string(),
});
}
info!(
"Flushing cache on {} HTTP workers (out of {} total workers)",
http_workers.len(),
workers.len()
);
let mut tasks = Vec::new();
for worker in &http_workers {
let url = worker.url().to_string();
let flush_url = format!("{}/flush_cache", url);
let mut request = client.post(&flush_url);
if let Some(api_key) = worker.api_key() {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
let worker_url = url.clone();
tasks.push(async move {
let result = request.send().await;
(worker_url, result)
});
}
let results = future::join_all(tasks).await;
let mut successful = Vec::new();
let mut failed = Vec::new();
for (url, result) in results {
match result {
Ok(response) if response.status().is_success() => {
debug!("Successfully flushed cache on worker: {}", url);
successful.push(url);
}
Ok(response) => {
let error = format!("HTTP {}", response.status());
warn!("Failed to flush cache on worker {}: {}", url, error);
failed.push((url, error));
}
Err(e) => {
let error = e.to_string();
error!("Failed to connect to worker {}: {}", url, error);
failed.push((url, error));
}
}
}
let message = if failed.is_empty() {
format!(
"Successfully flushed cache on all {} HTTP workers",
successful.len()
)
} else {
format!(
"Cache flush completed: {} succeeded, {} failed (out of {} HTTP workers)",
successful.len(),
failed.len(),
http_workers.len()
)
};
info!("{}", message);
Ok(FlushCacheResult {
successful,
failed,
total_workers: workers.len(),
http_workers: http_workers.len(),
message,
})
}
pub async fn get_worker_load(
url: &str,
api_key: Option<&str>,
client: &reqwest::Client,
) -> Option<isize> {
let load_url = format!("{}/get_load", url);
let mut request = client.get(&load_url);
if let Some(key) = api_key {
request = request.bearer_auth(key);
}
match request.send().await {
Ok(response) if response.status().is_success() => {
match response.json::<Value>().await {
Ok(json) => {
// The /get_load endpoint returns an array of load info objects (one per DP rank)
// Each object has: {dp_rank, num_reqs, num_waiting_reqs, num_tokens}
if let Some(array) = json.as_array() {
let total_tokens: i64 = array
.iter()
.filter_map(|entry| {
entry.get("num_tokens").and_then(|v| v.as_i64())
})
.sum();
debug!("Worker {} load (total tokens): {}", url, total_tokens);
Some(total_tokens as isize)
} else {
warn!(
"Invalid load response from {}: expected array, got {:?}",
url, json
);
None
}
}
Err(e) => {
warn!("Failed to parse load response from {}: {}", url, e);
None
}
}
}
Ok(response) => {
warn!(
"Failed to get load from {}: HTTP {}",
url,
response.status()
);
None
}
Err(e) => {
warn!("Failed to connect to {} for load check: {}", url, e);
None
}
}
}
pub async fn get_all_worker_loads(
worker_registry: &WorkerRegistry,
client: &reqwest::Client,
) -> WorkerLoadsResult {
let workers = worker_registry.get_all();
let total_workers = workers.len();
// Prepare tasks for parallel execution
let mut tasks = Vec::new();
for worker in &workers {
let url = worker.url().to_string();
let api_key = worker.api_key().clone();
let worker_type = match worker.worker_type() {
WorkerType::Regular => None,
WorkerType::Prefill { .. } => Some("prefill".to_string()),
WorkerType::Decode => Some("decode".to_string()),
};
let is_http = matches!(worker.connection_mode(), ConnectionMode::Http);
let client = client.clone();
tasks.push(async move {
let load = if is_http {
Self::get_worker_load(&url, api_key.as_deref(), &client)
.await
.unwrap_or(-1)
} else {
-1
};
WorkerLoadInfo {
worker: url,
worker_type,
load,
}
});
}
let loads = future::join_all(tasks).await;
let successful = loads.iter().filter(|l| l.load >= 0).count();
let failed = loads.iter().filter(|l| l.load < 0).count();
WorkerLoadsResult {
loads,
total_workers,
successful,
failed,
}
}
}
/// Load monitoring service that periodically fetches worker loads
pub struct LoadMonitor {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
client: reqwest::Client,
interval: Duration,
tx: watch::Sender<HashMap<String, isize>>,
rx: watch::Receiver<HashMap<String, isize>>,
monitor_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
}
impl LoadMonitor {
/// Create a new load monitor
pub fn new(
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
client: reqwest::Client,
interval_secs: u64,
) -> Self {
let (tx, rx) = watch::channel(HashMap::new());
Self {
worker_registry,
policy_registry,
client,
interval: Duration::from_secs(interval_secs),
tx,
rx,
monitor_handle: Arc::new(Mutex::new(None)),
}
}
/// Start monitoring worker loads
pub async fn start(&self) {
let mut handle_guard = self.monitor_handle.lock().await;
if handle_guard.is_some() {
debug!("Load monitoring already running");
return;
}
info!(
"Starting load monitoring with interval: {:?}",
self.interval
);
let worker_registry = Arc::clone(&self.worker_registry);
let policy_registry = Arc::clone(&self.policy_registry);
let client = self.client.clone();
let interval = self.interval;
let tx = self.tx.clone();
let handle = tokio::spawn(async move {
Self::monitor_loop(worker_registry, policy_registry, client, interval, tx).await;
});
*handle_guard = Some(handle);
}
/// Stop monitoring worker loads
pub async fn stop(&self) {
let mut handle_guard = self.monitor_handle.lock().await;
if let Some(handle) = handle_guard.take() {
info!("Stopping load monitoring");
handle.abort();
let _ = handle.await; // Wait for task to finish
}
}
/// Get a receiver for load updates
pub fn subscribe(&self) -> watch::Receiver<HashMap<String, isize>> {
self.rx.clone()
}
/// The main monitoring loop
async fn monitor_loop(
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
client: reqwest::Client,
interval: Duration,
tx: watch::Sender<HashMap<String, isize>>,
) {
let mut interval_timer = tokio::time::interval(interval);
loop {
interval_timer.tick().await;
let power_of_two_policies = policy_registry.get_all_power_of_two_policies();
if power_of_two_policies.is_empty() {
debug!("No PowerOfTwo policies found, skipping load fetch");
continue;
}
let result = WorkerManager::get_all_worker_loads(&worker_registry, &client).await;
let mut loads = HashMap::new();
for load_info in result.loads {
loads.insert(load_info.worker, load_info.load);
}
if !loads.is_empty() {
debug!(
"Fetched loads from {} workers, updating {} PowerOfTwo policies",
loads.len(),
power_of_two_policies.len()
);
for policy in &power_of_two_policies {
policy.update_loads(&loads);
}
let _ = tx.send(loads);
} else {
warn!("No loads fetched from workers");
}
}
}
/// Check if monitoring is currently active
pub async fn is_running(&self) -> bool {
let handle_guard = self.monitor_handle.lock().await;
handle_guard.is_some()
}
}
impl Drop for LoadMonitor {
fn drop(&mut self) {
if let Ok(mut handle_guard) = self.monitor_handle.try_lock() {
if let Some(handle) = handle_guard.take() {
handle.abort();
}
}
}
}