[router] refactor router and worker management 3/n (#10727)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
|
||||
Worker, WorkerRegistry, WorkerType,
|
||||
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
@@ -10,7 +9,7 @@ use crate::protocols::spec::{
|
||||
RerankRequest, RerankResponse, RerankResult, ResponsesRequest,
|
||||
};
|
||||
use crate::routers::header_utils;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use crate::routers::RouterTrait;
|
||||
use axum::body::to_bytes;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -27,7 +26,7 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tracing::{debug, error};
|
||||
|
||||
/// Regular router that uses injected load balancing policies
|
||||
#[derive(Debug)]
|
||||
@@ -35,13 +34,8 @@ pub struct Router {
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
policy_registry: Arc<PolicyRegistry>,
|
||||
client: Client,
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
#[allow(dead_code)]
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
@@ -56,30 +50,15 @@ impl Router {
|
||||
false, // include all workers
|
||||
);
|
||||
|
||||
// Update active workers gauge
|
||||
RouterMetrics::set_active_workers(workers.len());
|
||||
|
||||
// Get worker URLs for monitoring
|
||||
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||
let core_cb_config = CircuitBreakerConfig {
|
||||
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||
success_threshold: circuit_breaker_config.success_threshold,
|
||||
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Cache-aware policies are initialized in WorkerInitializer
|
||||
// Setup load monitoring for PowerOfTwo policy
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
|
||||
// Get default policy to check if we need load monitoring
|
||||
let default_policy = ctx.policy_registry.get_default_policy();
|
||||
|
||||
// Check if default policy is power_of_two for load monitoring
|
||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
let monitor_api_keys = monitor_urls
|
||||
@@ -113,201 +92,13 @@ impl Router {
|
||||
worker_registry: ctx.worker_registry.clone(),
|
||||
policy_registry: ctx.policy_registry.clone(),
|
||||
client: ctx.client.clone(),
|
||||
worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval_secs: ctx
|
||||
.router_config
|
||||
.worker_startup_check_interval_secs,
|
||||
dp_aware: ctx.router_config.dp_aware,
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
circuit_breaker_config: core_cb_config,
|
||||
_worker_loads: worker_loads,
|
||||
_load_monitor_handle: load_monitor_handle,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the current list of worker URLs
|
||||
pub fn get_worker_urls(&self) -> Vec<String> {
|
||||
self.worker_registry.get_all_urls()
|
||||
}
|
||||
|
||||
/// Get worker URLs for a specific model
|
||||
pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec<String> {
|
||||
let workers = self.worker_registry.get_workers_filtered(
|
||||
model_id,
|
||||
Some(WorkerType::Regular),
|
||||
Some(ConnectionMode::Http),
|
||||
false, // get all workers
|
||||
);
|
||||
workers.iter().map(|w| w.url().to_string()).collect()
|
||||
}
|
||||
|
||||
pub async fn wait_for_healthy_workers(
|
||||
worker_urls: &[String],
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval_secs: u64,
|
||||
) -> Result<(), String> {
|
||||
if worker_urls.is_empty() {
|
||||
return Err(
|
||||
"Timeout waiting for workers to become healthy: no workers provided".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
// Perform health check asynchronously
|
||||
Self::wait_for_healthy_workers_async(
|
||||
worker_urls,
|
||||
worker_startup_timeout_secs,
|
||||
worker_startup_check_interval_secs,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn wait_for_healthy_workers_async(
|
||||
worker_urls: &[String],
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval_secs: u64,
|
||||
) -> Result<(), String> {
|
||||
info!(
|
||||
"Waiting for {} workers to become healthy (timeout: {}s)",
|
||||
worker_urls.len(),
|
||||
worker_startup_timeout_secs
|
||||
);
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(2))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(worker_startup_timeout_secs) {
|
||||
error!(
|
||||
"Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
|
||||
worker_startup_timeout_secs, worker_urls
|
||||
);
|
||||
return Err(format!(
|
||||
"Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
|
||||
worker_startup_timeout_secs, worker_urls
|
||||
));
|
||||
}
|
||||
|
||||
// Perform all health checks concurrently
|
||||
let mut health_checks = Vec::new();
|
||||
for url in worker_urls {
|
||||
let client_clone = client.clone();
|
||||
let url_clone = url.clone();
|
||||
|
||||
let check_health = tokio::spawn(async move {
|
||||
let health_url = format!("{}/health", url_clone);
|
||||
match client_clone.get(&health_url).send().await {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
None
|
||||
} else {
|
||||
Some((url_clone, format!("status: {}", res.status())))
|
||||
}
|
||||
}
|
||||
Err(_) => Some((url_clone, "not ready".to_string())),
|
||||
}
|
||||
});
|
||||
|
||||
health_checks.push(check_health);
|
||||
}
|
||||
|
||||
// Wait for all health checks to complete
|
||||
let results = futures::future::join_all(health_checks).await;
|
||||
|
||||
let mut all_healthy = true;
|
||||
let mut unhealthy_workers = Vec::new();
|
||||
|
||||
for result in results {
|
||||
match result {
|
||||
Ok(None) => {
|
||||
// Worker is healthy
|
||||
}
|
||||
Ok(Some((url, reason))) => {
|
||||
all_healthy = false;
|
||||
unhealthy_workers.push((url, reason));
|
||||
}
|
||||
Err(e) => {
|
||||
all_healthy = false;
|
||||
unhealthy_workers
|
||||
.push(("unknown".to_string(), format!("task error: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if all_healthy {
|
||||
info!("All {} workers are healthy", worker_urls.len());
|
||||
return Ok(());
|
||||
} else {
|
||||
debug!(
|
||||
"Waiting for {} workers to become healthy ({} unhealthy: {:?})",
|
||||
worker_urls.len(),
|
||||
unhealthy_workers.len(),
|
||||
unhealthy_workers
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(worker_startup_check_interval_secs)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_worker_dp_size(worker_url: &str, api_key: &Option<String>) -> Result<usize, String> {
|
||||
let sync_client = reqwest::blocking::Client::new();
|
||||
let mut req_builder = sync_client.get(format!("{}/get_server_info", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
match req_builder.send() {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
let server_info = res
|
||||
.text()
|
||||
.map_err(|e| format!("failed to read text from response: {}", e))?;
|
||||
|
||||
let server_info: serde_json::Value = serde_json::from_str(&server_info)
|
||||
.map_err(|e| format!("failed to decode JSON: {}", e))?;
|
||||
|
||||
let dp_size = server_info
|
||||
.get("dp_size")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| String::from("dp_size not found or not an u64"))?;
|
||||
|
||||
Ok(if dp_size > usize::MAX as u64 {
|
||||
return Err(format!("dp_size is too large: {}", dp_size));
|
||||
} else {
|
||||
dp_size as usize
|
||||
})
|
||||
} else {
|
||||
Err(format!("unexpected status code: {}", res.status()))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(format!("error response: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
// Given a list of workers, return a list of workers with dp_rank as suffix
|
||||
fn get_dp_aware_workers(
|
||||
worker_urls: &[String],
|
||||
api_key: &Option<String>,
|
||||
) -> Result<Vec<String>, String> {
|
||||
let mut dp_aware_workers: Vec<String> = Vec::new();
|
||||
|
||||
for url in worker_urls {
|
||||
match Self::get_worker_dp_size(url, api_key) {
|
||||
Ok(dp_size) => {
|
||||
for i in 0..dp_size {
|
||||
dp_aware_workers.push(format!("{}@{}", url, i));
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dp_aware_workers)
|
||||
}
|
||||
|
||||
fn select_first_worker(&self) -> Result<String, String> {
|
||||
let workers = self.worker_registry.get_all();
|
||||
if workers.is_empty() {
|
||||
@@ -317,65 +108,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_health_check(&self, worker_url: &str) -> Response {
|
||||
let health_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
match Self::extract_dp_rank(worker_url) {
|
||||
Ok((worker_url_prefix, _dp_rank)) => worker_url_prefix,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank for health check: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to extract dp_rank: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
let request_builder = self.client.get(format!("{}/health", health_url));
|
||||
|
||||
let response = match request_builder.send().await {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
match res.bytes().await {
|
||||
Ok(body) => (status, body).into_response(),
|
||||
Err(e) => {
|
||||
error!(
|
||||
worker_url = %health_url,
|
||||
error = %e,
|
||||
"Failed to read health response body"
|
||||
);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response body: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
worker_url = %health_url,
|
||||
error = %e,
|
||||
"Failed to send health request to worker"
|
||||
);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to send request to worker {}: {}", health_url, e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
// Don't record metrics for health checks
|
||||
response
|
||||
}
|
||||
|
||||
// Helper method to proxy GET requests to the first available worker
|
||||
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
|
||||
let headers = header_utils::copy_request_headers(&req);
|
||||
@@ -575,14 +307,15 @@ impl Router {
|
||||
) -> Response {
|
||||
// TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
|
||||
// Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
|
||||
let worker_urls = self.get_worker_urls();
|
||||
if worker_urls.is_empty() {
|
||||
let workers = self.worker_registry.get_all();
|
||||
if workers.is_empty() {
|
||||
return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
|
||||
}
|
||||
|
||||
let mut last_response: Option<Response> = None;
|
||||
for worker_url in worker_urls {
|
||||
let base = self.worker_base_url(&worker_url);
|
||||
for worker in workers {
|
||||
let worker_url = worker.url();
|
||||
let base = self.worker_base_url(worker_url);
|
||||
|
||||
let url = format!("{}/{}", base, endpoint);
|
||||
let mut request_builder = match method {
|
||||
@@ -597,6 +330,11 @@ impl Router {
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(api_key) = worker.api_key() {
|
||||
request_builder =
|
||||
request_builder.header("Authorization", format!("Bearer {}", api_key));
|
||||
}
|
||||
|
||||
if let Some(hdrs) = headers {
|
||||
for (name, value) in hdrs {
|
||||
let name_lc = name.as_str().to_lowercase();
|
||||
@@ -691,6 +429,12 @@ impl Router {
|
||||
is_stream: bool,
|
||||
load_incremented: bool, // Whether load was incremented for this request
|
||||
) -> Response {
|
||||
// Get the worker's API key if available
|
||||
let api_key = self
|
||||
.worker_registry
|
||||
.get_by_url(worker_url)
|
||||
.and_then(|w| w.api_key().clone());
|
||||
|
||||
let mut request_builder = if self.dp_aware {
|
||||
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
@@ -704,7 +448,6 @@ impl Router {
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the request body
|
||||
let mut json_val = match serde_json::to_value(typed_req) {
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
@@ -716,7 +459,6 @@ impl Router {
|
||||
}
|
||||
};
|
||||
|
||||
// Insert the data_parallel_rank field
|
||||
if let Some(map) = json_val.as_object_mut() {
|
||||
map.insert(
|
||||
String::from("data_parallel_rank"),
|
||||
@@ -743,6 +485,10 @@ impl Router {
|
||||
.json(typed_req) // Use json() directly with typed request
|
||||
};
|
||||
|
||||
if let Some(key) = api_key {
|
||||
request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
// Copy all headers from original request if provided
|
||||
if let Some(headers) = headers {
|
||||
for (name, value) in headers {
|
||||
@@ -909,215 +655,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.worker_startup_timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(self.worker_startup_timeout_secs) {
|
||||
error!(
|
||||
"Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
|
||||
self.worker_startup_timeout_secs, worker_url
|
||||
);
|
||||
return Err(format!(
|
||||
"Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
|
||||
self.worker_startup_timeout_secs, worker_url
|
||||
));
|
||||
}
|
||||
|
||||
match client.get(format!("{}/health", worker_url)).send().await {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
if self.dp_aware {
|
||||
// Need to contact the worker to extract the dp_size,
|
||||
// and add them as multiple workers
|
||||
let url_vec = vec![String::from(worker_url)];
|
||||
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, api_key)
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
|
||||
let mut worker_added: bool = false;
|
||||
for dp_url in &dp_url_vec {
|
||||
if self.worker_registry.get_by_url(dp_url).is_some() {
|
||||
warn!("Worker {} already exists", dp_url);
|
||||
continue;
|
||||
}
|
||||
info!("Added worker: {}", dp_url);
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker_builder =
|
||||
BasicWorkerBuilder::new(dp_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
|
||||
let new_worker = if let Some(api_key) = api_key {
|
||||
new_worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
new_worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers = self.worker_registry.get_workers_filtered(
|
||||
Some(model_id),
|
||||
Some(WorkerType::Regular),
|
||||
Some(ConnectionMode::Http),
|
||||
false,
|
||||
);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
|
||||
worker_added = true;
|
||||
}
|
||||
if !worker_added {
|
||||
return Err(format!("No worker added for {}", worker_url));
|
||||
}
|
||||
} else {
|
||||
if self.worker_registry.get_by_url(worker_url).is_some() {
|
||||
return Err(format!("Worker {} already exists", worker_url));
|
||||
}
|
||||
info!("Added worker: {}", worker_url);
|
||||
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker_builder =
|
||||
BasicWorkerBuilder::new(worker_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone());
|
||||
|
||||
let new_worker = if let Some(api_key) = api_key {
|
||||
new_worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
new_worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers = self.worker_registry.get_workers_filtered(
|
||||
Some(model_id),
|
||||
Some(WorkerType::Regular),
|
||||
Some(ConnectionMode::Http),
|
||||
false,
|
||||
);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
}
|
||||
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
|
||||
return Ok(format!("Successfully added worker: {}", worker_url));
|
||||
} else {
|
||||
debug!(
|
||||
"Worker {} health check pending - status: {}",
|
||||
worker_url,
|
||||
res.status()
|
||||
);
|
||||
// if the url does not have http or https prefix, warn users
|
||||
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
|
||||
{
|
||||
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(
|
||||
self.worker_startup_check_interval_secs,
|
||||
))
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Worker {} health check pending - error: {}", worker_url, e);
|
||||
|
||||
// if the url does not have http or https prefix, warn users
|
||||
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
|
||||
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(
|
||||
self.worker_startup_check_interval_secs,
|
||||
))
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_worker(&self, worker_url: &str) {
|
||||
if self.dp_aware {
|
||||
// remove dp-aware workers in a prefix-matching fashion
|
||||
// without contacting the remote worker
|
||||
let mut removed_workers: Vec<String> = Vec::new();
|
||||
let worker_url_prefix = format!("{}@", worker_url);
|
||||
|
||||
// Find and remove all workers with matching prefix
|
||||
let all_workers = self.worker_registry.get_all();
|
||||
for w in all_workers.iter() {
|
||||
if w.url().starts_with(&worker_url_prefix) {
|
||||
// Get model_id before removing
|
||||
let model_id = w.model_id().to_string();
|
||||
|
||||
if self.worker_registry.remove_by_url(w.url()).is_some() {
|
||||
info!("Removed worker: {}", w.url());
|
||||
removed_workers.push(w.url().to_string());
|
||||
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", w.url());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
|
||||
for dp_url in removed_workers.iter() {
|
||||
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
|
||||
let model_id = worker.model_id();
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(model_id, dp_url);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Get the worker first to extract model_id
|
||||
let model_id = if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
|
||||
worker.model_id().to_string()
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
return;
|
||||
};
|
||||
|
||||
if self.worker_registry.remove_by_url(worker_url).is_some() {
|
||||
info!("Removed worker: {}", worker_url);
|
||||
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
}
|
||||
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(&model_id, worker_url);
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
@@ -1205,7 +742,7 @@ impl Router {
|
||||
|
||||
// Static version of get_worker_load for use in monitoring task
|
||||
async fn get_worker_load_static(
|
||||
client: &reqwest::Client,
|
||||
client: &Client,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Option<isize> {
|
||||
@@ -1281,25 +818,6 @@ impl Router {
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for Router {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Router::add_worker(self, worker_url, api_key).await
|
||||
}
|
||||
|
||||
fn remove_worker(&self, worker_url: &str) {
|
||||
Router::remove_worker(self, worker_url)
|
||||
}
|
||||
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
Router::get_worker_urls(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouterTrait for Router {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
@@ -1445,12 +963,19 @@ impl RouterTrait for Router {
|
||||
}
|
||||
|
||||
async fn flush_cache(&self) -> Response {
|
||||
// Get all worker URLs
|
||||
let worker_urls = self.get_worker_urls();
|
||||
// Get all workers
|
||||
let workers = self.worker_registry.get_all();
|
||||
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
|
||||
|
||||
// Send requests to all workers concurrently without headers
|
||||
let mut tasks = Vec::new();
|
||||
for worker_url in &worker_urls {
|
||||
// Get the worker's API key if available
|
||||
let api_key = self
|
||||
.worker_registry
|
||||
.get_by_url(worker_url)
|
||||
.and_then(|w| w.api_key().clone());
|
||||
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
@@ -1468,7 +993,13 @@ impl RouterTrait for Router {
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
let request_builder = self.client.post(format!("{}/flush_cache", worker_url));
|
||||
let mut request_builder = self.client.post(format!("{}/flush_cache", worker_url));
|
||||
|
||||
if let Some(key) = api_key {
|
||||
request_builder =
|
||||
request_builder.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
tasks.push(request_builder.send());
|
||||
}
|
||||
|
||||
@@ -1546,6 +1077,7 @@ impl RouterTrait for Router {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_regular_router() -> Router {
|
||||
@@ -1558,11 +1090,9 @@ mod tests {
|
||||
// Register test workers
|
||||
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
worker_registry.register(Arc::new(worker1));
|
||||
worker_registry.register(Arc::new(worker2));
|
||||
@@ -1571,13 +1101,9 @@ mod tests {
|
||||
Router {
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
client: Client::new(),
|
||||
retry_config: RetryConfig::default(),
|
||||
circuit_breaker_config: CircuitBreakerConfig::default(),
|
||||
_worker_loads: Arc::new(rx),
|
||||
_load_monitor_handle: None,
|
||||
}
|
||||
@@ -1586,7 +1112,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_router_get_worker_urls_regular() {
|
||||
let router = create_test_regular_router();
|
||||
let urls = router.get_worker_urls();
|
||||
let workers = router.worker_registry.get_all();
|
||||
let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
|
||||
|
||||
assert_eq!(urls.len(), 2);
|
||||
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
||||
@@ -1603,21 +1130,4 @@ mod tests {
|
||||
// DashMap doesn't guarantee order, so just check we get one of the workers
|
||||
assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wait_for_healthy_workers_empty_list() {
|
||||
// Empty list will return error immediately
|
||||
let result = Router::wait_for_healthy_workers(&[], 1, 1).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("no workers provided"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wait_for_healthy_workers_invalid_urls() {
|
||||
// This test will timeout quickly since the URLs are invalid
|
||||
let result =
|
||||
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Timeout"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user