[router] consolidate worker get loads (#10880)
This commit is contained in:
@@ -12,7 +12,9 @@ use crate::core::{
|
||||
Worker, WorkerFactory, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::worker_spec::{FlushCacheResult, WorkerConfigRequest};
|
||||
use crate::protocols::worker_spec::{
|
||||
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
|
||||
};
|
||||
use crate::server::AppContext;
|
||||
use futures::future;
|
||||
use once_cell::sync::Lazy;
|
||||
@@ -1079,6 +1081,100 @@ impl WorkerManager {
|
||||
message,
|
||||
})
|
||||
}
|
||||
pub async fn get_worker_load(
|
||||
url: &str,
|
||||
api_key: Option<&str>,
|
||||
client: &reqwest::Client,
|
||||
) -> Option<isize> {
|
||||
let load_url = format!("{}/get_load", url);
|
||||
let mut request = client.get(&load_url);
|
||||
|
||||
if let Some(key) = api_key {
|
||||
request = request.bearer_auth(key);
|
||||
}
|
||||
|
||||
match request.send().await {
|
||||
Ok(response) if response.status().is_success() => {
|
||||
match response.json::<Value>().await {
|
||||
Ok(json) => {
|
||||
if let Some(load) = json.get("load").and_then(|v| v.as_i64()) {
|
||||
debug!("Worker {} load: {}", url, load);
|
||||
Some(load as isize)
|
||||
} else {
|
||||
warn!("Invalid load response from {}: {:?}", url, json);
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse load response from {}: {}", url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(response) => {
|
||||
warn!(
|
||||
"Failed to get load from {}: HTTP {}",
|
||||
url,
|
||||
response.status()
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to {} for load check: {}", url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_all_worker_loads(
|
||||
worker_registry: &WorkerRegistry,
|
||||
client: &reqwest::Client,
|
||||
) -> WorkerLoadsResult {
|
||||
let workers = worker_registry.get_all();
|
||||
let total_workers = workers.len();
|
||||
|
||||
// Prepare tasks for parallel execution
|
||||
let mut tasks = Vec::new();
|
||||
for worker in &workers {
|
||||
let url = worker.url().to_string();
|
||||
let api_key = worker.api_key().clone();
|
||||
let worker_type = match worker.worker_type() {
|
||||
WorkerType::Regular => None,
|
||||
WorkerType::Prefill { .. } => Some("prefill".to_string()),
|
||||
WorkerType::Decode => Some("decode".to_string()),
|
||||
};
|
||||
let is_http = matches!(worker.connection_mode(), ConnectionMode::Http);
|
||||
let client = client.clone();
|
||||
|
||||
tasks.push(async move {
|
||||
let load = if is_http {
|
||||
Self::get_worker_load(&url, api_key.as_deref(), &client)
|
||||
.await
|
||||
.unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
|
||||
WorkerLoadInfo {
|
||||
worker: url,
|
||||
worker_type,
|
||||
load,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let loads = futures::future::join_all(tasks).await;
|
||||
|
||||
let successful = loads.iter().filter(|l| l.load >= 0).count();
|
||||
let failed = loads.iter().filter(|l| l.load < 0).count();
|
||||
|
||||
WorkerLoadsResult {
|
||||
loads,
|
||||
total_workers,
|
||||
successful,
|
||||
failed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user