[router] consolidate worker get loads (#10880)

This commit is contained in:
Simo Lin
2025-09-24 22:13:31 -04:00
committed by GitHub
parent fe531d6f4e
commit e738703547
10 changed files with 157 additions and 280 deletions

View File

@@ -1,6 +1,7 @@
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerManager, WorkerRegistry,
WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
@@ -660,58 +661,6 @@ impl Router {
}
}
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
// Background task to monitor worker loads
async fn monitor_worker_loads(
worker_urls: Vec<String>,
@@ -728,7 +677,10 @@ impl Router {
let mut loads = HashMap::new();
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
// Use WorkerManager for consistent load fetching
if let Some(load) =
WorkerManager::get_worker_load(url, api_key.as_deref(), &client).await
{
loads.insert(url.clone(), load);
}
}
@@ -745,62 +697,6 @@ impl Router {
}
}
// Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(
client: &Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let worker_url = if worker_url.contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
debug!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
async fn build_rerank_response(
req: &RerankRequest,
response: Response,
@@ -953,25 +849,6 @@ impl RouterTrait for Router {
}
}
async fn get_worker_loads(&self) -> Response {
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
let mut loads = Vec::new();
// Get loads from all workers
for (url, api_key) in &urls_with_key {
let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
loads.push(serde_json::json!({
"worker": url,
"load": load
}));
}
Json(serde_json::json!({
"workers": loads
}))
.into_response()
}
fn router_type(&self) -> &'static str {
"regular"
}