[router] consolidate worker get loads (#10880)

This commit is contained in:
Simo Lin
2025-09-24 22:13:31 -04:00
committed by GitHub
parent fe531d6f4e
commit e738703547
10 changed files with 157 additions and 280 deletions

View File

@@ -12,7 +12,9 @@ use crate::core::{
Worker, WorkerFactory, WorkerRegistry, WorkerType,
};
use crate::policies::PolicyRegistry;
use crate::protocols::worker_spec::{FlushCacheResult, WorkerConfigRequest};
use crate::protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
};
use crate::server::AppContext;
use futures::future;
use once_cell::sync::Lazy;
@@ -1079,6 +1081,100 @@ impl WorkerManager {
message,
})
}
pub async fn get_worker_load(
url: &str,
api_key: Option<&str>,
client: &reqwest::Client,
) -> Option<isize> {
let load_url = format!("{}/get_load", url);
let mut request = client.get(&load_url);
if let Some(key) = api_key {
request = request.bearer_auth(key);
}
match request.send().await {
Ok(response) if response.status().is_success() => {
match response.json::<Value>().await {
Ok(json) => {
if let Some(load) = json.get("load").and_then(|v| v.as_i64()) {
debug!("Worker {} load: {}", url, load);
Some(load as isize)
} else {
warn!("Invalid load response from {}: {:?}", url, json);
None
}
}
Err(e) => {
warn!("Failed to parse load response from {}: {}", url, e);
None
}
}
}
Ok(response) => {
warn!(
"Failed to get load from {}: HTTP {}",
url,
response.status()
);
None
}
Err(e) => {
warn!("Failed to connect to {} for load check: {}", url, e);
None
}
}
}
pub async fn get_all_worker_loads(
worker_registry: &WorkerRegistry,
client: &reqwest::Client,
) -> WorkerLoadsResult {
let workers = worker_registry.get_all();
let total_workers = workers.len();
// Prepare tasks for parallel execution
let mut tasks = Vec::new();
for worker in &workers {
let url = worker.url().to_string();
let api_key = worker.api_key().clone();
let worker_type = match worker.worker_type() {
WorkerType::Regular => None,
WorkerType::Prefill { .. } => Some("prefill".to_string()),
WorkerType::Decode => Some("decode".to_string()),
};
let is_http = matches!(worker.connection_mode(), ConnectionMode::Http);
let client = client.clone();
tasks.push(async move {
let load = if is_http {
Self::get_worker_load(&url, api_key.as_deref(), &client)
.await
.unwrap_or(-1)
} else {
-1
};
WorkerLoadInfo {
worker: url,
worker_type,
load,
}
});
}
let loads = futures::future::join_all(tasks).await;
let successful = loads.iter().filter(|l| l.load >= 0).count();
let failed = loads.iter().filter(|l| l.load < 0).count();
WorkerLoadsResult {
loads,
total_workers,
successful,
failed,
}
}
}
#[cfg(test)]