[router] allow one router to support different model families and serving mode (#10244)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig,
|
||||
RetryExecutor, Worker, WorkerFactory, WorkerType,
|
||||
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker,
|
||||
WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest,
|
||||
RerankResponse, RerankResult, ResponsesRequest,
|
||||
@@ -22,7 +22,7 @@ use axum::{
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -30,8 +30,8 @@ use tracing::{debug, error, info, warn};
|
||||
/// Regular router that uses injected load balancing policies
|
||||
#[derive(Debug)]
|
||||
pub struct Router {
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
policy_registry: Arc<PolicyRegistry>,
|
||||
client: Client,
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval_secs: u64,
|
||||
@@ -41,7 +41,6 @@ pub struct Router {
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
_health_checker: Option<HealthChecker>,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -49,7 +48,6 @@ impl Router {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
ctx: &Arc<crate::server::AppContext>,
|
||||
) -> Result<Self, String> {
|
||||
// Update active workers gauge
|
||||
@@ -82,45 +80,51 @@ impl Router {
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create Worker trait objects from URLs with health check config
|
||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
// Register workers in the registry
|
||||
// In IGW mode, we need to fetch model info from workers
|
||||
for url in &worker_urls {
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
// For now, create worker without model_id
|
||||
let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
|
||||
// Initialize policy with workers if needed (e.g., for cache-aware)
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&workers);
|
||||
let worker_arc = Arc::new(worker);
|
||||
ctx.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = ctx.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy and it's the first worker for this model,
|
||||
// initialize it with the worker
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
let worker_dyn: Arc<dyn Worker> = worker_arc.clone();
|
||||
cache_aware.init_workers(std::slice::from_ref(&worker_dyn));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let workers = Arc::new(RwLock::new(workers));
|
||||
let health_checker = crate::core::start_health_checker(
|
||||
Arc::clone(&workers),
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
);
|
||||
|
||||
// Setup load monitoring for PowerOfTwo policy
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
|
||||
let load_monitor_handle = if policy.name() == "power_of_two" {
|
||||
// Check if default policy is power_of_two for load monitoring
|
||||
let default_policy = ctx.policy_registry.get_default_policy();
|
||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
||||
let policy_clone = Arc::clone(&policy);
|
||||
let policy_clone = default_policy.clone();
|
||||
let client_clone = ctx.client.clone();
|
||||
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
@@ -138,8 +142,8 @@ impl Router {
|
||||
};
|
||||
|
||||
Ok(Router {
|
||||
workers,
|
||||
policy,
|
||||
worker_registry: ctx.worker_registry.clone(),
|
||||
policy_registry: ctx.policy_registry.clone(),
|
||||
client: ctx.client.clone(),
|
||||
worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval_secs: ctx
|
||||
@@ -151,18 +155,21 @@ impl Router {
|
||||
circuit_breaker_config: core_cb_config,
|
||||
_worker_loads: worker_loads,
|
||||
_load_monitor_handle: load_monitor_handle,
|
||||
_health_checker: Some(health_checker),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the current list of worker URLs
|
||||
pub fn get_worker_urls(&self) -> Vec<String> {
|
||||
self.workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect()
|
||||
self.worker_registry.get_all_urls()
|
||||
}
|
||||
|
||||
/// Get worker URLs for a specific model
|
||||
pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec<String> {
|
||||
let workers = match model_id {
|
||||
Some(model) => self.worker_registry.get_by_model_fast(model),
|
||||
None => self.worker_registry.get_all(),
|
||||
};
|
||||
workers.iter().map(|w| w.url().to_string()).collect()
|
||||
}
|
||||
|
||||
pub async fn wait_for_healthy_workers(
|
||||
@@ -332,11 +339,27 @@ impl Router {
|
||||
}
|
||||
|
||||
fn select_first_worker(&self) -> Result<String, String> {
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
if workers_guard.is_empty() {
|
||||
let workers = self.worker_registry.get_all();
|
||||
if workers.is_empty() {
|
||||
Err("No workers are available".to_string())
|
||||
} else {
|
||||
Ok(workers_guard[0].url().to_string())
|
||||
Ok(workers[0].url().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn select_first_worker_for_model(&self, model_id: Option<&str>) -> Result<String, String> {
|
||||
let workers = match model_id {
|
||||
Some(model) => self.worker_registry.get_by_model_fast(model),
|
||||
None => self.worker_registry.get_all(),
|
||||
};
|
||||
if workers.is_empty() {
|
||||
Err(format!(
|
||||
"No workers are available for model: {:?}",
|
||||
model_id
|
||||
))
|
||||
} else {
|
||||
Ok(workers[0].url().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -447,20 +470,35 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// New method to route typed requests directly
|
||||
/// Select worker considering circuit breaker state
|
||||
fn select_worker_with_circuit_breaker(&self, text: Option<&str>) -> Option<Box<dyn Worker>> {
|
||||
let workers = self.workers.read().ok()?;
|
||||
let available: Vec<Box<dyn Worker>> = workers
|
||||
/// Select worker for a specific model considering circuit breaker state
|
||||
fn select_worker_for_model(
|
||||
&self,
|
||||
model_id: Option<&str>,
|
||||
text: Option<&str>,
|
||||
) -> Option<Arc<dyn Worker>> {
|
||||
// Get workers for the specified model (O(1) lookup if model_id is provided)
|
||||
let workers = match model_id {
|
||||
Some(model) => self.worker_registry.get_by_model_fast(model),
|
||||
None => self.worker_registry.get_all(),
|
||||
};
|
||||
|
||||
let available: Vec<Arc<dyn Worker>> = workers
|
||||
.iter()
|
||||
.filter(|w| w.is_available())
|
||||
.map(|w| w.clone_worker())
|
||||
.cloned()
|
||||
.collect();
|
||||
if available.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let idx = self.policy.select_worker(&available, text)?;
|
||||
Some(available[idx].clone_worker())
|
||||
|
||||
// Get the appropriate policy for this model
|
||||
let policy = match model_id {
|
||||
Some(model) => self.policy_registry.get_policy_or_default(model),
|
||||
None => self.policy_registry.get_default_policy(),
|
||||
};
|
||||
|
||||
let idx = policy.select_worker(&available, text)?;
|
||||
Some(available[idx].clone())
|
||||
}
|
||||
|
||||
pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
|
||||
@@ -468,6 +506,7 @@ impl Router {
|
||||
headers: Option<&HeaderMap>,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
let start = Instant::now();
|
||||
let is_stream = typed_req.is_stream();
|
||||
@@ -477,7 +516,7 @@ impl Router {
|
||||
&self.retry_config,
|
||||
// operation per attempt
|
||||
|_: u32| async {
|
||||
let worker = match self.select_worker_with_circuit_breaker(Some(&text)) {
|
||||
let worker = match self.select_worker_for_model(model_id, Some(&text)) {
|
||||
Some(w) => w,
|
||||
None => {
|
||||
RouterMetrics::record_request_error(route, "no_available_workers");
|
||||
@@ -490,7 +529,13 @@ impl Router {
|
||||
};
|
||||
|
||||
// Optional load tracking for cache-aware policy
|
||||
let load_incremented = if self.policy.name() == "cache_aware" {
|
||||
// Get the policy for this model to check if it's cache-aware
|
||||
let policy = match model_id {
|
||||
Some(model) => self.policy_registry.get_policy_or_default(model),
|
||||
None => self.policy_registry.get_default_policy(),
|
||||
};
|
||||
|
||||
let load_incremented = if policy.name() == "cache_aware" {
|
||||
worker.increment_load();
|
||||
RouterMetrics::set_running_requests(worker.url(), worker.load());
|
||||
true
|
||||
@@ -654,11 +699,9 @@ impl Router {
|
||||
|
||||
// Decrement load on error if it was incremented
|
||||
if load_incremented {
|
||||
if let Ok(workers_guard) = self.workers.read() {
|
||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -687,13 +730,9 @@ impl Router {
|
||||
Err(e) => {
|
||||
// IMPORTANT: Decrement load on error before returning
|
||||
if load_incremented {
|
||||
if let Ok(workers_guard) = self.workers.read() {
|
||||
if let Some(worker) =
|
||||
workers_guard.iter().find(|w| w.url() == worker_url)
|
||||
{
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -704,18 +743,16 @@ impl Router {
|
||||
|
||||
// Decrement load counter for non-streaming requests if it was incremented
|
||||
if load_incremented {
|
||||
if let Ok(workers_guard) = self.workers.read() {
|
||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
} else if load_incremented {
|
||||
// For streaming with load tracking, we need to manually decrement when done
|
||||
let workers = Arc::clone(&self.workers);
|
||||
let registry = Arc::clone(&self.worker_registry);
|
||||
let worker_url = worker_url.to_string();
|
||||
|
||||
// Preserve headers for streaming response
|
||||
@@ -739,17 +776,10 @@ impl Router {
|
||||
.windows(12)
|
||||
.any(|window| window == b"data: [DONE]")
|
||||
{
|
||||
if let Ok(workers_guard) = workers.read() {
|
||||
if let Some(worker) =
|
||||
workers_guard.iter().find(|w| w.url() == worker_url)
|
||||
{
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(
|
||||
&worker_url,
|
||||
worker.load(),
|
||||
);
|
||||
decremented = true;
|
||||
}
|
||||
if let Some(worker) = registry.get_by_url(&worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(&worker_url, worker.load());
|
||||
decremented = true;
|
||||
}
|
||||
}
|
||||
if tx.send(Ok(bytes)).is_err() {
|
||||
@@ -763,11 +793,9 @@ impl Router {
|
||||
}
|
||||
}
|
||||
if !decremented {
|
||||
if let Ok(workers_guard) = workers.read() {
|
||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(&worker_url, worker.load());
|
||||
}
|
||||
if let Some(worker) = registry.get_by_url(&worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(&worker_url, worker.load());
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -839,7 +867,6 @@ impl Router {
|
||||
match client.get(format!("{}/health", worker_url)).send().await {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if self.dp_aware {
|
||||
// Need to contact the worker to extract the dp_size,
|
||||
// and add them as multiple workers
|
||||
@@ -848,46 +875,77 @@ impl Router {
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
|
||||
let mut worker_added: bool = false;
|
||||
for dp_url in &dp_url_vec {
|
||||
if workers_guard.iter().any(|w| w.url() == dp_url) {
|
||||
if self.worker_registry.get_by_url(dp_url).is_some() {
|
||||
warn!("Worker {} already exists", dp_url);
|
||||
continue;
|
||||
}
|
||||
info!("Added worker: {}", dp_url);
|
||||
let new_worker = WorkerFactory::create_regular_with_config(
|
||||
dp_url.to_string(),
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
workers_guard.push(new_worker);
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker =
|
||||
BasicWorker::new(dp_url.to_string(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, update it with all workers for this model
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>(
|
||||
) {
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
|
||||
worker_added = true;
|
||||
}
|
||||
if !worker_added {
|
||||
return Err(format!("No worker added for {}", worker_url));
|
||||
}
|
||||
} else {
|
||||
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||
if self.worker_registry.get_by_url(worker_url).is_some() {
|
||||
return Err(format!("Worker {} already exists", worker_url));
|
||||
}
|
||||
info!("Added worker: {}", worker_url);
|
||||
let new_worker = WorkerFactory::create_regular_with_config(
|
||||
worker_url.to_string(),
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
workers_guard.push(new_worker);
|
||||
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker =
|
||||
BasicWorker::new(worker_url.to_string(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, add this worker to it
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>(
|
||||
) {
|
||||
// Get all workers for this model
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
|
||||
// If cache aware policy, initialize the worker in the tree
|
||||
if let Some(cache_aware) =
|
||||
self.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
// Get updated workers after adding
|
||||
drop(workers_guard);
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
cache_aware.init_workers(&workers_guard);
|
||||
}
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
|
||||
return Ok(format!("Successfully added worker: {}", worker_url));
|
||||
} else {
|
||||
@@ -931,66 +989,73 @@ impl Router {
|
||||
if self.dp_aware {
|
||||
// remove dp-aware workers in a prefix-matching fashion
|
||||
// without contacting the remote worker
|
||||
let mut candidate_workers: Vec<String> = Vec::new();
|
||||
let mut removed_workers: Vec<String> = Vec::new();
|
||||
let worker_url_prefix = format!("{}@", worker_url);
|
||||
|
||||
{
|
||||
// find the candidate workers to be removed
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
for w in workers_guard.iter() {
|
||||
if w.url().starts_with(&worker_url_prefix) {
|
||||
candidate_workers.push(w.url().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Find and remove all workers with matching prefix
|
||||
let all_workers = self.worker_registry.get_all();
|
||||
for w in all_workers.iter() {
|
||||
if w.url().starts_with(&worker_url_prefix) {
|
||||
// Get model_id before removing
|
||||
let model_id = w.model_id().to_string();
|
||||
|
||||
{
|
||||
// do the removing on the worker_urls
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
for dp_url in candidate_workers.iter() {
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", dp_url);
|
||||
removed_workers.push(dp_url.to_string());
|
||||
if self.worker_registry.remove_by_url(w.url()).is_some() {
|
||||
info!("Removed worker: {}", w.url());
|
||||
removed_workers.push(w.url().to_string());
|
||||
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", dp_url);
|
||||
continue;
|
||||
warn!("Worker {} not found, skipping removal", w.url());
|
||||
}
|
||||
}
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
}
|
||||
|
||||
// If cache aware policy, remove the workers from the tree
|
||||
if let Some(cache_aware) = self
|
||||
.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
for dp_url in removed_workers.iter() {
|
||||
cache_aware.remove_worker(dp_url);
|
||||
info!("Removed worker from tree: {}", dp_url);
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
|
||||
// If any models are using cache aware policy, remove the workers from the tree
|
||||
// Check each removed worker's model and get its policy
|
||||
for dp_url in removed_workers.iter() {
|
||||
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
|
||||
let model_id = worker.model_id();
|
||||
if let Some(policy) = self.policy_registry.get_policy(model_id) {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(dp_url);
|
||||
info!("Removed worker from cache-aware tree: {}", dp_url);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", worker_url);
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
// Get the worker first to extract model_id
|
||||
let model_id = if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
|
||||
worker.model_id().to_string()
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
return;
|
||||
};
|
||||
|
||||
if self.worker_registry.remove_by_url(worker_url).is_some() {
|
||||
info!("Removed worker: {}", worker_url);
|
||||
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
}
|
||||
|
||||
// If cache aware policy, remove the workers from the tree
|
||||
if let Some(cache_aware) = self
|
||||
.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker(worker_url);
|
||||
info!("Removed worker from tree: {}", worker_url);
|
||||
// If the model is using cache aware policy, remove the worker from the tree
|
||||
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(worker_url);
|
||||
info!("Removed worker from cache-aware tree: {}", worker_url);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1171,7 +1236,7 @@ impl RouterTrait for Router {
|
||||
}
|
||||
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
let workers = self.workers.read().unwrap();
|
||||
let workers = self.worker_registry.get_all();
|
||||
let unhealthy_servers: Vec<_> = workers
|
||||
.iter()
|
||||
.filter(|w| !w.is_healthy())
|
||||
@@ -1209,16 +1274,19 @@ impl RouterTrait for Router {
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
self.route_typed_request(headers, body, "/generate").await
|
||||
self.route_typed_request(headers, body, "/generate", model_id)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
self.route_typed_request(headers, body, "/v1/chat/completions")
|
||||
self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1226,8 +1294,9 @@ impl RouterTrait for Router {
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
self.route_typed_request(headers, body, "/v1/completions")
|
||||
self.route_typed_request(headers, body, "/v1/completions", model_id)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1235,8 +1304,9 @@ impl RouterTrait for Router {
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ResponsesRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
self.route_typed_request(headers, body, "/v1/responses")
|
||||
self.route_typed_request(headers, body, "/v1/responses", model_id)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1244,11 +1314,18 @@ impl RouterTrait for Router {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response {
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &RerankRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
if let Err(e) = body.validate() {
|
||||
return (StatusCode::BAD_REQUEST, e).into_response();
|
||||
}
|
||||
let response = self.route_typed_request(headers, body, "/v1/rerank").await;
|
||||
let response = self
|
||||
.route_typed_request(headers, body, "/v1/rerank", model_id)
|
||||
.await;
|
||||
if response.status().is_success() {
|
||||
match Self::build_rerank_response(body, response).await {
|
||||
Ok(rerank_response) => rerank_response,
|
||||
@@ -1340,19 +1417,15 @@ impl RouterTrait for Router {
|
||||
|
||||
fn readiness(&self) -> Response {
|
||||
// Regular router is ready if it has at least one healthy worker
|
||||
let healthy_count = self
|
||||
.workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter(|w| w.is_healthy())
|
||||
.count();
|
||||
let workers = self.worker_registry.get_all();
|
||||
let healthy_count = workers.iter().filter(|w| w.is_healthy()).count();
|
||||
let total_workers = workers.len();
|
||||
|
||||
if healthy_count > 0 {
|
||||
Json(serde_json::json!({
|
||||
"status": "ready",
|
||||
"healthy_workers": healthy_count,
|
||||
"total_workers": self.workers.read().unwrap().len()
|
||||
"total_workers": total_workers
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
@@ -1361,7 +1434,7 @@ impl RouterTrait for Router {
|
||||
Json(serde_json::json!({
|
||||
"status": "not_ready",
|
||||
"reason": "no healthy workers available",
|
||||
"total_workers": self.workers.read().unwrap().len()
|
||||
"total_workers": total_workers
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
@@ -1372,18 +1445,25 @@ impl RouterTrait for Router {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::policies::RandomPolicy;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_regular_router() -> Router {
|
||||
let workers = vec![
|
||||
WorkerFactory::create_regular("http://worker1:8080".to_string()),
|
||||
WorkerFactory::create_regular("http://worker2:8080".to_string()),
|
||||
];
|
||||
// Create registries
|
||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||
let policy_registry = Arc::new(PolicyRegistry::new(
|
||||
crate::config::types::PolicyConfig::RoundRobin,
|
||||
));
|
||||
|
||||
// Register test workers
|
||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
|
||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
|
||||
worker_registry.register(Arc::new(worker1));
|
||||
worker_registry.register(Arc::new(worker2));
|
||||
|
||||
let (_, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
Router {
|
||||
workers: Arc::new(RwLock::new(workers)),
|
||||
policy: Arc::new(RandomPolicy::new()),
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
@@ -1393,7 +1473,6 @@ mod tests {
|
||||
circuit_breaker_config: CircuitBreakerConfig::default(),
|
||||
_worker_loads: Arc::new(rx),
|
||||
_load_monitor_handle: None,
|
||||
_health_checker: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1413,7 +1492,9 @@ mod tests {
|
||||
let result = router.select_first_worker();
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "http://worker1:8080");
|
||||
let url = result.unwrap();
|
||||
// DashMap doesn't guarantee order, so just check we get one of the workers
|
||||
assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user