[router] consolidate worker load monitoring (#10894)
This commit is contained in:
@@ -25,5 +25,5 @@ pub use worker::{
|
|||||||
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||||
};
|
};
|
||||||
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||||
pub use worker_manager::{DpInfo, ServerInfo, WorkerManager};
|
pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager};
|
||||||
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
|
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ use serde_json::Value;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use tokio::sync::{watch, Mutex};
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
|
static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
|
||||||
@@ -1177,6 +1179,139 @@ impl WorkerManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Load monitoring service that periodically fetches worker loads
|
||||||
|
pub struct LoadMonitor {
|
||||||
|
worker_registry: Arc<WorkerRegistry>,
|
||||||
|
policy_registry: Arc<PolicyRegistry>,
|
||||||
|
client: reqwest::Client,
|
||||||
|
interval: Duration,
|
||||||
|
tx: watch::Sender<HashMap<String, isize>>,
|
||||||
|
rx: watch::Receiver<HashMap<String, isize>>,
|
||||||
|
monitor_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LoadMonitor {
|
||||||
|
/// Create a new load monitor
|
||||||
|
pub fn new(
|
||||||
|
worker_registry: Arc<WorkerRegistry>,
|
||||||
|
policy_registry: Arc<PolicyRegistry>,
|
||||||
|
client: reqwest::Client,
|
||||||
|
interval_secs: u64,
|
||||||
|
) -> Self {
|
||||||
|
let (tx, rx) = watch::channel(HashMap::new());
|
||||||
|
|
||||||
|
Self {
|
||||||
|
worker_registry,
|
||||||
|
policy_registry,
|
||||||
|
client,
|
||||||
|
interval: Duration::from_secs(interval_secs),
|
||||||
|
tx,
|
||||||
|
rx,
|
||||||
|
monitor_handle: Arc::new(Mutex::new(None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start monitoring worker loads
|
||||||
|
pub async fn start(&self) {
|
||||||
|
let mut handle_guard = self.monitor_handle.lock().await;
|
||||||
|
if handle_guard.is_some() {
|
||||||
|
debug!("Load monitoring already running");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Starting load monitoring with interval: {:?}",
|
||||||
|
self.interval
|
||||||
|
);
|
||||||
|
|
||||||
|
let worker_registry = Arc::clone(&self.worker_registry);
|
||||||
|
let policy_registry = Arc::clone(&self.policy_registry);
|
||||||
|
let client = self.client.clone();
|
||||||
|
let interval = self.interval;
|
||||||
|
let tx = self.tx.clone();
|
||||||
|
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
Self::monitor_loop(worker_registry, policy_registry, client, interval, tx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
*handle_guard = Some(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stop monitoring worker loads
|
||||||
|
pub async fn stop(&self) {
|
||||||
|
let mut handle_guard = self.monitor_handle.lock().await;
|
||||||
|
if let Some(handle) = handle_guard.take() {
|
||||||
|
info!("Stopping load monitoring");
|
||||||
|
handle.abort();
|
||||||
|
let _ = handle.await; // Wait for task to finish
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a receiver for load updates
|
||||||
|
pub fn subscribe(&self) -> watch::Receiver<HashMap<String, isize>> {
|
||||||
|
self.rx.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The main monitoring loop
|
||||||
|
async fn monitor_loop(
|
||||||
|
worker_registry: Arc<WorkerRegistry>,
|
||||||
|
policy_registry: Arc<PolicyRegistry>,
|
||||||
|
client: reqwest::Client,
|
||||||
|
interval: Duration,
|
||||||
|
tx: watch::Sender<HashMap<String, isize>>,
|
||||||
|
) {
|
||||||
|
let mut interval_timer = tokio::time::interval(interval);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval_timer.tick().await;
|
||||||
|
|
||||||
|
let power_of_two_policies = policy_registry.get_all_power_of_two_policies();
|
||||||
|
|
||||||
|
if power_of_two_policies.is_empty() {
|
||||||
|
debug!("No PowerOfTwo policies found, skipping load fetch");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = WorkerManager::get_all_worker_loads(&worker_registry, &client).await;
|
||||||
|
|
||||||
|
let mut loads = HashMap::new();
|
||||||
|
for load_info in result.loads {
|
||||||
|
loads.insert(load_info.worker, load_info.load);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !loads.is_empty() {
|
||||||
|
debug!(
|
||||||
|
"Fetched loads from {} workers, updating {} PowerOfTwo policies",
|
||||||
|
loads.len(),
|
||||||
|
power_of_two_policies.len()
|
||||||
|
);
|
||||||
|
for policy in &power_of_two_policies {
|
||||||
|
policy.update_loads(&loads);
|
||||||
|
}
|
||||||
|
let _ = tx.send(loads);
|
||||||
|
} else {
|
||||||
|
warn!("No loads fetched from workers");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if monitoring is currently active
|
||||||
|
pub async fn is_running(&self) -> bool {
|
||||||
|
let handle_guard = self.monitor_handle.lock().await;
|
||||||
|
handle_guard.is_some()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for LoadMonitor {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Ok(mut handle_guard) = self.monitor_handle.try_lock() {
|
||||||
|
if let Some(handle) = handle_guard.take() {
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@@ -257,6 +257,47 @@ impl PolicyRegistry {
|
|||||||
.unwrap_or_else(|| self.get_default_policy())
|
.unwrap_or_else(|| self.get_default_policy())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get all PowerOfTwo policies that need load updates
|
||||||
|
pub fn get_all_power_of_two_policies(&self) -> Vec<Arc<dyn LoadBalancingPolicy>> {
|
||||||
|
let mut power_of_two_policies = Vec::new();
|
||||||
|
|
||||||
|
if self.default_policy.name() == "power_of_two" {
|
||||||
|
power_of_two_policies.push(Arc::clone(&self.default_policy));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref policy) = *self.prefill_policy.read().unwrap() {
|
||||||
|
if policy.name() == "power_of_two" && !Arc::ptr_eq(policy, &self.default_policy) {
|
||||||
|
power_of_two_policies.push(Arc::clone(policy));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref policy) = *self.decode_policy.read().unwrap() {
|
||||||
|
if policy.name() == "power_of_two"
|
||||||
|
&& !Arc::ptr_eq(policy, &self.default_policy)
|
||||||
|
&& !self
|
||||||
|
.prefill_policy
|
||||||
|
.read()
|
||||||
|
.unwrap()
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|p| Arc::ptr_eq(p, policy))
|
||||||
|
{
|
||||||
|
power_of_two_policies.push(Arc::clone(policy));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let model_policies = self.model_policies.read().unwrap();
|
||||||
|
for policy in model_policies.values() {
|
||||||
|
if policy.name() == "power_of_two" {
|
||||||
|
let already_added = power_of_two_policies.iter().any(|p| Arc::ptr_eq(p, policy));
|
||||||
|
if !already_added {
|
||||||
|
power_of_two_policies.push(Arc::clone(policy));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
power_of_two_policies
|
||||||
|
}
|
||||||
|
|
||||||
/// Initialize cache-aware policy with workers if applicable
|
/// Initialize cache-aware policy with workers if applicable
|
||||||
/// This should be called after workers are registered for a model
|
/// This should be called after workers are registered for a model
|
||||||
pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc<dyn Worker>]) {
|
pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc<dyn Worker>]) {
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use super::pd_types::api_path;
|
use super::pd_types::api_path;
|
||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
use crate::core::{
|
use crate::core::{
|
||||||
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerManager,
|
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
|
||||||
WorkerRegistry, WorkerType,
|
|
||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||||
@@ -23,18 +22,15 @@ use futures_util::StreamExt;
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct PDRouter {
|
pub struct PDRouter {
|
||||||
pub worker_registry: Arc<WorkerRegistry>,
|
pub worker_registry: Arc<WorkerRegistry>,
|
||||||
pub policy_registry: Arc<PolicyRegistry>,
|
pub policy_registry: Arc<PolicyRegistry>,
|
||||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
|
||||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
|
||||||
pub client: Client,
|
pub client: Client,
|
||||||
pub retry_config: RetryConfig,
|
pub retry_config: RetryConfig,
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
@@ -124,71 +120,9 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
|
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
|
||||||
let prefill_workers = ctx.worker_registry.get_workers_filtered(
|
|
||||||
None, // any model
|
|
||||||
Some(WorkerType::Prefill {
|
|
||||||
bootstrap_port: None,
|
|
||||||
}),
|
|
||||||
Some(ConnectionMode::Http),
|
|
||||||
false, // include all workers
|
|
||||||
);
|
|
||||||
|
|
||||||
let decode_workers = ctx.worker_registry.get_workers_filtered(
|
|
||||||
None, // any model
|
|
||||||
Some(WorkerType::Decode),
|
|
||||||
Some(ConnectionMode::Http),
|
|
||||||
false, // include all workers
|
|
||||||
);
|
|
||||||
|
|
||||||
let all_urls: Vec<String> = prefill_workers
|
|
||||||
.iter()
|
|
||||||
.chain(decode_workers.iter())
|
|
||||||
.map(|w| w.url().to_string())
|
|
||||||
.collect();
|
|
||||||
let all_api_keys: Vec<Option<String>> = prefill_workers
|
|
||||||
.iter()
|
|
||||||
.chain(decode_workers.iter())
|
|
||||||
.map(|w| w.api_key().clone())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
|
||||||
let worker_loads = Arc::new(rx);
|
|
||||||
|
|
||||||
let prefill_policy = ctx.policy_registry.get_prefill_policy();
|
|
||||||
let decode_policy = ctx.policy_registry.get_decode_policy();
|
|
||||||
|
|
||||||
let load_monitor_handle =
|
|
||||||
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
|
|
||||||
let monitor_urls = all_urls.clone();
|
|
||||||
let monitor_api_keys = all_api_keys.clone();
|
|
||||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
|
||||||
let monitor_client = ctx.client.clone();
|
|
||||||
let prefill_policy_clone = Arc::clone(&prefill_policy);
|
|
||||||
let decode_policy_clone = Arc::clone(&decode_policy);
|
|
||||||
|
|
||||||
Some(Arc::new(tokio::spawn(async move {
|
|
||||||
Self::monitor_worker_loads_with_client(
|
|
||||||
monitor_urls,
|
|
||||||
monitor_api_keys,
|
|
||||||
tx,
|
|
||||||
monitor_interval,
|
|
||||||
monitor_client,
|
|
||||||
prefill_policy_clone,
|
|
||||||
decode_policy_clone,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
})))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// No longer need prefill drain channel - we'll wait for both responses
|
|
||||||
|
|
||||||
Ok(PDRouter {
|
Ok(PDRouter {
|
||||||
worker_registry: Arc::clone(&ctx.worker_registry),
|
worker_registry: Arc::clone(&ctx.worker_registry),
|
||||||
policy_registry: Arc::clone(&ctx.policy_registry),
|
policy_registry: Arc::clone(&ctx.policy_registry),
|
||||||
worker_loads,
|
|
||||||
load_monitor_handle,
|
|
||||||
client: ctx.client.clone(),
|
client: ctx.client.clone(),
|
||||||
retry_config: ctx.router_config.effective_retry_config(),
|
retry_config: ctx.router_config.effective_retry_config(),
|
||||||
api_key: ctx.router_config.api_key.clone(),
|
api_key: ctx.router_config.api_key.clone(),
|
||||||
@@ -708,55 +642,6 @@ impl PDRouter {
|
|||||||
Ok(available_workers[selected_idx].clone())
|
Ok(available_workers[selected_idx].clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn monitor_worker_loads_with_client(
|
|
||||||
worker_urls: Vec<String>,
|
|
||||||
worker_api_keys: Vec<Option<String>>,
|
|
||||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
|
||||||
interval_secs: u64,
|
|
||||||
client: Client,
|
|
||||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
|
||||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
|
||||||
) {
|
|
||||||
loop {
|
|
||||||
let mut loads = HashMap::new();
|
|
||||||
|
|
||||||
let futures: Vec<_> = worker_urls
|
|
||||||
.iter()
|
|
||||||
.zip(worker_api_keys.iter())
|
|
||||||
.map(|(url, api_key)| {
|
|
||||||
let client = client.clone();
|
|
||||||
let url = url.clone();
|
|
||||||
let api_key = api_key.clone();
|
|
||||||
async move {
|
|
||||||
let load =
|
|
||||||
WorkerManager::get_worker_load(&url, api_key.as_deref(), &client)
|
|
||||||
.await
|
|
||||||
.unwrap_or(0);
|
|
||||||
(url, load)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let results = futures_util::future::join_all(futures).await;
|
|
||||||
|
|
||||||
for (url, load) in results {
|
|
||||||
loads.insert(url, load);
|
|
||||||
}
|
|
||||||
|
|
||||||
debug!("Worker loads updated: {:?}", loads);
|
|
||||||
|
|
||||||
prefill_policy.update_loads(&loads);
|
|
||||||
decode_policy.update_loads(&loads);
|
|
||||||
|
|
||||||
if tx.send(loads).is_err() {
|
|
||||||
info!("Load monitor receiver dropped, shutting down monitor task");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn create_streaming_response(
|
fn create_streaming_response(
|
||||||
&self,
|
&self,
|
||||||
@@ -1375,8 +1260,6 @@ mod tests {
|
|||||||
PDRouter {
|
PDRouter {
|
||||||
worker_registry,
|
worker_registry,
|
||||||
policy_registry,
|
policy_registry,
|
||||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
|
||||||
load_monitor_handle: None,
|
|
||||||
client: Client::new(),
|
client: Client::new(),
|
||||||
retry_config: RetryConfig::default(),
|
retry_config: RetryConfig::default(),
|
||||||
api_key: Some("test_api_key".to_string()),
|
api_key: Some("test_api_key".to_string()),
|
||||||
@@ -1436,31 +1319,6 @@ mod tests {
|
|||||||
assert!(result.unwrap_err().contains("No prefill workers available"));
|
assert!(result.unwrap_err().contains("No prefill workers available"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_load_monitor_updates() {
|
|
||||||
let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
|
|
||||||
let mut router = create_test_pd_router();
|
|
||||||
router
|
|
||||||
.policy_registry
|
|
||||||
.set_prefill_policy(power_of_two_policy.clone());
|
|
||||||
router
|
|
||||||
.policy_registry
|
|
||||||
.set_decode_policy(power_of_two_policy);
|
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
|
||||||
router.worker_loads = Arc::new(rx);
|
|
||||||
|
|
||||||
let mut loads = HashMap::new();
|
|
||||||
loads.insert("http://worker1".to_string(), 10);
|
|
||||||
loads.insert("http://worker2".to_string(), 5);
|
|
||||||
|
|
||||||
let _ = tx.send(loads.clone());
|
|
||||||
|
|
||||||
let received = router.worker_loads.borrow().clone();
|
|
||||||
assert_eq!(received.get("http://worker1"), Some(&10));
|
|
||||||
assert_eq!(received.get("http://worker2"), Some(&5));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_worker_load_metrics() {
|
fn test_worker_load_metrics() {
|
||||||
let prefill_worker = create_test_worker(
|
let prefill_worker = create_test_worker(
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
use crate::core::{
|
use crate::core::{
|
||||||
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerManager, WorkerRegistry,
|
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
|
||||||
WorkerType,
|
|
||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
use crate::policies::PolicyRegistry;
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
|
||||||
RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest,
|
RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest,
|
||||||
@@ -23,9 +22,8 @@ use axum::{
|
|||||||
};
|
};
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error};
|
use tracing::{debug, error};
|
||||||
|
|
||||||
@@ -38,8 +36,6 @@ pub struct Router {
|
|||||||
dp_aware: bool,
|
dp_aware: bool,
|
||||||
enable_igw: bool,
|
enable_igw: bool,
|
||||||
retry_config: RetryConfig,
|
retry_config: RetryConfig,
|
||||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
|
||||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
@@ -54,42 +50,6 @@ impl Router {
|
|||||||
|
|
||||||
RouterMetrics::set_active_workers(workers.len());
|
RouterMetrics::set_active_workers(workers.len());
|
||||||
|
|
||||||
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
|
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
|
||||||
let worker_loads = Arc::new(rx);
|
|
||||||
|
|
||||||
let default_policy = ctx.policy_registry.get_default_policy();
|
|
||||||
|
|
||||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
|
||||||
let monitor_urls = worker_urls.clone();
|
|
||||||
let monitor_api_keys = monitor_urls
|
|
||||||
.iter()
|
|
||||||
.map(|url| {
|
|
||||||
ctx.worker_registry
|
|
||||||
.get_by_url(url)
|
|
||||||
.and_then(|w| w.api_key().clone())
|
|
||||||
})
|
|
||||||
.collect::<Vec<Option<String>>>();
|
|
||||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
|
||||||
let policy_clone = default_policy.clone();
|
|
||||||
let client_clone = ctx.client.clone();
|
|
||||||
|
|
||||||
Some(Arc::new(tokio::spawn(async move {
|
|
||||||
Self::monitor_worker_loads(
|
|
||||||
monitor_urls,
|
|
||||||
monitor_api_keys,
|
|
||||||
tx,
|
|
||||||
monitor_interval,
|
|
||||||
policy_clone,
|
|
||||||
client_clone,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
})))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Router {
|
Ok(Router {
|
||||||
worker_registry: ctx.worker_registry.clone(),
|
worker_registry: ctx.worker_registry.clone(),
|
||||||
policy_registry: ctx.policy_registry.clone(),
|
policy_registry: ctx.policy_registry.clone(),
|
||||||
@@ -97,8 +57,6 @@ impl Router {
|
|||||||
dp_aware: ctx.router_config.dp_aware,
|
dp_aware: ctx.router_config.dp_aware,
|
||||||
enable_igw: ctx.router_config.enable_igw,
|
enable_igw: ctx.router_config.enable_igw,
|
||||||
retry_config: ctx.router_config.effective_retry_config(),
|
retry_config: ctx.router_config.effective_retry_config(),
|
||||||
_worker_loads: worker_loads,
|
|
||||||
_load_monitor_handle: load_monitor_handle,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -661,42 +619,6 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Background task to monitor worker loads
|
|
||||||
async fn monitor_worker_loads(
|
|
||||||
worker_urls: Vec<String>,
|
|
||||||
worker_api_keys: Vec<Option<String>>,
|
|
||||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
|
||||||
interval_secs: u64,
|
|
||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
|
||||||
client: Client,
|
|
||||||
) {
|
|
||||||
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
|
|
||||||
|
|
||||||
loop {
|
|
||||||
interval.tick().await;
|
|
||||||
|
|
||||||
let mut loads = HashMap::new();
|
|
||||||
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
|
|
||||||
// Use WorkerManager for consistent load fetching
|
|
||||||
if let Some(load) =
|
|
||||||
WorkerManager::get_worker_load(url, api_key.as_deref(), &client).await
|
|
||||||
{
|
|
||||||
loads.insert(url.clone(), load);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !loads.is_empty() {
|
|
||||||
// Update policy with new loads
|
|
||||||
policy.update_loads(&loads);
|
|
||||||
|
|
||||||
// Send to watchers
|
|
||||||
if let Err(e) = tx.send(loads) {
|
|
||||||
error!("Failed to send load update: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn build_rerank_response(
|
async fn build_rerank_response(
|
||||||
req: &RerankRequest,
|
req: &RerankRequest,
|
||||||
response: Response,
|
response: Response,
|
||||||
@@ -858,7 +780,6 @@ impl RouterTrait for Router {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::core::BasicWorkerBuilder;
|
use crate::core::BasicWorkerBuilder;
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
fn create_test_regular_router() -> Router {
|
fn create_test_regular_router() -> Router {
|
||||||
// Create registries
|
// Create registries
|
||||||
@@ -877,15 +798,12 @@ mod tests {
|
|||||||
worker_registry.register(Arc::new(worker1));
|
worker_registry.register(Arc::new(worker1));
|
||||||
worker_registry.register(Arc::new(worker2));
|
worker_registry.register(Arc::new(worker2));
|
||||||
|
|
||||||
let (_, rx) = tokio::sync::watch::channel(HashMap::new());
|
|
||||||
Router {
|
Router {
|
||||||
worker_registry,
|
worker_registry,
|
||||||
policy_registry,
|
policy_registry,
|
||||||
dp_aware: false,
|
dp_aware: false,
|
||||||
client: Client::new(),
|
client: Client::new(),
|
||||||
retry_config: RetryConfig::default(),
|
retry_config: RetryConfig::default(),
|
||||||
_worker_loads: Arc::new(rx),
|
|
||||||
_load_monitor_handle: None,
|
|
||||||
enable_igw: false,
|
enable_igw: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
||||||
core::{WorkerManager, WorkerRegistry, WorkerType},
|
core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType},
|
||||||
data_connector::{
|
data_connector::{
|
||||||
MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage,
|
MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage,
|
||||||
},
|
},
|
||||||
@@ -51,6 +51,7 @@ pub struct AppContext {
|
|||||||
pub policy_registry: Arc<PolicyRegistry>,
|
pub policy_registry: Arc<PolicyRegistry>,
|
||||||
pub router_manager: Option<Arc<RouterManager>>,
|
pub router_manager: Option<Arc<RouterManager>>,
|
||||||
pub response_storage: SharedResponseStorage,
|
pub response_storage: SharedResponseStorage,
|
||||||
|
pub load_monitor: Option<Arc<LoadMonitor>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppContext {
|
impl AppContext {
|
||||||
@@ -107,6 +108,13 @@ impl AppContext {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let load_monitor = Some(Arc::new(LoadMonitor::new(
|
||||||
|
worker_registry.clone(),
|
||||||
|
policy_registry.clone(),
|
||||||
|
client.clone(),
|
||||||
|
router_config.worker_startup_check_interval_secs,
|
||||||
|
)));
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client,
|
client,
|
||||||
router_config,
|
router_config,
|
||||||
@@ -118,6 +126,7 @@ impl AppContext {
|
|||||||
policy_registry,
|
policy_registry,
|
||||||
router_manager,
|
router_manager,
|
||||||
response_storage,
|
response_storage,
|
||||||
|
load_monitor,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -727,6 +736,11 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
config.router_config.health_check.check_interval_secs
|
config.router_config.health_check.check_interval_secs
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if let Some(ref load_monitor) = app_context.load_monitor {
|
||||||
|
load_monitor.start().await;
|
||||||
|
info!("Started LoadMonitor for PowerOfTwo policies");
|
||||||
|
}
|
||||||
|
|
||||||
let (limiter, processor) = middleware::ConcurrencyLimiter::new(
|
let (limiter, processor) = middleware::ConcurrencyLimiter::new(
|
||||||
app_context.rate_limiter.clone(),
|
app_context.rate_limiter.clone(),
|
||||||
config.router_config.queue_size,
|
config.router_config.queue_size,
|
||||||
|
|||||||
@@ -584,6 +584,7 @@ mod tests {
|
|||||||
tool_parser_registry: None, // HTTP mode doesn't need tool parser
|
tool_parser_registry: None, // HTTP mode doesn't need tool parser
|
||||||
router_manager: None, // Test doesn't need router manager
|
router_manager: None, // Test doesn't need router manager
|
||||||
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
||||||
|
load_monitor: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user