[router] allow one router to support different model families and serving mode (#10244)

This commit is contained in:
Simo Lin
2025-09-12 19:18:27 -04:00
committed by GitHub
parent 321fecab74
commit 2f173ea074
28 changed files with 3528 additions and 837 deletions

View File

@@ -1,10 +1,10 @@
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig,
RetryExecutor, Worker, WorkerFactory, WorkerType,
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker,
WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest,
RerankResponse, RerankResult, ResponsesRequest,
@@ -22,7 +22,7 @@ use axum::{
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
@@ -30,8 +30,8 @@ use tracing::{debug, error, info, warn};
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
policy: Arc<dyn LoadBalancingPolicy>,
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
client: Client,
worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64,
@@ -41,7 +41,6 @@ pub struct Router {
circuit_breaker_config: CircuitBreakerConfig,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
_health_checker: Option<HealthChecker>,
}
impl Router {
@@ -49,7 +48,6 @@ impl Router {
#[allow(clippy::too_many_arguments)]
pub async fn new(
worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
// Update active workers gauge
@@ -82,45 +80,51 @@ impl Router {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Create Worker trait objects from URLs with health check config
let workers: Vec<Box<dyn Worker>> = worker_urls
.iter()
.map(|url| {
let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
// Register workers in the registry
// In IGW mode, we need to fetch model info from workers
for url in &worker_urls {
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// For now, create worker without model_id
let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
// Initialize policy with workers if needed (e.g., for cache-aware)
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
let worker_arc = Arc::new(worker);
ctx.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = ctx.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy and it's the first worker for this model,
// initialize it with the worker
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let worker_dyn: Arc<dyn Worker> = worker_arc.clone();
cache_aware.init_workers(std::slice::from_ref(&worker_dyn));
}
}
}
let workers = Arc::new(RwLock::new(workers));
let health_checker = crate::core::start_health_checker(
Arc::clone(&workers),
ctx.router_config.worker_startup_check_interval_secs,
);
// Setup load monitoring for PowerOfTwo policy
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
let load_monitor_handle = if policy.name() == "power_of_two" {
// Check if default policy is power_of_two for load monitoring
let default_policy = ctx.policy_registry.get_default_policy();
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let policy_clone = Arc::clone(&policy);
let policy_clone = default_policy.clone();
let client_clone = ctx.client.clone();
Some(Arc::new(tokio::spawn(async move {
@@ -138,8 +142,8 @@ impl Router {
};
Ok(Router {
workers,
policy,
worker_registry: ctx.worker_registry.clone(),
policy_registry: ctx.policy_registry.clone(),
client: ctx.client.clone(),
worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs,
worker_startup_check_interval_secs: ctx
@@ -151,18 +155,21 @@ impl Router {
circuit_breaker_config: core_cb_config,
_worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle,
_health_checker: Some(health_checker),
})
}
/// Get the current list of worker URLs
pub fn get_worker_urls(&self) -> Vec<String> {
self.workers
.read()
.unwrap()
.iter()
.map(|w| w.url().to_string())
.collect()
self.worker_registry.get_all_urls()
}
/// Get worker URLs for a specific model
pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec<String> {
let workers = match model_id {
Some(model) => self.worker_registry.get_by_model_fast(model),
None => self.worker_registry.get_all(),
};
workers.iter().map(|w| w.url().to_string()).collect()
}
pub async fn wait_for_healthy_workers(
@@ -332,11 +339,27 @@ impl Router {
}
fn select_first_worker(&self) -> Result<String, String> {
let workers_guard = self.workers.read().unwrap();
if workers_guard.is_empty() {
let workers = self.worker_registry.get_all();
if workers.is_empty() {
Err("No workers are available".to_string())
} else {
Ok(workers_guard[0].url().to_string())
Ok(workers[0].url().to_string())
}
}
#[allow(dead_code)]
fn select_first_worker_for_model(&self, model_id: Option<&str>) -> Result<String, String> {
let workers = match model_id {
Some(model) => self.worker_registry.get_by_model_fast(model),
None => self.worker_registry.get_all(),
};
if workers.is_empty() {
Err(format!(
"No workers are available for model: {:?}",
model_id
))
} else {
Ok(workers[0].url().to_string())
}
}
@@ -447,20 +470,35 @@ impl Router {
}
}
// New method to route typed requests directly
/// Select worker considering circuit breaker state
fn select_worker_with_circuit_breaker(&self, text: Option<&str>) -> Option<Box<dyn Worker>> {
let workers = self.workers.read().ok()?;
let available: Vec<Box<dyn Worker>> = workers
/// Select worker for a specific model considering circuit breaker state
fn select_worker_for_model(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn Worker>> {
// Get workers for the specified model (O(1) lookup if model_id is provided)
let workers = match model_id {
Some(model) => self.worker_registry.get_by_model_fast(model),
None => self.worker_registry.get_all(),
};
let available: Vec<Arc<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
.map(|w| w.clone_worker())
.cloned()
.collect();
if available.is_empty() {
return None;
}
let idx = self.policy.select_worker(&available, text)?;
Some(available[idx].clone_worker())
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
@@ -468,6 +506,7 @@ impl Router {
headers: Option<&HeaderMap>,
typed_req: &T,
route: &str,
model_id: Option<&str>,
) -> Response {
let start = Instant::now();
let is_stream = typed_req.is_stream();
@@ -477,7 +516,7 @@ impl Router {
&self.retry_config,
// operation per attempt
|_: u32| async {
let worker = match self.select_worker_with_circuit_breaker(Some(&text)) {
let worker = match self.select_worker_for_model(model_id, Some(&text)) {
Some(w) => w,
None => {
RouterMetrics::record_request_error(route, "no_available_workers");
@@ -490,7 +529,13 @@ impl Router {
};
// Optional load tracking for cache-aware policy
let load_incremented = if self.policy.name() == "cache_aware" {
// Get the policy for this model to check if it's cache-aware
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
let load_incremented = if policy.name() == "cache_aware" {
worker.increment_load();
RouterMetrics::set_running_requests(worker.url(), worker.load());
true
@@ -654,11 +699,9 @@ impl Router {
// Decrement load on error if it was incremented
if load_incremented {
if let Ok(workers_guard) = self.workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
}
@@ -687,13 +730,9 @@ impl Router {
Err(e) => {
// IMPORTANT: Decrement load on error before returning
if load_incremented {
if let Ok(workers_guard) = self.workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == worker_url)
{
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
}
@@ -704,18 +743,16 @@ impl Router {
// Decrement load counter for non-streaming requests if it was incremented
if load_incremented {
if let Ok(workers_guard) = self.workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
}
response
} else if load_incremented {
// For streaming with load tracking, we need to manually decrement when done
let workers = Arc::clone(&self.workers);
let registry = Arc::clone(&self.worker_registry);
let worker_url = worker_url.to_string();
// Preserve headers for streaming response
@@ -739,17 +776,10 @@ impl Router {
.windows(12)
.any(|window| window == b"data: [DONE]")
{
if let Ok(workers_guard) = workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == worker_url)
{
worker.decrement_load();
RouterMetrics::set_running_requests(
&worker_url,
worker.load(),
);
decremented = true;
}
if let Some(worker) = registry.get_by_url(&worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(&worker_url, worker.load());
decremented = true;
}
}
if tx.send(Ok(bytes)).is_err() {
@@ -763,11 +793,9 @@ impl Router {
}
}
if !decremented {
if let Ok(workers_guard) = workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(&worker_url, worker.load());
}
if let Some(worker) = registry.get_by_url(&worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(&worker_url, worker.load());
}
}
});
@@ -839,7 +867,6 @@ impl Router {
match client.get(format!("{}/health", worker_url)).send().await {
Ok(res) => {
if res.status().is_success() {
let mut workers_guard = self.workers.write().unwrap();
if self.dp_aware {
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
@@ -848,46 +875,77 @@ impl Router {
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
let mut worker_added: bool = false;
for dp_url in &dp_url_vec {
if workers_guard.iter().any(|w| w.url() == dp_url) {
if self.worker_registry.get_by_url(dp_url).is_some() {
warn!("Worker {} already exists", dp_url);
continue;
}
info!("Added worker: {}", dp_url);
let new_worker = WorkerFactory::create_regular_with_config(
dp_url.to_string(),
self.circuit_breaker_config.clone(),
);
workers_guard.push(new_worker);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker =
BasicWorker::new(dp_url.to_string(), WorkerType::Regular)
.with_circuit_breaker_config(
self.circuit_breaker_config.clone(),
);
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>(
) {
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
worker_added = true;
}
if !worker_added {
return Err(format!("No worker added for {}", worker_url));
}
} else {
if workers_guard.iter().any(|w| w.url() == worker_url) {
if self.worker_registry.get_by_url(worker_url).is_some() {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
let new_worker = WorkerFactory::create_regular_with_config(
worker_url.to_string(),
self.circuit_breaker_config.clone(),
);
workers_guard.push(new_worker);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker =
BasicWorker::new(worker_url.to_string(), WorkerType::Regular)
.with_circuit_breaker_config(
self.circuit_breaker_config.clone(),
);
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, add this worker to it
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>(
) {
// Get all workers for this model
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
}
RouterMetrics::set_active_workers(workers_guard.len());
// If cache aware policy, initialize the worker in the tree
if let Some(cache_aware) =
self.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
// Get updated workers after adding
drop(workers_guard);
let workers_guard = self.workers.read().unwrap();
cache_aware.init_workers(&workers_guard);
}
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
return Ok(format!("Successfully added worker: {}", worker_url));
} else {
@@ -931,66 +989,73 @@ impl Router {
if self.dp_aware {
// remove dp-aware workers in a prefix-matching fashion
// without contacting the remote worker
let mut candidate_workers: Vec<String> = Vec::new();
let mut removed_workers: Vec<String> = Vec::new();
let worker_url_prefix = format!("{}@", worker_url);
{
// find the candidate workers to be removed
let workers_guard = self.workers.read().unwrap();
for w in workers_guard.iter() {
if w.url().starts_with(&worker_url_prefix) {
candidate_workers.push(w.url().to_string());
}
}
}
// Find and remove all workers with matching prefix
let all_workers = self.worker_registry.get_all();
for w in all_workers.iter() {
if w.url().starts_with(&worker_url_prefix) {
// Get model_id before removing
let model_id = w.model_id().to_string();
{
// do the removing on the worker_urls
let mut workers_guard = self.workers.write().unwrap();
for dp_url in candidate_workers.iter() {
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
workers_guard.remove(index);
info!("Removed worker: {}", dp_url);
removed_workers.push(dp_url.to_string());
if self.worker_registry.remove_by_url(w.url()).is_some() {
info!("Removed worker: {}", w.url());
removed_workers.push(w.url().to_string());
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
} else {
warn!("Worker {} not found, skipping removal", dp_url);
continue;
warn!("Worker {} not found, skipping removal", w.url());
}
}
RouterMetrics::set_active_workers(workers_guard.len());
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
for dp_url in removed_workers.iter() {
cache_aware.remove_worker(dp_url);
info!("Removed worker from tree: {}", dp_url);
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
// If any models are using cache aware policy, remove the workers from the tree
// Check each removed worker's model and get its policy
for dp_url in removed_workers.iter() {
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
let model_id = worker.model_id();
if let Some(policy) = self.policy_registry.get_policy(model_id) {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(dp_url);
info!("Removed worker from cache-aware tree: {}", dp_url);
}
}
}
}
} else {
let mut workers_guard = self.workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
info!("Removed worker: {}", worker_url);
RouterMetrics::set_active_workers(workers_guard.len());
// Get the worker first to extract model_id
let model_id = if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
worker.model_id().to_string()
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
};
if self.worker_registry.remove_by_url(worker_url).is_some() {
info!("Removed worker: {}", worker_url);
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker(worker_url);
info!("Removed worker from tree: {}", worker_url);
// If the model is using cache aware policy, remove the worker from the tree
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(worker_url);
info!("Removed worker from cache-aware tree: {}", worker_url);
}
}
}
}
@@ -1171,7 +1236,7 @@ impl RouterTrait for Router {
}
async fn health(&self, _req: Request<Body>) -> Response {
let workers = self.workers.read().unwrap();
let workers = self.worker_registry.get_all();
let unhealthy_servers: Vec<_> = workers
.iter()
.filter(|w| !w.is_healthy())
@@ -1209,16 +1274,19 @@ impl RouterTrait for Router {
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response {
self.route_typed_request(headers, body, "/generate").await
self.route_typed_request(headers, body, "/generate", model_id)
.await
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
self.route_typed_request(headers, body, "/v1/chat/completions")
self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
.await
}
@@ -1226,8 +1294,9 @@ impl RouterTrait for Router {
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
model_id: Option<&str>,
) -> Response {
self.route_typed_request(headers, body, "/v1/completions")
self.route_typed_request(headers, body, "/v1/completions", model_id)
.await
}
@@ -1235,8 +1304,9 @@ impl RouterTrait for Router {
&self,
headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
self.route_typed_request(headers, body, "/v1/responses")
self.route_typed_request(headers, body, "/v1/responses", model_id)
.await
}
@@ -1244,11 +1314,18 @@ impl RouterTrait for Router {
todo!()
}
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response {
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response {
if let Err(e) = body.validate() {
return (StatusCode::BAD_REQUEST, e).into_response();
}
let response = self.route_typed_request(headers, body, "/v1/rerank").await;
let response = self
.route_typed_request(headers, body, "/v1/rerank", model_id)
.await;
if response.status().is_success() {
match Self::build_rerank_response(body, response).await {
Ok(rerank_response) => rerank_response,
@@ -1340,19 +1417,15 @@ impl RouterTrait for Router {
fn readiness(&self) -> Response {
// Regular router is ready if it has at least one healthy worker
let healthy_count = self
.workers
.read()
.unwrap()
.iter()
.filter(|w| w.is_healthy())
.count();
let workers = self.worker_registry.get_all();
let healthy_count = workers.iter().filter(|w| w.is_healthy()).count();
let total_workers = workers.len();
if healthy_count > 0 {
Json(serde_json::json!({
"status": "ready",
"healthy_workers": healthy_count,
"total_workers": self.workers.read().unwrap().len()
"total_workers": total_workers
}))
.into_response()
} else {
@@ -1361,7 +1434,7 @@ impl RouterTrait for Router {
Json(serde_json::json!({
"status": "not_ready",
"reason": "no healthy workers available",
"total_workers": self.workers.read().unwrap().len()
"total_workers": total_workers
})),
)
.into_response()
@@ -1372,18 +1445,25 @@ impl RouterTrait for Router {
#[cfg(test)]
mod tests {
use super::*;
use crate::policies::RandomPolicy;
use std::collections::HashMap;
fn create_test_regular_router() -> Router {
let workers = vec![
WorkerFactory::create_regular("http://worker1:8080".to_string()),
WorkerFactory::create_regular("http://worker2:8080".to_string()),
];
// Create registries
let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry = Arc::new(PolicyRegistry::new(
crate::config::types::PolicyConfig::RoundRobin,
));
// Register test workers
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
worker_registry.register(Arc::new(worker1));
worker_registry.register(Arc::new(worker2));
let (_, rx) = tokio::sync::watch::channel(HashMap::new());
Router {
workers: Arc::new(RwLock::new(workers)),
policy: Arc::new(RandomPolicy::new()),
worker_registry,
policy_registry,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
@@ -1393,7 +1473,6 @@ mod tests {
circuit_breaker_config: CircuitBreakerConfig::default(),
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
_health_checker: None,
}
}
@@ -1413,7 +1492,9 @@ mod tests {
let result = router.select_first_worker();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "http://worker1:8080");
let url = result.unwrap();
// DashMap doesn't guarantee order, so just check we get one of the workers
assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
}
#[tokio::test]