[router] consolidate worker get loads (#10880)
This commit is contained in:
@@ -12,7 +12,9 @@ use crate::core::{
|
||||
Worker, WorkerFactory, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::worker_spec::{FlushCacheResult, WorkerConfigRequest};
|
||||
use crate::protocols::worker_spec::{
|
||||
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
|
||||
};
|
||||
use crate::server::AppContext;
|
||||
use futures::future;
|
||||
use once_cell::sync::Lazy;
|
||||
@@ -1079,6 +1081,100 @@ impl WorkerManager {
|
||||
message,
|
||||
})
|
||||
}
|
||||
pub async fn get_worker_load(
|
||||
url: &str,
|
||||
api_key: Option<&str>,
|
||||
client: &reqwest::Client,
|
||||
) -> Option<isize> {
|
||||
let load_url = format!("{}/get_load", url);
|
||||
let mut request = client.get(&load_url);
|
||||
|
||||
if let Some(key) = api_key {
|
||||
request = request.bearer_auth(key);
|
||||
}
|
||||
|
||||
match request.send().await {
|
||||
Ok(response) if response.status().is_success() => {
|
||||
match response.json::<Value>().await {
|
||||
Ok(json) => {
|
||||
if let Some(load) = json.get("load").and_then(|v| v.as_i64()) {
|
||||
debug!("Worker {} load: {}", url, load);
|
||||
Some(load as isize)
|
||||
} else {
|
||||
warn!("Invalid load response from {}: {:?}", url, json);
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse load response from {}: {}", url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(response) => {
|
||||
warn!(
|
||||
"Failed to get load from {}: HTTP {}",
|
||||
url,
|
||||
response.status()
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to {} for load check: {}", url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_all_worker_loads(
|
||||
worker_registry: &WorkerRegistry,
|
||||
client: &reqwest::Client,
|
||||
) -> WorkerLoadsResult {
|
||||
let workers = worker_registry.get_all();
|
||||
let total_workers = workers.len();
|
||||
|
||||
// Prepare tasks for parallel execution
|
||||
let mut tasks = Vec::new();
|
||||
for worker in &workers {
|
||||
let url = worker.url().to_string();
|
||||
let api_key = worker.api_key().clone();
|
||||
let worker_type = match worker.worker_type() {
|
||||
WorkerType::Regular => None,
|
||||
WorkerType::Prefill { .. } => Some("prefill".to_string()),
|
||||
WorkerType::Decode => Some("decode".to_string()),
|
||||
};
|
||||
let is_http = matches!(worker.connection_mode(), ConnectionMode::Http);
|
||||
let client = client.clone();
|
||||
|
||||
tasks.push(async move {
|
||||
let load = if is_http {
|
||||
Self::get_worker_load(&url, api_key.as_deref(), &client)
|
||||
.await
|
||||
.unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
|
||||
WorkerLoadInfo {
|
||||
worker: url,
|
||||
worker_type,
|
||||
load,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let loads = futures::future::join_all(tasks).await;
|
||||
|
||||
let successful = loads.iter().filter(|l| l.load >= 0).count();
|
||||
let failed = loads.iter().filter(|l| l.load < 0).count();
|
||||
|
||||
WorkerLoadsResult {
|
||||
loads,
|
||||
total_workers,
|
||||
successful,
|
||||
failed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -215,3 +215,28 @@ pub struct FlushCacheResult {
|
||||
/// Human-readable summary message
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Result from getting worker loads
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct WorkerLoadsResult {
|
||||
/// Worker URL and load pairs
|
||||
pub loads: Vec<WorkerLoadInfo>,
|
||||
/// Total number of workers
|
||||
pub total_workers: usize,
|
||||
/// Number of workers with successful load fetches
|
||||
pub successful: usize,
|
||||
/// Number of workers with failed load fetches
|
||||
pub failed: usize,
|
||||
}
|
||||
|
||||
/// Individual worker load information
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct WorkerLoadInfo {
|
||||
/// Worker URL
|
||||
pub worker: String,
|
||||
/// Worker type (regular, prefill, decode)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub worker_type: Option<String>,
|
||||
/// Current load (-1 indicates failure to fetch)
|
||||
pub load: isize,
|
||||
}
|
||||
|
||||
@@ -340,10 +340,6 @@ impl RouterTrait for GrpcPDRouter {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"grpc_pd"
|
||||
}
|
||||
|
||||
@@ -787,10 +787,6 @@ impl RouterTrait for GrpcRouter {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"grpc"
|
||||
}
|
||||
|
||||
@@ -1296,14 +1296,6 @@ impl super::super::RouterTrait for OpenAIRouter {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
(
|
||||
StatusCode::FORBIDDEN,
|
||||
"get_worker_loads not supported for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use super::pd_types::api_path;
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry,
|
||||
WorkerType,
|
||||
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerManager,
|
||||
WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
@@ -18,7 +18,6 @@ use axum::{
|
||||
extract::Request,
|
||||
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
@@ -53,26 +52,6 @@ struct PDRequestContext<'a> {
|
||||
}
|
||||
|
||||
impl PDRouter {
|
||||
fn _get_worker_url_and_key(&self, w: &Arc<dyn Worker>) -> (String, Option<String>) {
|
||||
(w.url().to_string(), w.api_key().clone())
|
||||
}
|
||||
|
||||
fn get_prefill_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
|
||||
self.worker_registry
|
||||
.get_prefill_workers()
|
||||
.iter()
|
||||
.map(|w| self._get_worker_url_and_key(w))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_decode_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
|
||||
self.worker_registry
|
||||
.get_decode_workers()
|
||||
.iter()
|
||||
.map(|w| self._get_worker_url_and_key(w))
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn proxy_to_first_prefill_worker(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
@@ -749,7 +728,10 @@ impl PDRouter {
|
||||
let url = url.clone();
|
||||
let api_key = api_key.clone();
|
||||
async move {
|
||||
let load = get_worker_load(&client, &url, &api_key).await.unwrap_or(0);
|
||||
let load =
|
||||
WorkerManager::get_worker_load(&url, api_key.as_deref(), &client)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
(url, load)
|
||||
}
|
||||
})
|
||||
@@ -1083,49 +1065,6 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
async fn get_worker_load(
|
||||
client: &Client,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Option<isize> {
|
||||
let mut req_builder = client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
.get("load")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as isize),
|
||||
Err(e) => {
|
||||
debug!("Failed to parse load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
debug!("Failed to read load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Ok(res) => {
|
||||
debug!(
|
||||
"Worker {} returned non-success status: {}",
|
||||
worker_url,
|
||||
res.status()
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to get load from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouterTrait for PDRouter {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
@@ -1418,44 +1357,6 @@ impl RouterTrait for PDRouter {
|
||||
self.execute_dual_dispatch(headers, body, context).await
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
let mut loads = HashMap::new();
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Process prefill workers
|
||||
let prefill_urls_with_key = self.get_prefill_worker_urls_with_api_key();
|
||||
for (worker_url, api_key) in prefill_urls_with_key {
|
||||
match get_worker_load(&self.client, &worker_url, &api_key).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("prefill_{}", worker_url), load);
|
||||
}
|
||||
None => {
|
||||
errors.push(format!("Failed to get load from prefill {}", worker_url));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process decode workers
|
||||
let decode_urls_with_key = self.get_decode_worker_urls_with_api_key();
|
||||
for (worker_url, api_key) in decode_urls_with_key {
|
||||
match get_worker_load(&self.client, &worker_url, &api_key).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("decode_{}", worker_url), load);
|
||||
}
|
||||
None => {
|
||||
errors.push(format!("Failed to get load from decode {}", worker_url));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_data = serde_json::json!({
|
||||
"loads": loads,
|
||||
"errors": errors
|
||||
});
|
||||
|
||||
(StatusCode::OK, Json(response_data)).into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"pd"
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
|
||||
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerManager, WorkerRegistry,
|
||||
WorkerType,
|
||||
};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
@@ -660,58 +661,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
.get("load")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as isize),
|
||||
Err(e) => {
|
||||
debug!("Failed to parse load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
debug!("Failed to read load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Ok(res) => {
|
||||
debug!(
|
||||
"Worker {} returned non-success status: {}",
|
||||
worker_url,
|
||||
res.status()
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to get load from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Background task to monitor worker loads
|
||||
async fn monitor_worker_loads(
|
||||
worker_urls: Vec<String>,
|
||||
@@ -728,7 +677,10 @@ impl Router {
|
||||
|
||||
let mut loads = HashMap::new();
|
||||
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
|
||||
if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
|
||||
// Use WorkerManager for consistent load fetching
|
||||
if let Some(load) =
|
||||
WorkerManager::get_worker_load(url, api_key.as_deref(), &client).await
|
||||
{
|
||||
loads.insert(url.clone(), load);
|
||||
}
|
||||
}
|
||||
@@ -745,62 +697,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// Static version of get_worker_load for use in monitoring task
|
||||
async fn get_worker_load_static(
|
||||
client: &Client,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Option<isize> {
|
||||
let worker_url = if worker_url.contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
debug!("Failed to extract dp_rank: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
let mut req_builder = client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
.get("load")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as isize),
|
||||
Err(e) => {
|
||||
debug!("Failed to parse load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
debug!("Failed to read load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Ok(res) => {
|
||||
debug!(
|
||||
"Worker {} returned non-success status: {}",
|
||||
worker_url,
|
||||
res.status()
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to get load from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_rerank_response(
|
||||
req: &RerankRequest,
|
||||
response: Response,
|
||||
@@ -953,25 +849,6 @@ impl RouterTrait for Router {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
|
||||
let mut loads = Vec::new();
|
||||
|
||||
// Get loads from all workers
|
||||
for (url, api_key) in &urls_with_key {
|
||||
let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
|
||||
loads.push(serde_json::json!({
|
||||
"worker": url,
|
||||
"load": load
|
||||
}));
|
||||
}
|
||||
|
||||
Json(serde_json::json!({
|
||||
"workers": loads
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"regular"
|
||||
}
|
||||
|
||||
@@ -126,9 +126,6 @@ pub trait RouterTrait: Send + Sync + Debug {
|
||||
model_id: Option<&str>,
|
||||
) -> Response;
|
||||
|
||||
/// Get worker loads (for monitoring)
|
||||
async fn get_worker_loads(&self) -> Response;
|
||||
|
||||
/// Get router type name
|
||||
fn router_type(&self) -> &'static str;
|
||||
|
||||
|
||||
@@ -508,30 +508,6 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
let workers = self.worker_registry.get_all();
|
||||
let loads: Vec<serde_json::Value> = workers
|
||||
.iter()
|
||||
.map(|w| {
|
||||
serde_json::json!({
|
||||
"url": w.url(),
|
||||
"model": w.model_id(),
|
||||
"load": w.load(),
|
||||
"is_healthy": w.is_healthy()
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
serde_json::json!({
|
||||
"workers": loads
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"manager"
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ use axum::{
|
||||
};
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use serde_json::{json, Value};
|
||||
use std::{
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
sync::Arc,
|
||||
@@ -400,7 +400,28 @@ async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Respo
|
||||
}
|
||||
|
||||
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||
state.router.get_worker_loads().await
|
||||
let result =
|
||||
WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client)
|
||||
.await;
|
||||
|
||||
let loads: Vec<Value> = result
|
||||
.loads
|
||||
.iter()
|
||||
.map(|info| {
|
||||
json!({
|
||||
"worker": &info.worker,
|
||||
"load": info.load
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"workers": loads
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn create_worker(
|
||||
|
||||
Reference in New Issue
Block a user