diff --git a/sgl-router/py_test/e2e/conftest.py b/sgl-router/py_test/e2e/conftest.py index 460195816..4e0b241c2 100644 --- a/sgl-router/py_test/e2e/conftest.py +++ b/sgl-router/py_test/e2e/conftest.py @@ -85,6 +85,8 @@ def _popen_launch_router( str(prom_port), "--router-prometheus-host", "127.0.0.1", + "--router-log-level", + "warn", ] proc = subprocess.Popen(cmd) diff --git a/sgl-router/py_test/e2e/test_e2e_embeddings.py b/sgl-router/py_test/e2e/test_e2e_embeddings.py index 538d4df6f..1b852ef7e 100644 --- a/sgl-router/py_test/e2e/test_e2e_embeddings.py +++ b/sgl-router/py_test/e2e/test_e2e_embeddings.py @@ -1,9 +1,31 @@ +import time from types import SimpleNamespace import pytest import requests +def _wait_for_workers( + base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None +) -> None: + """Poll /workers endpoint until expected number of workers are registered.""" + start = time.perf_counter() + with requests.Session() as session: + while time.perf_counter() - start < timeout: + try: + r = session.get(f"{base_url}/workers", headers=headers, timeout=5) + if r.status_code == 200: + workers = r.json().get("workers", []) + if len(workers) >= expected_count: + return + except requests.RequestException: + pass + time.sleep(0.5) + raise TimeoutError( + f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s" + ) + + @pytest.mark.e2e def test_embeddings_basic( e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model @@ -12,8 +34,11 @@ def test_embeddings_basic( worker_url = e2e_primary_embedding_worker.url # Attach embedding worker to router-only instance - r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180) - r.raise_for_status() + r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180) + assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}" + + # Wait for worker to be registered + _wait_for_workers(base, expected_count=1, timeout=60.0) # Simple embedding request with two inputs payload = { diff --git a/sgl-router/py_test/e2e/test_pd_router.py b/sgl-router/py_test/e2e/test_pd_router.py index c0ca06c3c..eccbad4d1 100644 --- a/sgl-router/py_test/e2e/test_pd_router.py +++ b/sgl-router/py_test/e2e/test_pd_router.py @@ -198,6 +198,8 @@ def pd_cluster(e2e_model: str): "--policy", "round_robin", "--pd-disaggregation", + "--log-level", + "warn", ] for url, bport in prefill: cmd += ["--prefill", url, str(bport)] diff --git a/sgl-router/py_test/e2e/test_regular_router.py b/sgl-router/py_test/e2e/test_regular_router.py index e6c9bdf3b..effb39ef4 100644 --- a/sgl-router/py_test/e2e/test_regular_router.py +++ b/sgl-router/py_test/e2e/test_regular_router.py @@ -8,13 +8,39 @@ import requests from sglang.test.run_eval import run_eval +def _wait_for_workers( + base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None +) -> None: + """Poll /workers endpoint until expected number of workers are registered.""" + start = time.perf_counter() + with requests.Session() as session: + while time.perf_counter() - start < timeout: + try: + r = session.get(f"{base_url}/workers", headers=headers, timeout=5) + if r.status_code == 200: + workers = r.json().get("workers", []) + if len(workers) >= expected_count: + return + except requests.RequestException: + pass + time.sleep(0.5) + raise TimeoutError( + f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s" + ) + + @pytest.mark.e2e def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model): # Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance base = e2e_router_only_rr.url for w in e2e_two_workers_dp2: - r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180) - r.raise_for_status() + r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180) + assert ( + r.status_code == 202 + ), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}" + + # Wait for workers to be registered + _wait_for_workers(base, expected_count=2, timeout=60.0) args = SimpleNamespace( base_url=base, @@ -35,8 +61,13 @@ def test_genai_bench( """Attach a worker to the regular router and run a short genai-bench.""" base = e2e_router_only_rr.url for w in e2e_two_workers_dp2: - r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180) - r.raise_for_status() + r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180) + assert ( + r.status_code == 202 + ), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}" + + # Wait for workers to be registered + _wait_for_workers(base, expected_count=2, timeout=60.0) genai_bench_runner( router_url=base, @@ -59,8 +90,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_ base = e2e_router_only_rr.url worker_url = e2e_primary_worker.url - r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180) - r.raise_for_status() + r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180) + assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}" + + # Wait for worker to be registered + _wait_for_workers(base, expected_count=1, timeout=60.0) with requests.Session() as s: for i in range(8): @@ -77,8 +111,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_ r.raise_for_status() # Remove the worker - r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60) - r.raise_for_status() + from urllib.parse import quote + + encoded_url = quote(worker_url, safe="") + r = requests.delete(f"{base}/workers/{encoded_url}", timeout=60) + assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}" @pytest.mark.e2e @@ -86,8 +123,11 @@ def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_m base = e2e_router_only_rr.url worker = e2e_primary_worker - r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180) - r.raise_for_status() + r = requests.post(f"{base}/workers", json={"url": worker.url}, timeout=180) + assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}" + + # Wait for worker to be registered + _wait_for_workers(base, expected_count=1, timeout=60.0) def killer(): time.sleep(10) @@ -129,20 +169,30 @@ def test_dp_aware_worker_expansion_and_api_key( # Attach worker; router should expand to dp_size logical workers r = requests.post( - f"{router_url}/add_worker", - params={"url": worker_url, "api_key": api_key}, + f"{router_url}/workers", + json={"url": worker_url, "api_key": api_key}, headers={"Authorization": f"Bearer {api_key}"}, timeout=180, ) - r.raise_for_status() + assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}" + # Wait for workers to be registered and expanded + _wait_for_workers( + router_url, + expected_count=2, + timeout=60.0, + headers={"Authorization": f"Bearer {api_key}"}, + ) + + # Verify the expanded workers have correct URLs r = requests.get( - f"{router_url}/list_workers", + f"{router_url}/workers", headers={"Authorization": f"Bearer {api_key}"}, timeout=30, ) r.raise_for_status() - urls = r.json().get("urls", []) + workers = r.json().get("workers", []) + urls = [w["url"] for w in workers] assert len(urls) == 2 assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"} diff --git a/sgl-router/py_test/e2e_grpc/fixtures.py b/sgl-router/py_test/e2e_grpc/fixtures.py index f7047394c..869c70167 100644 --- a/sgl-router/py_test/e2e_grpc/fixtures.py +++ b/sgl-router/py_test/e2e_grpc/fixtures.py @@ -267,6 +267,8 @@ def popen_launch_workers_and_router( policy, "--model-path", model, + "--log-level", + "warn", ] # Add worker URLs diff --git a/sgl-router/py_test/fixtures/router_manager.py b/sgl-router/py_test/fixtures/router_manager.py index c536a0015..436e42989 100644 --- a/sgl-router/py_test/fixtures/router_manager.py +++ b/sgl-router/py_test/fixtures/router_manager.py @@ -133,19 +133,90 @@ class RouterManager: time.sleep(0.2) raise TimeoutError(f"Router at {base_url} did not become healthy") - def add_worker(self, base_url: str, worker_url: str) -> None: - r = requests.post(f"{base_url}/add_worker", params={"url": worker_url}) - assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}" + def add_worker(self, base_url: str, worker_url: str, timeout: float = 30.0) -> None: + r = requests.post(f"{base_url}/workers", json={"url": worker_url}) + assert ( + r.status_code == 202 + ), f"add_worker failed: {r.status_code} {r.text}" # ACCEPTED status - def remove_worker(self, base_url: str, worker_url: str) -> None: - r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url}) - assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}" + # Poll until worker is actually added and healthy + from urllib.parse import quote + + encoded_url = quote(worker_url, safe="") + start = time.time() + with requests.Session() as s: + while time.time() - start < timeout: + try: + r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2) + if r.status_code == 200: + data = r.json() + # Check if registration job failed + job_status = data.get("job_status") + if job_status and job_status.get("state") == "failed": + raise RuntimeError( + f"Worker registration failed: {job_status.get('message', 'Unknown error')}" + ) + # Check if worker is healthy and registered (not just in job queue) + if data.get("is_healthy", False): + return + # Worker not ready yet, continue polling + except requests.RequestException: + pass + time.sleep(0.1) + raise TimeoutError( + f"Worker {worker_url} was not added and healthy after {timeout}s" + ) + + def remove_worker( + self, base_url: str, worker_url: str, timeout: float = 30.0 + ) -> None: + # URL encode the worker_url for path parameter + from urllib.parse import quote + + encoded_url = quote(worker_url, safe="") + r = requests.delete(f"{base_url}/workers/{encoded_url}") + assert ( + r.status_code == 202 + ), f"remove_worker failed: {r.status_code} {r.text}" # ACCEPTED status + + # Poll until worker is actually removed (GET returns 404) or timeout + start = time.time() + last_status = None + with requests.Session() as s: + while time.time() - start < timeout: + try: + r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2) + if r.status_code == 404: + # Worker successfully removed + return + elif r.status_code == 200: + # Check if removal job failed + data = r.json() + job_status = data.get("job_status") + if job_status: + last_status = job_status + if job_status.get("state") == "failed": + raise RuntimeError( + f"Worker removal failed: {job_status.get('message', 'Unknown error')}" + ) + # Worker still being processed, continue polling + except requests.RequestException: + pass + time.sleep(0.1) + + # Provide detailed timeout error with last known status + error_msg = f"Worker {worker_url} was not removed after {timeout}s" + if last_status: + error_msg += f". Last job status: {last_status}" + raise TimeoutError(error_msg) def list_workers(self, base_url: str) -> list[str]: - r = requests.get(f"{base_url}/list_workers") + r = requests.get(f"{base_url}/workers") assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}" data = r.json() - return data.get("urls", []) + # Extract URLs from WorkerInfo objects + workers = data.get("workers", []) + return [w["url"] for w in workers] def stop_all(self): for p in self._children: diff --git a/sgl-router/py_test/integration/conftest.py b/sgl-router/py_test/integration/conftest.py index 21b9369d7..0dc7bc3c3 100644 --- a/sgl-router/py_test/integration/conftest.py +++ b/sgl-router/py_test/integration/conftest.py @@ -2,7 +2,7 @@ import os import subprocess import time from pathlib import Path -from typing import Dict, Iterable, List, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import pytest import requests @@ -84,7 +84,7 @@ def mock_workers(): procs: List[subprocess.Popen] = [] - def _start(n: int, args: List[str] | None = None): + def _start(n: int, args: Optional[List[str]] = None): args = args or [] new_procs: List[subprocess.Popen] = [] urls: List[str] = [] diff --git a/sgl-router/src/core/job_queue.rs b/sgl-router/src/core/job_queue.rs index 05b39e4c2..8335420e9 100644 --- a/sgl-router/src/core/job_queue.rs +++ b/sgl-router/src/core/job_queue.rs @@ -15,11 +15,9 @@ use tracing::{debug, error, info, warn}; use crate::{ config::{RouterConfig, RoutingMode}, - core::{ - workflow::{ - WorkflowContext, WorkflowEngine, WorkflowId, WorkflowInstanceId, WorkflowStatus, - }, - WorkerManager, + core::workflow::{ + steps::WorkerRemovalRequest, WorkflowContext, WorkflowEngine, WorkflowId, + WorkflowInstanceId, WorkflowStatus, }, metrics::RouterMetrics, protocols::worker_spec::{JobStatus, WorkerConfigRequest}, @@ -320,11 +318,29 @@ impl JobQueue { .await } Job::RemoveWorker { url } => { - let result = WorkerManager::remove_worker(url, context); + let engine = context + .workflow_engine + .get() + .ok_or_else(|| "Workflow engine not initialized".to_string())?; + + let instance_id = Self::start_worker_removal_workflow(engine, url, context).await?; + + debug!( + "Started worker removal workflow for {} (instance: {})", + url, instance_id + ); + + let timeout_duration = Duration::from_secs(30); + + let result = + Self::wait_for_workflow_completion(engine, instance_id, url, timeout_duration) + .await; + // Clean up job status when removing worker if let Some(queue) = context.worker_job_queue.get() { queue.remove_status(url); } + result } Job::InitializeWorkersFromConfig { router_config } => { @@ -424,6 +440,27 @@ impl JobQueue { .map_err(|e| format!("Failed to start worker registration workflow: {:?}", e)) } + /// Start worker removal workflow + async fn start_worker_removal_workflow( + engine: &Arc, + url: &str, + context: &Arc, + ) -> Result { + let removal_request = WorkerRemovalRequest { + url: url.to_string(), + dp_aware: context.router_config.dp_aware, + }; + + let mut workflow_context = WorkflowContext::new(WorkflowInstanceId::new()); + workflow_context.set("removal_request", removal_request); + workflow_context.set_arc("app_context", Arc::clone(context)); + + engine + .start_workflow(WorkflowId::new("worker_removal"), workflow_context) + .await + .map_err(|e| format!("Failed to start worker removal workflow: {:?}", e)) + } + /// Wait for workflow completion with adaptive polling async fn wait_for_workflow_completion( engine: &Arc, diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index f521f3839..bc42e0684 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -29,5 +29,5 @@ pub use worker::{ Worker, WorkerFactory, WorkerLoadGuard, WorkerType, }; pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder}; -pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager}; +pub use worker_manager::{LoadMonitor, WorkerManager}; pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats}; diff --git a/sgl-router/src/core/worker_manager.rs b/sgl-router/src/core/worker_manager.rs index 58d397da6..f2cd76a75 100644 --- a/sgl-router/src/core/worker_manager.rs +++ b/sgl-router/src/core/worker_manager.rs @@ -6,8 +6,6 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use futures::future; -use once_cell::sync::Lazy; -use serde::{Deserialize, Serialize}; use serde_json::Value; use tokio::{ sync::{watch, Mutex}, @@ -16,698 +14,15 @@ use tokio::{ use tracing::{debug, error, info, warn}; use crate::{ - config::types::{ - CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode, - HealthCheckConfig, RouterConfig, RoutingMode, - }, - core::{ - BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, - HealthConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType, - }, - grpc_client::SglangSchedulerClient, + core::{ConnectionMode, WorkerRegistry, WorkerType}, policies::PolicyRegistry, - protocols::worker_spec::{ - FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult, - }, - server::AppContext, + protocols::worker_spec::{FlushCacheResult, WorkerLoadInfo, WorkerLoadsResult}, }; -static HTTP_CLIENT: Lazy = Lazy::new(|| { - reqwest::Client::builder() - .timeout(Duration::from_secs(10)) - .build() - .expect("Failed to create HTTP client") -}); - -/// Server information returned from worker endpoints -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ServerInfo { - pub model_id: Option, - pub model_path: Option, - pub dp_size: Option, - pub version: Option, - pub max_batch_size: Option, - pub max_total_tokens: Option, - pub max_prefill_tokens: Option, - pub max_running_requests: Option, - pub max_num_reqs: Option, -} - -/// DP (Data Parallel) information for a worker -#[derive(Debug, Clone)] -pub struct DpInfo { - pub dp_size: usize, - pub model_id: String, -} - -/// Worker discovery results gathered from backend endpoints -struct WorkerDiscovery { - labels: HashMap, - grpc_client: Option, -} - -impl WorkerDiscovery { - fn new() -> Self { - Self { - labels: HashMap::new(), - grpc_client: None, - } - } -} - /// Unified worker management pub struct WorkerManager; impl WorkerManager { - /// Get server info from /get_server_info endpoint - pub async fn get_server_info(url: &str, api_key: Option<&str>) -> Result { - let base_url = url.trim_end_matches('/'); - - let server_info_url = format!("{}/get_server_info", base_url); - let mut req = HTTP_CLIENT.get(&server_info_url); - if let Some(key) = api_key { - req = req.bearer_auth(key); - } - - let response = req - .send() - .await - .map_err(|e| format!("Failed to connect to {}: {}", server_info_url, e))?; - - if !response.status().is_success() { - return Err(format!( - "Server returned status {} from {}", - response.status(), - server_info_url - )); - } - - let json = response - .json::() - .await - .map_err(|e| format!("Failed to parse response from {}: {}", server_info_url, e))?; - - info!( - "Successfully retrieved server info from {}", - server_info_url - ); - Self::parse_server_info(json) - } - - /// Get model info from /get_model_info endpoint - pub async fn get_model_info(url: &str, api_key: Option<&str>) -> Result { - let base_url = url.trim_end_matches('/'); - - let model_info_url = format!("{}/get_model_info", base_url); - let mut req = HTTP_CLIENT.get(&model_info_url); - if let Some(key) = api_key { - req = req.bearer_auth(key); - } - - let response = req - .send() - .await - .map_err(|e| format!("Failed to connect to {}: {}", model_info_url, e))?; - - if !response.status().is_success() { - return Err(format!( - "Server returned status {} from {}", - response.status(), - model_info_url - )); - } - - let json = response - .json::() - .await - .map_err(|e| format!("Failed to parse response from {}: {}", model_info_url, e))?; - - info!("Successfully retrieved model info from {}", model_info_url); - Ok(json) - } - - /// Get DP info for a worker URL - pub async fn get_dp_info(url: &str, api_key: Option<&str>) -> Result { - let info = Self::get_server_info(url, api_key).await?; - - let dp_size = info - .dp_size - .ok_or_else(|| format!("No dp_size in response from {}", url))?; - - let model_id = info - .model_id - .or_else(|| { - info.model_path - .and_then(|path| path.split('/').next_back().map(|s| s.to_string())) - }) - .unwrap_or_else(|| "unknown".to_string()); - - Ok(DpInfo { dp_size, model_id }) - } - - /// Generate DP-aware worker URLs - pub async fn get_dp_aware_urls( - base_urls: &[String], - api_key: Option<&str>, - ) -> Result, String> { - let mut dp_urls = Vec::new(); - - for base_url in base_urls { - match Self::get_dp_info(base_url, api_key).await { - Ok(dp_info) => { - info!( - "Discovered DP size {} for {} (model: {})", - dp_info.dp_size, base_url, dp_info.model_id - ); - - for rank in 0..dp_info.dp_size { - dp_urls.push(format!("{}@{}", base_url, rank)); - } - } - Err(e) => { - return Err(format!("Failed to get DP info from {}: {}", base_url, e)); - } - } - } - - Ok(dp_urls) - } - - /// Initialize workers from configuration at startup - pub async fn initialize_workers( - config: &RouterConfig, - registry: &Arc, - policy_registry: Option<&Arc>, - ) -> Result<(), String> { - info!("Starting worker initialization"); - - // Determine connection mode from config - let connection_mode = &config.connection_mode; - - match &config.mode { - RoutingMode::Regular { worker_urls } => match connection_mode { - ConfigConnectionMode::Http => { - Self::initialize_regular_workers( - worker_urls, - config, - registry, - policy_registry, - ) - .await?; - } - ConfigConnectionMode::Grpc => { - Self::initialize_grpc_workers(worker_urls, config, registry, policy_registry) - .await?; - } - }, - RoutingMode::PrefillDecode { - prefill_urls, - decode_urls, - .. - } => match connection_mode { - ConfigConnectionMode::Http => { - let prefill_entries: Vec<(&String, &Option)> = - prefill_urls.iter().map(|(url, port)| (url, port)).collect(); - - Self::initialize_prefill_workers( - &prefill_entries, - config, - registry, - policy_registry, - ) - .await?; - Self::initialize_decode_workers(decode_urls, config, registry, policy_registry) - .await?; - } - ConfigConnectionMode::Grpc => { - Self::initialize_grpc_pd_workers( - prefill_urls, - decode_urls, - config, - registry, - policy_registry, - ) - .await?; - } - }, - RoutingMode::OpenAI { .. } => { - info!("OpenAI routing mode - no workers to initialize"); - } - } - - Self::wait_for_healthy_workers( - registry, - config.worker_startup_timeout_secs, - config.health_check.check_interval_secs, - ) - .await?; - - info!("Worker initialization completed successfully"); - Ok(()) - } - - /// Initialize regular workers - async fn initialize_regular_workers( - urls: &[String], - config: &RouterConfig, - registry: &Arc, - policy_registry: Option<&Arc>, - ) -> Result<(), String> { - info!("Creating {} regular workers", urls.len()); - - let connection_mode = Self::convert_connection_mode(&config.connection_mode, urls.first()); - let circuit_breaker_config = - Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); - let health_config = Self::convert_health_config(&config.health_check); - - let mut registered_workers: HashMap>> = HashMap::new(); - - for url in urls { - if config.dp_aware { - match Self::get_dp_info(url, config.api_key.as_deref()).await { - Ok(dp_info) => { - info!( - "Discovered DP-aware worker {} with size {}", - url, dp_info.dp_size - ); - - for rank in 0..dp_info.dp_size { - let mut builder = - DPAwareWorkerBuilder::new(url.clone(), rank, dp_info.dp_size) - .worker_type(WorkerType::Regular) - .connection_mode(connection_mode.clone()) - .circuit_breaker_config(circuit_breaker_config.clone()) - .health_config(health_config.clone()); - - if let Some(ref key) = config.api_key { - builder = builder.api_key(key.clone()); - } - - let worker = Arc::new(builder.build()) as Arc; - - let model_id = worker.model_id(); - let worker_id = registry.register(Arc::clone(&worker)); - info!( - "Registered DP-aware worker {}@{} with ID {:?}", - url, rank, worker_id - ); - - registered_workers - .entry(model_id.to_string()) - .or_default() - .push(Arc::clone(&worker)); - - if let Some(policy_reg) = policy_registry { - policy_reg.on_worker_added(model_id, None); - } - } - } - Err(e) => { - return Err(format!( - "Failed to get DP info for worker {}: {}. DP-aware mode requires all workers to support DP.", - url, e - )); - } - } - } else { - let worker = Self::create_basic_worker( - url.clone(), - WorkerType::Regular, - connection_mode.clone(), - config.api_key.clone(), - None, - circuit_breaker_config.clone(), - health_config.clone(), - ) - .await; - Self::register_worker(worker, registry, &mut registered_workers, policy_registry); - } - } - - Self::initialize_cache_policies(®istered_workers, registry, policy_registry); - Ok(()) - } - - /// Initialize prefill workers for PD mode - async fn initialize_prefill_workers( - prefill_entries: &[(&String, &Option)], - config: &RouterConfig, - registry: &Arc, - policy_registry: Option<&Arc>, - ) -> Result<(), String> { - info!("Creating {} prefill workers", prefill_entries.len()); - - let connection_mode = Self::convert_connection_mode( - &config.connection_mode, - prefill_entries.first().map(|(url, _)| *url), - ); - let circuit_breaker_config = - Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); - let health_config = Self::convert_health_config(&config.health_check); - - let mut registered_workers: HashMap>> = HashMap::new(); - - // TODO: Add proper DP-aware support for prefill workers in PD mode - if config.dp_aware { - warn!("DP-aware mode is not yet supported for prefill workers in PD mode. Creating regular prefill workers instead."); - } - - for (url, bootstrap_port) in prefill_entries { - let worker_type = WorkerType::Prefill { - bootstrap_port: **bootstrap_port, - }; - let worker = Self::create_basic_worker( - (*url).clone(), - worker_type, - connection_mode.clone(), - config.api_key.clone(), - None, - circuit_breaker_config.clone(), - health_config.clone(), - ) - .await; - Self::register_worker(worker, registry, &mut registered_workers, policy_registry); - } - - if let Some(policy_reg) = policy_registry { - let all_prefill_workers: Vec> = registered_workers - .values() - .flat_map(|workers| workers.iter().cloned()) - .collect(); - policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &[]); - } - - Ok(()) - } - - /// Initialize decode workers for PD mode - async fn initialize_decode_workers( - urls: &[String], - config: &RouterConfig, - registry: &Arc, - policy_registry: Option<&Arc>, - ) -> Result<(), String> { - info!("Creating {} decode workers", urls.len()); - - let connection_mode = Self::convert_connection_mode(&config.connection_mode, urls.first()); - let circuit_breaker_config = - Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); - let health_config = Self::convert_health_config(&config.health_check); - - let mut registered_workers: HashMap>> = HashMap::new(); - - // TODO: Add proper DP-aware support for decode workers in PD mode - if config.dp_aware { - warn!("DP-aware mode is not yet supported for decode workers in PD mode. Creating regular decode workers instead."); - } - - for url in urls { - let worker = Self::create_basic_worker( - url.clone(), - WorkerType::Decode, - connection_mode.clone(), - config.api_key.clone(), - None, - circuit_breaker_config.clone(), - health_config.clone(), - ) - .await; - Self::register_worker(worker, registry, &mut registered_workers, policy_registry); - } - - if let Some(policy_reg) = policy_registry { - let all_decode_workers: Vec> = registered_workers - .values() - .flat_map(|workers| workers.iter().cloned()) - .collect(); - policy_reg.init_pd_cache_aware_policies(&[], &all_decode_workers); - } - - Ok(()) - } - - /// Initialize gRPC workers for regular mode - async fn initialize_grpc_workers( - urls: &[String], - config: &RouterConfig, - registry: &Arc, - policy_registry: Option<&Arc>, - ) -> Result<(), String> { - info!("Creating {} gRPC regular workers", urls.len()); - - let circuit_breaker_config = - Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); - let health_config = Self::convert_health_config(&config.health_check); - let connection_mode = ConnectionMode::Grpc { port: None }; - - let mut registered_workers: HashMap>> = HashMap::new(); - - for url in urls { - let worker = Self::create_basic_worker( - url.clone(), - WorkerType::Regular, - connection_mode.clone(), - config.api_key.clone(), - None, - circuit_breaker_config.clone(), - health_config.clone(), - ) - .await; - Self::register_worker(worker, registry, &mut registered_workers, policy_registry); - info!( - "Registered gRPC worker at {} (will connect on first use)", - url - ); - } - - Self::initialize_cache_policies(®istered_workers, registry, policy_registry); - Ok(()) - } - - /// Initialize gRPC PD (Prefill-Decode) workers - async fn initialize_grpc_pd_workers( - prefill_urls: &[(String, Option)], - decode_urls: &[String], - config: &RouterConfig, - registry: &Arc, - policy_registry: Option<&Arc>, - ) -> Result<(), String> { - info!( - "Creating {} gRPC prefill workers and {} gRPC decode workers", - prefill_urls.len(), - decode_urls.len() - ); - - let circuit_breaker_config = - Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); - let health_config = Self::convert_health_config(&config.health_check); - - let mut registered_prefill_workers: HashMap>> = HashMap::new(); - let mut registered_decode_workers: HashMap>> = HashMap::new(); - - for (url, bootstrap_port) in prefill_urls { - let worker_type = WorkerType::Prefill { - bootstrap_port: *bootstrap_port, - }; - let connection_mode = ConnectionMode::Grpc { - port: *bootstrap_port, - }; - - let worker = Self::create_basic_worker( - url.clone(), - worker_type, - connection_mode, - config.api_key.clone(), - None, - circuit_breaker_config.clone(), - health_config.clone(), - ) - .await; - Self::register_worker( - worker, - registry, - &mut registered_prefill_workers, - policy_registry, - ); - info!( - "Registered gRPC prefill worker at {} (will connect on first use)", - url - ); - } - - // Create decode workers - for url in decode_urls { - let connection_mode = ConnectionMode::Grpc { port: None }; - - let worker = Self::create_basic_worker( - url.clone(), - WorkerType::Decode, - connection_mode, - config.api_key.clone(), - None, - circuit_breaker_config.clone(), - health_config.clone(), - ) - .await; - Self::register_worker( - worker, - registry, - &mut registered_decode_workers, - policy_registry, - ); - info!( - "Registered gRPC decode worker at {} (will connect on first use)", - url - ); - } - - if let Some(policy_reg) = policy_registry { - let all_prefill_workers: Vec> = registered_prefill_workers - .values() - .flat_map(|workers| workers.iter().cloned()) - .collect(); - let all_decode_workers: Vec> = registered_decode_workers - .values() - .flat_map(|workers| workers.iter().cloned()) - .collect(); - policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &all_decode_workers); - } - - Ok(()) - } - - /// Add a worker from a configuration request - /// - /// Registers worker immediately with healthy=false, returns worker for async validation - pub async fn add_worker_from_config( - config: &WorkerConfigRequest, - context: &AppContext, - ) -> Result, String> { - // Check if worker already exists - if context.worker_registry.get_by_url(&config.url).is_some() { - return Err(format!("Worker {} already exists", config.url)); - } - let mut labels = config.labels.clone(); - - if let Some(model_id) = &config.model_id { - labels.insert("model_id".to_string(), model_id.clone()); - } - if let Some(priority) = config.priority { - labels.insert("priority".to_string(), priority.to_string()); - } - if let Some(cost) = config.cost { - labels.insert("cost".to_string(), cost.to_string()); - } - if let Some(ref tokenizer_path) = config.tokenizer_path { - labels.insert("tokenizer_path".to_string(), tokenizer_path.clone()); - } - if let Some(ref reasoning_parser) = config.reasoning_parser { - labels.insert("reasoning_parser".to_string(), reasoning_parser.clone()); - } - if let Some(ref tool_parser) = config.tool_parser { - labels.insert("tool_parser".to_string(), tool_parser.clone()); - } - if let Some(ref chat_template) = config.chat_template { - labels.insert("chat_template".to_string(), chat_template.clone()); - } - - let worker_type = config - .worker_type - .as_ref() - .map(|t| match t.as_str() { - "prefill" => WorkerType::Prefill { - bootstrap_port: config.bootstrap_port, - }, - "decode" => WorkerType::Decode, - _ => WorkerType::Regular, - }) - .unwrap_or(WorkerType::Regular); - - let connection_mode = if config.url.starts_with("grpc://") { - ConnectionMode::Grpc { port: None } - } else { - ConnectionMode::Http - }; - - let circuit_breaker_config = Self::convert_circuit_breaker_config( - &context.router_config.effective_circuit_breaker_config(), - ); - let health_config = Self::convert_health_config(&context.router_config.health_check); - - // Create and register worker (starts with healthy=false) - let worker = Self::create_basic_worker( - config.url.clone(), - worker_type, - connection_mode, - config.api_key.clone(), - Some(labels.clone()), - circuit_breaker_config, - health_config, - ) - .await; - - worker.set_healthy(false); - context.worker_registry.register(worker.clone()); - - let policy_hint = labels.get("policy").map(|s| s.as_str()); - let model_id = worker.model_id().to_string(); - context - .policy_registry - .on_worker_added(&model_id, policy_hint); - - info!("Registered worker {} (initializing)", config.url); - - // Return worker for async validation - Ok(worker) - } - - /// Validate and activate a worker (for async validation after registration) - pub async fn validate_and_activate_worker( - worker: &Arc, - context: &AppContext, - ) -> Result { - let url = worker.url(); - - // Perform health validation - WorkerFactory::validate_health(url, context.router_config.worker_startup_timeout_secs) - .await - .map_err(|e| format!("Health check failed for {}: {}", url, e))?; - - // Mark as healthy - worker.set_healthy(true); - - info!("Worker {} validated and activated", url); - - Ok(format!("Worker {} is now healthy", url)) - } - - /// Add a worker from URL (legacy endpoint) - pub async fn add_worker( - url: &str, - api_key: &Option, - context: &AppContext, - ) -> Result { - Self::add_worker_internal( - url, - WorkerType::Regular, - ConnectionMode::Http, - api_key.clone(), - None, - None, - context, - ) - .await - } - - /// Remove a worker - pub fn remove_worker(url: &str, context: &AppContext) -> Result { - if context.router_config.dp_aware { - Self::remove_dp_aware_workers(url, context) - } else { - Self::remove_single_worker(url, context) - } - } - pub fn get_worker_urls(registry: &Arc) -> Vec { registry .get_all() @@ -716,757 +31,6 @@ impl WorkerManager { .collect() } - /// Internal method to add a worker with all parameters - async fn add_worker_internal( - worker_url: &str, - worker_type: WorkerType, - connection_mode: ConnectionMode, - api_key: Option, - labels: Option>, - policy_hint: Option<&str>, - context: &AppContext, - ) -> Result { - WorkerFactory::validate_health( - worker_url, - context.router_config.worker_startup_timeout_secs, - ) - .await - .map_err(|e| format!("Health check failed: {}", e))?; - - let circuit_breaker_config = Self::convert_circuit_breaker_config( - &context.router_config.effective_circuit_breaker_config(), - ); - let health_config = Self::convert_health_config(&context.router_config.health_check); - - if context.router_config.dp_aware { - let dp_urls = Self::get_dp_aware_urls( - &[worker_url.to_string()], - context.router_config.api_key.as_deref(), - ) - .await?; - let mut workers_added = 0; - let mut model_workers: HashMap>> = HashMap::new(); - - let dp_size_for_base = dp_urls.len(); - - for (rank, dp_url) in dp_urls.iter().enumerate() { - if context.worker_registry.get_by_url(dp_url).is_some() { - info!("Worker {} already exists, skipping", dp_url); - continue; - } - - let base_url = dp_url.split('@').next().unwrap().to_string(); - let mut builder = DPAwareWorkerBuilder::new(base_url, rank, dp_size_for_base) - .worker_type(worker_type.clone()) - .connection_mode(connection_mode.clone()) - .circuit_breaker_config(circuit_breaker_config.clone()) - .health_config(health_config.clone()); - - if let Some(ref key) = api_key { - builder = builder.api_key(key.clone()); - } - - if let Some(ref worker_labels) = labels { - builder = builder.labels(worker_labels.clone()); - } - - let worker = Arc::new(builder.build()) as Arc; - - let model_id = worker.model_id().to_string(); - context.worker_registry.register(worker.clone()); - workers_added += 1; - - model_workers - .entry(model_id.clone()) - .or_default() - .push(worker); - - context - .policy_registry - .on_worker_added(&model_id, policy_hint); - } - - for model_id in model_workers.keys() { - let all_model_workers = context.worker_registry.get_by_model_fast(model_id); - if let Some(policy) = context.policy_registry.get_policy(model_id) { - if policy.name() == "cache_aware" { - context - .policy_registry - .init_cache_aware_policy(model_id, &all_model_workers); - } - } - } - - if workers_added == 0 { - Ok(format!("All DP workers already exist for {}", worker_url)) - } else { - Ok(format!( - "Added {} DP-aware workers for {}", - workers_added, worker_url - )) - } - } else { - if context.worker_registry.get_by_url(worker_url).is_some() { - return Err(format!("Worker {} already exists", worker_url)); - } - - let worker = Self::create_basic_worker( - worker_url.to_string(), - worker_type, - connection_mode, - api_key, - labels, - circuit_breaker_config, - health_config, - ) - .await; - - let model_id = worker.model_id().to_string(); - context.worker_registry.register(worker.clone()); - context - .policy_registry - .on_worker_added(&model_id, policy_hint); - - let workers = context.worker_registry.get_by_model_fast(&model_id); - if let Some(policy) = context.policy_registry.get_policy(&model_id) { - if policy.name() == "cache_aware" { - context - .policy_registry - .init_cache_aware_policy(&model_id, &workers); - } - } - - Ok(format!("Worker {} added successfully", worker_url)) - } - } - - /// Remove a single worker - fn remove_single_worker(worker_url: &str, context: &AppContext) -> Result { - let worker = context - .worker_registry - .get_by_url(worker_url) - .ok_or_else(|| format!("Worker {} not found", worker_url))?; - let model_id = worker.model_id().to_string(); - - context - .policy_registry - .remove_worker_from_cache_aware(&model_id, worker_url); - context.worker_registry.remove_by_url(worker_url); - context.policy_registry.on_worker_removed(&model_id); - - let remaining_workers = context.worker_registry.get_by_model_fast(&model_id); - if let Some(policy) = context.policy_registry.get_policy(&model_id) { - if policy.name() == "cache_aware" && !remaining_workers.is_empty() { - context - .policy_registry - .init_cache_aware_policy(&model_id, &remaining_workers); - } - } - - Ok(format!("Worker {} removed successfully", worker_url)) - } - - /// Remove DP-aware workers with prefix matching - fn remove_dp_aware_workers(worker_url: &str, context: &AppContext) -> Result { - let worker_url_prefix = format!("{}@", worker_url); - let mut removed_workers = Vec::new(); - let mut affected_models = std::collections::HashSet::new(); - - let all_workers = context.worker_registry.get_all(); - for worker in all_workers.iter() { - if worker.url().starts_with(&worker_url_prefix) { - let model_id = worker.model_id().to_string(); - affected_models.insert(model_id.clone()); - - context - .policy_registry - .remove_worker_from_cache_aware(&model_id, worker.url()); - - if context - .worker_registry - .remove_by_url(worker.url()) - .is_some() - { - removed_workers.push(worker.url().to_string()); - context.policy_registry.on_worker_removed(&model_id); - } - } - } - - for model_id in affected_models { - let remaining_workers = context.worker_registry.get_by_model_fast(&model_id); - if let Some(policy) = context.policy_registry.get_policy(&model_id) { - if policy.name() == "cache_aware" && !remaining_workers.is_empty() { - context - .policy_registry - .init_cache_aware_policy(&model_id, &remaining_workers); - } - } - } - - if removed_workers.is_empty() { - Err(format!( - "No workers found with prefix {}", - worker_url_prefix - )) - } else { - Ok(format!( - "Removed {} DP-aware workers: {:?}", - removed_workers.len(), - removed_workers - )) - } - } - - /// Create a basic worker - async fn create_basic_worker( - url: String, - worker_type: WorkerType, - connection_mode: ConnectionMode, - api_key: Option, - labels: Option>, - circuit_breaker_config: CircuitBreakerConfig, - health_config: HealthConfig, - ) -> Arc { - let discovery = - Self::discover_worker_metadata(&url, &connection_mode, api_key.as_deref()).await; - - let mut final_labels = discovery.labels; - if let Some(custom_labels) = labels { - for (key, value) in custom_labels { - final_labels.insert(key, value); - } - } - - let mut builder = BasicWorkerBuilder::new(url) - .worker_type(worker_type) - .connection_mode(connection_mode) - .circuit_breaker_config(circuit_breaker_config) - .health_config(health_config); - - if let Some(key) = api_key { - builder = builder.api_key(key); - } - - if !final_labels.is_empty() { - builder = builder.labels(final_labels); - } - - if let Some(client) = discovery.grpc_client { - builder = builder.grpc_client(client); - } - - let worker = builder.build(); - Arc::new(worker) as Arc - } - - /// Register a worker and update policies - fn register_worker( - worker: Arc, - registry: &Arc, - registered_workers: &mut HashMap>>, - policy_registry: Option<&Arc>, - ) { - let model_id = worker.model_id(); - let url = worker.url(); - let worker_id = registry.register(Arc::clone(&worker)); - info!("Registered worker {} with ID {:?}", url, worker_id); - - registered_workers - .entry(model_id.to_string()) - .or_default() - .push(Arc::clone(&worker)); - - if let Some(policy_reg) = policy_registry { - policy_reg.on_worker_added(model_id, None); - } - } - - /// Initialize cache-aware policies - fn initialize_cache_policies( - registered_workers: &HashMap>>, - registry: &Arc, - policy_registry: Option<&Arc>, - ) { - if let Some(policy_reg) = policy_registry { - for model_id in registered_workers.keys() { - let all_model_workers = registry.get_by_model_fast(model_id); - if let Some(policy) = policy_reg.get_policy(model_id) { - if policy.name() == "cache_aware" { - policy_reg.init_cache_aware_policy(model_id, &all_model_workers); - } - } - } - } - } - - /// Wait for workers to become healthy - async fn wait_for_healthy_workers( - registry: &Arc, - timeout_secs: u64, - check_interval_secs: u64, - ) -> Result<(), String> { - let timeout = Duration::from_secs(timeout_secs); - let check_interval = Duration::from_secs(check_interval_secs); - let start_time = std::time::Instant::now(); - - info!( - "Waiting for workers to become healthy (timeout: {}s)", - timeout_secs - ); - - let workers = registry.get_all(); - if workers.is_empty() { - info!("No workers to wait for, continuing"); - return Ok(()); - } - - // Mark all workers as unhealthy initially - info!( - "Marking {} workers as unhealthy before health checks", - workers.len() - ); - for worker in &workers { - worker.set_healthy(false); - } - - loop { - // 1. Filter unhealthy workers - let workers = registry.get_all(); - let unhealthy_workers: Vec<_> = workers - .iter() - .filter(|w| !w.is_healthy()) - .cloned() - .collect(); - - // 2. If all workers are healthy, return immediately - if unhealthy_workers.is_empty() { - let healthy_urls: Vec<_> = workers.iter().map(|w| w.url().to_string()).collect(); - info!( - "All {} workers are healthy: {:?}", - workers.len(), - healthy_urls - ); - return Ok(()); - } - - // Check timeout - if start_time.elapsed() > timeout { - let healthy_workers: Vec<_> = workers - .iter() - .filter(|w| w.is_healthy()) - .map(|w| w.url().to_string()) - .collect(); - let unhealthy_urls: Vec<_> = unhealthy_workers - .iter() - .map(|w| w.url().to_string()) - .collect(); - - error!( - "Workers failed to become healthy after {}s. Unhealthy: {:?}, Healthy: {:?}", - timeout_secs, unhealthy_urls, healthy_workers - ); - return Err(format!( - "Workers failed to become healthy after {}s. Unhealthy: {:?}", - timeout_secs, unhealthy_urls - )); - } - - let unhealthy_urls: Vec<_> = unhealthy_workers - .iter() - .map(|w| w.url().to_string()) - .collect(); - - info!( - "Waiting for {} workers to become healthy. Unhealthy: {:?}", - unhealthy_workers.len(), - unhealthy_urls - ); - - // 3. Check health of all unhealthy workers in parallel - let health_check_futures: Vec<_> = unhealthy_workers - .iter() - .map(|worker| { - let w = worker.clone(); - let url = worker.url().to_string(); - async move { - match w.check_health_async().await { - Ok(_) => { - w.set_healthy(true); - debug!("Worker {} now healthy", url); - } - Err(e) => { - debug!("Worker {} health check failed: {}", url, e); - } - } - } - }) - .collect(); - - future::join_all(health_check_futures).await; - - // 4. Check if all workers are now healthy after health checks - let still_unhealthy: Vec<_> = workers.iter().filter(|w| !w.is_healthy()).collect(); - - // 5. If all workers are now healthy, return immediately without sleeping - if still_unhealthy.is_empty() { - let healthy_urls: Vec<_> = workers.iter().map(|w| w.url().to_string()).collect(); - info!( - "All {} workers are healthy: {:?}", - workers.len(), - healthy_urls - ); - return Ok(()); - } - - // 6. Otherwise, sleep before next iteration - tokio::time::sleep(check_interval).await; - } - } - - /// Gather worker metadata directly from the backend before registration. - async fn discover_worker_metadata( - url: &str, - connection_mode: &ConnectionMode, - api_key: Option<&str>, - ) -> WorkerDiscovery { - match connection_mode { - ConnectionMode::Http => Self::discover_http_metadata(url, api_key).await, - ConnectionMode::Grpc { .. } => Self::discover_grpc_metadata(url).await, - } - } - - async fn discover_http_metadata(url: &str, api_key: Option<&str>) -> WorkerDiscovery { - let mut discovery = WorkerDiscovery::new(); - - match Self::get_model_info(url, api_key).await { - Ok(model_info) => { - if let Some(model_path) = model_info.get("model_path").and_then(|v| v.as_str()) { - if !model_path.is_empty() { - discovery - .labels - .insert("model_path".to_string(), model_path.to_string()); - } - } - if let Some(tokenizer_path) = - model_info.get("tokenizer_path").and_then(|v| v.as_str()) - { - if !tokenizer_path.is_empty() { - discovery - .labels - .insert("tokenizer_path".to_string(), tokenizer_path.to_string()); - } - } - if let Some(served_model_name) = - model_info.get("served_model_name").and_then(|v| v.as_str()) - { - if !served_model_name.is_empty() { - discovery.labels.insert( - "served_model_name".to_string(), - served_model_name.to_string(), - ); - } - } - if let Some(weight_version) = - model_info.get("weight_version").and_then(|v| v.as_str()) - { - if !weight_version.is_empty() { - discovery - .labels - .insert("weight_version".to_string(), weight_version.to_string()); - } - } - if let Some(model_type) = model_info.get("model_type").and_then(|v| v.as_str()) { - if !model_type.is_empty() { - discovery - .labels - .insert("model_type".to_string(), model_type.to_string()); - } - } - if let Some(is_generation) = - model_info.get("is_generation").and_then(|v| v.as_bool()) - { - discovery - .labels - .insert("is_generation".to_string(), is_generation.to_string()); - } - if let Some(preferred_sampling_params) = model_info - .get("preferred_sampling_params") - .and_then(|v| v.as_str()) - { - if !preferred_sampling_params.is_empty() { - discovery.labels.insert( - "preferred_sampling_params".to_string(), - preferred_sampling_params.to_string(), - ); - } - } - if let Some(max_context_length) = model_info - .get("max_context_length") - .and_then(|v| v.as_i64()) - { - discovery.labels.insert( - "max_context_length".to_string(), - max_context_length.to_string(), - ); - } - if let Some(max_req_input_len) = - model_info.get("max_req_input_len").and_then(|v| v.as_i64()) - { - discovery.labels.insert( - "max_req_input_len".to_string(), - max_req_input_len.to_string(), - ); - } - } - Err(e) => { - warn!( - "Worker discovery: failed to fetch HTTP model info from {}: {}", - url, e - ); - } - } - - match Self::get_server_info(url, api_key).await { - Ok(server_info) => { - if let Some(model_id) = server_info.model_id { - if !model_id.is_empty() { - discovery.labels.insert("model_id".to_string(), model_id); - } - } - if let Some(model_path) = server_info.model_path { - if !model_path.is_empty() { - discovery - .labels - .insert("model_path".to_string(), model_path); - } - } - if let Some(version) = server_info.version { - if !version.is_empty() { - discovery - .labels - .insert("server_version".to_string(), version); - } - } - if let Some(max_total_tokens) = server_info.max_total_tokens { - discovery - .labels - .insert("max_total_tokens".to_string(), max_total_tokens.to_string()); - } - if let Some(max_prefill_tokens) = server_info.max_prefill_tokens { - discovery.labels.insert( - "max_prefill_tokens".to_string(), - max_prefill_tokens.to_string(), - ); - } - if let Some(max_running_requests) = server_info.max_running_requests { - discovery.labels.insert( - "max_running_requests".to_string(), - max_running_requests.to_string(), - ); - } - } - Err(e) => { - warn!( - "Worker discovery: failed to fetch HTTP server info from {}: {}", - url, e - ); - } - } - - Self::finalize_model_id(&mut discovery.labels); - - discovery - } - - async fn discover_grpc_metadata(url: &str) -> WorkerDiscovery { - let mut discovery = WorkerDiscovery::new(); - - let client = match SglangSchedulerClient::connect(url).await { - Ok(client) => client, - Err(e) => { - warn!( - "Worker discovery: failed to connect to gRPC worker {}: {}", - url, e - ); - return discovery; - } - }; - - match client.get_model_info().await { - Ok(model_info) => { - if !model_info.model_path.is_empty() { - discovery - .labels - .insert("model_path".to_string(), model_info.model_path.clone()); - } - if !model_info.tokenizer_path.is_empty() { - discovery.labels.insert( - "tokenizer_path".to_string(), - model_info.tokenizer_path.clone(), - ); - } - if !model_info.served_model_name.is_empty() { - discovery.labels.insert( - "served_model_name".to_string(), - model_info.served_model_name.clone(), - ); - discovery - .labels - .insert("model_id".to_string(), model_info.served_model_name); - } - if !model_info.weight_version.is_empty() { - discovery.labels.insert( - "weight_version".to_string(), - model_info.weight_version.clone(), - ); - } - if !model_info.model_type.is_empty() { - discovery - .labels - .insert("model_type".to_string(), model_info.model_type.clone()); - } - if !model_info.preferred_sampling_params.is_empty() { - discovery.labels.insert( - "preferred_sampling_params".to_string(), - model_info.preferred_sampling_params.clone(), - ); - } - discovery.labels.insert( - "is_generation".to_string(), - model_info.is_generation.to_string(), - ); - if model_info.max_context_length > 0 { - discovery.labels.insert( - "max_context_length".to_string(), - model_info.max_context_length.to_string(), - ); - } - if model_info.max_req_input_len > 0 { - discovery.labels.insert( - "max_req_input_len".to_string(), - model_info.max_req_input_len.to_string(), - ); - } - if model_info.vocab_size > 0 { - discovery - .labels - .insert("vocab_size".to_string(), model_info.vocab_size.to_string()); - } - } - Err(e) => { - warn!( - "Worker discovery: failed to fetch gRPC model info from {}: {}", - url, e - ); - } - } - - if !discovery.labels.contains_key("model_id") { - Self::finalize_model_id(&mut discovery.labels); - } - - discovery.grpc_client = Some(client); - discovery - } - - fn finalize_model_id(labels: &mut HashMap) { - let has_model_id = labels - .get("model_id") - .map(|v| !v.trim().is_empty()) - .unwrap_or(false); - if has_model_id { - return; - } - - if let Some(served_name) = labels.get("served_model_name") { - if !served_name.trim().is_empty() { - labels.insert("model_id".to_string(), served_name.clone()); - return; - } - } - - if let Some(model_path) = labels.get("model_path") { - if !model_path.trim().is_empty() { - labels.insert("model_id".to_string(), model_path.clone()); - } - } - } - - /// Parse server info from JSON response - fn parse_server_info(json: Value) -> Result { - Ok(ServerInfo { - model_id: json - .get("model_id") - .and_then(|v| v.as_str()) - .map(String::from) - .or_else(|| json.get("model").and_then(|v| v.as_str()).map(String::from)), - model_path: json - .get("model_path") - .and_then(|v| v.as_str()) - .map(String::from), - dp_size: json - .get("dp_size") - .and_then(|v| v.as_u64()) - .map(|v| v as usize), - version: json - .get("version") - .and_then(|v| v.as_str()) - .map(String::from), - max_batch_size: json - .get("max_batch_size") - .and_then(|v| v.as_u64()) - .map(|v| v as usize), - max_total_tokens: json - .get("max_total_tokens") - .and_then(|v| v.as_u64()) - .map(|v| v as usize), - max_prefill_tokens: json - .get("max_prefill_tokens") - .and_then(|v| v.as_u64()) - .map(|v| v as usize), - max_running_requests: json - .get("max_running_requests") - .and_then(|v| v.as_u64()) - .map(|v| v as usize), - max_num_reqs: json - .get("max_num_reqs") - .and_then(|v| v.as_u64()) - .map(|v| v as usize), - }) - } - - /// Convert config connection mode to core connection mode - fn convert_connection_mode( - config_mode: &ConfigConnectionMode, - _sample_url: Option<&String>, - ) -> ConnectionMode { - match config_mode { - ConfigConnectionMode::Http => ConnectionMode::Http, - ConfigConnectionMode::Grpc => ConnectionMode::Grpc { port: None }, - } - } - - /// Convert config circuit breaker to core circuit breaker - fn convert_circuit_breaker_config(config: &ConfigCircuitBreakerConfig) -> CircuitBreakerConfig { - CircuitBreakerConfig { - failure_threshold: config.failure_threshold, - success_threshold: config.success_threshold, - timeout_duration: Duration::from_secs(config.timeout_duration_secs), - window_duration: Duration::from_secs(config.window_duration_secs), - } - } - - /// Convert config health check to core health config - fn convert_health_config(config: &HealthCheckConfig) -> HealthConfig { - HealthConfig { - timeout_secs: config.timeout_secs, - check_interval_secs: config.check_interval_secs, - endpoint: config.endpoint.clone(), - failure_threshold: config.failure_threshold, - success_threshold: config.success_threshold, - } - } /// Flush cache on all workers /// /// Sends a POST request to /flush_cache endpoint on all HTTP workers. @@ -1804,69 +368,3 @@ impl Drop for LoadMonitor { } } } - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use super::*; - - #[test] - fn test_parse_server_info() { - let json = serde_json::json!({ - "model_id": "llama-3", - "model_path": "/models/llama-3", - "dp_size": 4, - "version": "0.1.0" - }); - - let info = WorkerManager::parse_server_info(json).unwrap(); - assert_eq!(info.model_id, Some("llama-3".to_string())); - assert_eq!(info.dp_size, Some(4)); - } - - #[test] - fn test_parse_server_info_with_fallback() { - let json = serde_json::json!({ - "model": "gpt-4", - "dp_size": 2 - }); - - let info = WorkerManager::parse_server_info(json).unwrap(); - assert_eq!(info.model_id, Some("gpt-4".to_string())); - assert_eq!(info.dp_size, Some(2)); - } - - #[test] - fn test_parse_server_info_minimal() { - let json = serde_json::json!({}); - let info = WorkerManager::parse_server_info(json).unwrap(); - assert_eq!(info.model_id, None); - assert_eq!(info.dp_size, None); - } - - #[test] - fn test_finalize_model_id_prefers_existing() { - let mut labels = HashMap::new(); - labels.insert("model_id".to_string(), "manual-id".to_string()); - labels.insert("served_model_name".to_string(), "auto-id".to_string()); - WorkerManager::finalize_model_id(&mut labels); - assert_eq!(labels.get("model_id").unwrap(), "manual-id"); - } - - #[test] - fn test_finalize_model_id_prefers_served_name() { - let mut labels = HashMap::new(); - labels.insert("served_model_name".to_string(), "served-name".to_string()); - WorkerManager::finalize_model_id(&mut labels); - assert_eq!(labels.get("model_id").unwrap(), "served-name"); - } - - #[test] - fn test_finalize_model_id_falls_back_to_path() { - let mut labels = HashMap::new(); - labels.insert("model_path".to_string(), "/models/alpha".to_string()); - WorkerManager::finalize_model_id(&mut labels); - assert_eq!(labels.get("model_id").unwrap(), "/models/alpha"); - } -} diff --git a/sgl-router/src/core/workflow/mod.rs b/sgl-router/src/core/workflow/mod.rs index 8e840b2ea..f1c3293a2 100644 --- a/sgl-router/src/core/workflow/mod.rs +++ b/sgl-router/src/core/workflow/mod.rs @@ -14,5 +14,5 @@ pub use engine::WorkflowEngine; pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent}; pub use executor::{FunctionStep, StepExecutor}; pub use state::WorkflowStateStore; -pub use steps::create_worker_registration_workflow; +pub use steps::{create_worker_registration_workflow, create_worker_removal_workflow}; pub use types::*; diff --git a/sgl-router/src/core/workflow/steps/mod.rs b/sgl-router/src/core/workflow/steps/mod.rs index 7f4dc8252..9de153023 100644 --- a/sgl-router/src/core/workflow/steps/mod.rs +++ b/sgl-router/src/core/workflow/steps/mod.rs @@ -2,11 +2,17 @@ //! //! This module contains concrete step implementations for various workflows: //! - Worker registration and activation +//! - Worker removal //! - Future: Tokenizer fetching, LoRA updates, etc. pub mod worker_registration; +pub mod worker_removal; pub use worker_registration::{ create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep, DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep, }; +pub use worker_removal::{ + create_worker_removal_workflow, FindWorkersToRemoveStep, RemoveFromPolicyRegistryStep, + RemoveFromWorkerRegistryStep, UpdateRemainingPoliciesStep, WorkerRemovalRequest, +}; diff --git a/sgl-router/src/core/workflow/steps/worker_registration.rs b/sgl-router/src/core/workflow/steps/worker_registration.rs index 894326c19..e3ce5d0f2 100644 --- a/sgl-router/src/core/workflow/steps/worker_registration.rs +++ b/sgl-router/src/core/workflow/steps/worker_registration.rs @@ -16,13 +16,14 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use async_trait::async_trait; use once_cell::sync::Lazy; use reqwest::Client; +use serde::{Deserialize, Serialize}; use serde_json::Value; use tracing::{debug, info, warn}; use crate::{ core::{ workflow::*, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, - DPAwareWorkerBuilder, DpInfo, HealthConfig, Worker, WorkerManager, WorkerType, + DPAwareWorkerBuilder, HealthConfig, Worker, WorkerType, }, grpc_client::SglangSchedulerClient, protocols::worker_spec::WorkerConfigRequest, @@ -37,6 +38,82 @@ static HTTP_CLIENT: Lazy = Lazy::new(|| { .expect("Failed to create HTTP client") }); +/// Server information returned from worker endpoints +#[derive(Debug, Clone, Deserialize, Serialize)] +struct ServerInfo { + #[serde(alias = "model")] + model_id: Option, + model_path: Option, + dp_size: Option, + version: Option, + max_batch_size: Option, + max_total_tokens: Option, + max_prefill_tokens: Option, + max_running_requests: Option, + max_num_reqs: Option, +} + +#[derive(Debug, Clone)] +pub struct DpInfo { + pub dp_size: usize, + pub model_id: String, +} + +/// Parse server info from JSON response using serde +fn parse_server_info(json: Value) -> Result { + serde_json::from_value(json).map_err(|e| format!("Failed to parse server info: {}", e)) +} + +/// Get server info from /get_server_info endpoint +async fn get_server_info(url: &str, api_key: Option<&str>) -> Result { + let base_url = url.trim_end_matches('/'); + let server_info_url = format!("{}/get_server_info", base_url); + + let mut req = HTTP_CLIENT.get(&server_info_url); + if let Some(key) = api_key { + req = req.bearer_auth(key); + } + + let response = req + .send() + .await + .map_err(|e| format!("Failed to connect to {}: {}", server_info_url, e))?; + + if !response.status().is_success() { + return Err(format!( + "Server returned status {} from {}", + response.status(), + server_info_url + )); + } + + let json = response + .json::() + .await + .map_err(|e| format!("Failed to parse response from {}: {}", server_info_url, e))?; + + parse_server_info(json) +} + +/// Get DP info for a worker URL +async fn get_dp_info(url: &str, api_key: Option<&str>) -> Result { + let info = get_server_info(url, api_key).await?; + + let dp_size = info + .dp_size + .ok_or_else(|| format!("No dp_size in response from {}", url))?; + + let model_id = info + .model_id + .or_else(|| { + info.model_path + .and_then(|path| path.split('/').next_back().map(|s| s.to_string())) + }) + .unwrap_or_else(|| "unknown".to_string()); + + Ok(DpInfo { dp_size, model_id }) +} + /// Helper: Strip protocol prefix from URL fn strip_protocol(url: &str) -> String { url.trim_start_matches("http://") @@ -83,49 +160,6 @@ async fn try_grpc_health_check(url: &str, timeout_secs: u64) -> Result<(), Strin Ok(()) } -/// Helper: Fetch HTTP metadata -async fn fetch_http_metadata( - url: &str, - api_key: Option<&str>, -) -> Result, String> { - let clean_url = strip_protocol(url); - let info_url = if clean_url.starts_with("http://") || clean_url.starts_with("https://") { - format!("{}/get_server_info", clean_url) - } else { - format!("http://{}/get_server_info", clean_url) - }; - - let mut request = HTTP_CLIENT.get(&info_url); - if let Some(key) = api_key { - request = request.header("Authorization", format!("Bearer {}", key)); - } - - let response = request - .send() - .await - .map_err(|e| format!("Failed to fetch HTTP metadata: {}", e))?; - - let server_info: Value = response - .json() - .await - .map_err(|e| format!("Failed to parse HTTP metadata: {}", e))?; - - let mut labels = HashMap::new(); - - if let Some(model_path) = server_info.get("model_path").and_then(|v| v.as_str()) { - if !model_path.is_empty() { - labels.insert("model_path".to_string(), model_path.to_string()); - } - } - if let Some(tokenizer_path) = server_info.get("tokenizer_path").and_then(|v| v.as_str()) { - if !tokenizer_path.is_empty() { - labels.insert("tokenizer_path".to_string(), tokenizer_path.to_string()); - } - } - - Ok(labels) -} - /// Helper: Fetch gRPC metadata async fn fetch_grpc_metadata(url: &str) -> Result, String> { let grpc_url = if url.starts_with("grpc://") { @@ -266,7 +300,18 @@ impl StepExecutor for DiscoverMetadataStep { let discovered_labels = match connection_mode.as_ref() { ConnectionMode::Http => { - fetch_http_metadata(&config.url, config.api_key.as_deref()).await + match get_server_info(&config.url, config.api_key.as_deref()).await { + Ok(server_info) => { + let mut labels = HashMap::new(); + if let Some(model_path) = server_info.model_path { + if !model_path.is_empty() { + labels.insert("model_path".to_string(), model_path); + } + } + Ok(labels) + } + Err(e) => Err(e), + } } ConnectionMode::Grpc { .. } => fetch_grpc_metadata(&config.url).await, } @@ -314,7 +359,7 @@ impl StepExecutor for DiscoverDPInfoStep { debug!("Discovering DP info for {} (DP-aware)", config.url); // Get DP info from worker - let dp_info = WorkerManager::get_dp_info(&config.url, config.api_key.as_deref()) + let dp_info = get_dp_info(&config.url, config.api_key.as_deref()) .await .map_err(|e| WorkflowError::StepFailed { step_id: StepId::new("discover_dp_info"), @@ -327,7 +372,7 @@ impl StepExecutor for DiscoverDPInfoStep { ); // Store DP info in context - context.set("dp_info", Arc::new(dp_info)); + context.set("dp_info", dp_info); Ok(StepResult::Success) } @@ -522,7 +567,7 @@ impl StepExecutor for CreateWorkerStep { } // Store workers (plural) and labels in context - context.set("workers", Arc::new(workers)); + context.set("workers", workers); context.set("labels", final_labels); Ok(StepResult::Success) @@ -595,7 +640,7 @@ impl StepExecutor for RegisterWorkerStep { ); } - context.set("worker_ids", Arc::new(worker_ids)); + context.set("worker_ids", worker_ids); Ok(StepResult::Success) } else { // Non-DP-aware path: Register single worker diff --git a/sgl-router/src/core/workflow/steps/worker_removal.rs b/sgl-router/src/core/workflow/steps/worker_removal.rs new file mode 100644 index 000000000..a1cfc351b --- /dev/null +++ b/sgl-router/src/core/workflow/steps/worker_removal.rs @@ -0,0 +1,310 @@ +//! Worker Removal Workflow Steps +//! +//! This module implements the workflow steps for removing workers from the router. +//! Handles both single worker removal and DP-aware worker removal with prefix matching. +//! +//! Steps: +//! 1. FindWorkersToRemove - Identify workers to remove based on URL (handles DP-aware prefix matching) +//! 2. RemoveFromPolicyRegistry - Remove workers from policy registry and cache-aware policies +//! 3. RemoveFromWorkerRegistry - Remove workers from worker registry +//! 4. UpdateRemainingPolicies - Update cache-aware policies for remaining workers + +use std::{collections::HashSet, sync::Arc}; + +use async_trait::async_trait; +use tracing::{debug, info}; + +use crate::{ + core::{workflow::*, Worker}, + server::AppContext, +}; + +/// Request structure for worker removal +#[derive(Debug, Clone)] +pub struct WorkerRemovalRequest { + pub url: String, + pub dp_aware: bool, +} + +/// Step 1: Find workers to remove based on URL +pub struct FindWorkersToRemoveStep; + +#[async_trait] +impl StepExecutor for FindWorkersToRemoveStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let request: Arc = context + .get("removal_request") + .ok_or_else(|| WorkflowError::ContextValueNotFound("removal_request".to_string()))?; + let app_context: Arc = context + .get("app_context") + .ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?; + + debug!( + "Finding workers to remove for {} (dp_aware: {})", + request.url, request.dp_aware + ); + + let workers_to_remove: Vec> = if request.dp_aware { + // DP-aware: Find all workers with matching prefix + let worker_url_prefix = format!("{}@", request.url); + let all_workers = app_context.worker_registry.get_all(); + + all_workers + .iter() + .filter(|worker| worker.url().starts_with(&worker_url_prefix)) + .cloned() + .collect() + } else { + // Non-DP-aware: Find single worker by exact URL + match app_context.worker_registry.get_by_url(&request.url) { + Some(worker) => vec![worker], + None => Vec::new(), + } + }; + + if workers_to_remove.is_empty() { + let error_msg = if request.dp_aware { + format!("No workers found with prefix {}@", request.url) + } else { + format!("Worker {} not found", request.url) + }; + return Err(WorkflowError::StepFailed { + step_id: StepId::new("find_workers_to_remove"), + message: error_msg, + }); + } + + debug!( + "Found {} worker(s) to remove for {}", + workers_to_remove.len(), + request.url + ); + + // Store workers and their model IDs for subsequent steps + let worker_urls: Vec = workers_to_remove + .iter() + .map(|w| w.url().to_string()) + .collect(); + + let affected_models: HashSet = workers_to_remove + .iter() + .map(|w| w.model_id().to_string()) + .collect(); + + context.set("workers_to_remove", workers_to_remove); + context.set("worker_urls", worker_urls); + context.set("affected_models", affected_models); + + Ok(StepResult::Success) + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + false // Worker not found is not retryable + } +} + +/// Step 2: Remove workers from policy registry +pub struct RemoveFromPolicyRegistryStep; + +#[async_trait] +impl StepExecutor for RemoveFromPolicyRegistryStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let app_context: Arc = context + .get("app_context") + .ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?; + let workers_to_remove: Arc>> = context + .get("workers_to_remove") + .ok_or_else(|| WorkflowError::ContextValueNotFound("workers_to_remove".to_string()))?; + + debug!( + "Removing {} worker(s) from policy registry", + workers_to_remove.len() + ); + + for worker in workers_to_remove.iter() { + let model_id = worker.model_id().to_string(); + let worker_url = worker.url(); + + // Remove from cache-aware policy + app_context + .policy_registry + .remove_worker_from_cache_aware(&model_id, worker_url); + + // Notify policy registry + app_context.policy_registry.on_worker_removed(&model_id); + + debug!( + "Removed worker {} from policy registry (model: {})", + worker_url, model_id + ); + } + + Ok(StepResult::Success) + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + false // Policy removal is not retryable + } +} + +/// Step 3: Remove workers from worker registry +pub struct RemoveFromWorkerRegistryStep; + +#[async_trait] +impl StepExecutor for RemoveFromWorkerRegistryStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let app_context: Arc = context + .get("app_context") + .ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?; + let worker_urls: Arc> = context + .get("worker_urls") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker_urls".to_string()))?; + + debug!( + "Removing {} worker(s) from worker registry", + worker_urls.len() + ); + + let mut removed_count = 0; + for worker_url in worker_urls.iter() { + if app_context + .worker_registry + .remove_by_url(worker_url) + .is_some() + { + removed_count += 1; + debug!("Removed worker {} from registry", worker_url); + } + } + + if removed_count != worker_urls.len() { + return Err(WorkflowError::StepFailed { + step_id: StepId::new("remove_from_worker_registry"), + message: format!( + "Expected to remove {} workers but only removed {}", + worker_urls.len(), + removed_count + ), + }); + } + + Ok(StepResult::Success) + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + false // Worker removal is not retryable + } +} + +/// Step 4: Update cache-aware policies for remaining workers +pub struct UpdateRemainingPoliciesStep; + +#[async_trait] +impl StepExecutor for UpdateRemainingPoliciesStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let app_context: Arc = context + .get("app_context") + .ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?; + let affected_models: Arc> = context + .get("affected_models") + .ok_or_else(|| WorkflowError::ContextValueNotFound("affected_models".to_string()))?; + let worker_urls: Arc> = context + .get("worker_urls") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker_urls".to_string()))?; + + debug!( + "Updating cache-aware policies for {} affected model(s)", + affected_models.len() + ); + + for model_id in affected_models.iter() { + let remaining_workers = app_context.worker_registry.get_by_model_fast(model_id); + + if let Some(policy) = app_context.policy_registry.get_policy(model_id) { + if policy.name() == "cache_aware" && !remaining_workers.is_empty() { + app_context + .policy_registry + .init_cache_aware_policy(model_id, &remaining_workers); + + debug!( + "Updated cache-aware policy for model {} ({} remaining workers)", + model_id, + remaining_workers.len() + ); + } + } + } + + // Log final result at info level + if worker_urls.len() == 1 { + info!("Removed worker {}", worker_urls[0]); + } else { + info!( + "Removed {} DP-aware workers: {:?}", + worker_urls.len(), + worker_urls + ); + } + + Ok(StepResult::Success) + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + false // Policy update is not retryable + } +} + +/// Create a worker removal workflow definition +pub fn create_worker_removal_workflow() -> WorkflowDefinition { + use std::time::Duration; + + WorkflowDefinition::new("worker_removal", "Remove worker from router") + .add_step( + StepDefinition::new( + "find_workers_to_remove", + "Find workers to remove", + Arc::new(FindWorkersToRemoveStep), + ) + .with_timeout(Duration::from_secs(10)) + .with_retry(RetryPolicy { + max_attempts: 1, + backoff: BackoffStrategy::Fixed(Duration::from_secs(0)), + }), + ) + .add_step( + StepDefinition::new( + "remove_from_policy_registry", + "Remove workers from policy registry", + Arc::new(RemoveFromPolicyRegistryStep), + ) + .with_timeout(Duration::from_secs(10)) + .with_retry(RetryPolicy { + max_attempts: 1, + backoff: BackoffStrategy::Fixed(Duration::from_secs(0)), + }), + ) + .add_step( + StepDefinition::new( + "remove_from_worker_registry", + "Remove workers from worker registry", + Arc::new(RemoveFromWorkerRegistryStep), + ) + .with_timeout(Duration::from_secs(10)) + .with_retry(RetryPolicy { + max_attempts: 1, + backoff: BackoffStrategy::Fixed(Duration::from_secs(0)), + }), + ) + .add_step( + StepDefinition::new( + "update_remaining_policies", + "Update cache-aware policies for remaining workers", + Arc::new(UpdateRemainingPoliciesStep), + ) + .with_timeout(Duration::from_secs(10)) + .with_retry(RetryPolicy { + max_attempts: 1, + backoff: BackoffStrategy::Fixed(Duration::from_secs(0)), + }), + ) +} diff --git a/sgl-router/src/routers/grpc/responses/conversions.rs b/sgl-router/src/routers/grpc/responses/conversions.rs index b55b8c1ca..e9428662d 100644 --- a/sgl-router/src/routers/grpc/responses/conversions.rs +++ b/sgl-router/src/routers/grpc/responses/conversions.rs @@ -149,7 +149,11 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result, -} - -async fn add_worker( - State(state): State>, - Query(AddWorkerQuery { url, api_key }): Query, -) -> Response { - // Warn if router has API key but worker is being added without one - if state.context.router_config.api_key.is_some() && api_key.is_none() { - warn!( - "Adding worker {} without API key while router has API key configured. \ - Worker will be accessible without authentication. \ - If the worker requires the same API key as the router, please specify it explicitly.", - url - ); - } - - let result = WorkerManager::add_worker(&url, &api_key, &state.context).await; - - match result { - Ok(message) => (StatusCode::OK, message).into_response(), - Err(error) => (StatusCode::BAD_REQUEST, error).into_response(), - } -} - -async fn list_workers(State(state): State>) -> Response { - let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry); - Json(json!({ "urls": worker_list })).into_response() -} - -async fn remove_worker( - State(state): State>, - Query(AddWorkerQuery { url, .. }): Query, -) -> Response { - let result = WorkerManager::remove_worker(&url, &state.context); - - match result { - Ok(message) => (StatusCode::OK, message).into_response(), - Err(error) => (StatusCode::BAD_REQUEST, error).into_response(), - } -} - async fn flush_cache(State(state): State>, _req: Request) -> Response { match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client) .await @@ -566,6 +525,12 @@ async fn create_worker( ); } + // Populate dp_aware from router's configuration + let config = WorkerConfigRequest { + dp_aware: state.context.router_config.dp_aware, + ..config + }; + // Submit job for async processing let worker_url = config.url.clone(); let job = Job::AddWorker { @@ -761,9 +726,6 @@ pub fn build_app( .route("/get_server_info", get(get_server_info)); let admin_routes = Router::new() - .route("/add_worker", post(add_worker)) - .route("/remove_worker", post(remove_worker)) - .route("/list_workers", get(list_workers)) .route("/flush_cache", post(flush_cache)) .route("/get_loads", get(get_loads)) .route_layer(axum::middleware::from_fn_with_state( @@ -1018,15 +980,16 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box= expected_count { + break; + } + + if start.elapsed() > timeout_duration { + panic!( + "Timeout waiting for {} workers to become healthy (only {} ready)", + expected_count, healthy_workers + ); + } + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } } // Create router let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router); - // Wait for router to discover workers - if !workers.is_empty() { - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; - } - Self { workers, router, @@ -711,221 +740,6 @@ mod model_info_tests { } } -#[cfg(test)] -mod worker_management_tests { - use super::*; - - #[tokio::test] - async fn test_add_new_worker() { - let ctx = TestContext::new(vec![]).await; - let app = ctx.create_app().await; - - // Start a mock worker - let mut worker = MockWorker::new(MockWorkerConfig { - port: 18301, - worker_type: WorkerType::Regular, - health_status: HealthStatus::Healthy, - response_delay_ms: 0, - fail_rate: 0.0, - }); - let url = worker.start().await.unwrap(); - - // Add the worker - let req = Request::builder() - .method("POST") - .uri(format!("/add_worker?url={}", url)) - .body(Body::empty()) - .unwrap(); - - let resp = app.clone().oneshot(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - // List workers to verify - let req = Request::builder() - .method("GET") - .uri("/list_workers") - .body(Body::empty()) - .unwrap(); - - let resp = app.oneshot(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - let body = axum::body::to_bytes(resp.into_body(), usize::MAX) - .await - .unwrap(); - let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); - let workers = body_json["urls"].as_array().unwrap(); - assert!(workers.iter().any(|w| w.as_str().unwrap() == url)); - - worker.stop().await; - ctx.shutdown().await; - } - - #[tokio::test] - async fn test_remove_existing_worker() { - let ctx = TestContext::new(vec![MockWorkerConfig { - port: 18302, - worker_type: WorkerType::Regular, - health_status: HealthStatus::Healthy, - response_delay_ms: 0, - fail_rate: 0.0, - }]) - .await; - - let app = ctx.create_app().await; - - // Get the worker URL - let req = Request::builder() - .method("GET") - .uri("/list_workers") - .body(Body::empty()) - .unwrap(); - let resp = app.clone().oneshot(req).await.unwrap(); - let body = axum::body::to_bytes(resp.into_body(), usize::MAX) - .await - .unwrap(); - let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); - let workers = body_json["urls"].as_array().unwrap(); - let worker_url = workers[0].as_str().unwrap(); - - // Remove the worker - let req = Request::builder() - .method("POST") - .uri(format!("/remove_worker?url={}", worker_url)) - .body(Body::empty()) - .unwrap(); - - let resp = app.clone().oneshot(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - let req = Request::builder() - .method("GET") - .uri("/list_workers") - .body(Body::empty()) - .unwrap(); - let resp = app.oneshot(req).await.unwrap(); - let body = axum::body::to_bytes(resp.into_body(), usize::MAX) - .await - .unwrap(); - let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); - let workers = body_json["urls"].as_array().unwrap(); - assert!(workers.is_empty()); - - ctx.shutdown().await; - } - - #[tokio::test] - async fn test_add_worker_invalid_url() { - let ctx = TestContext::new(vec![]).await; - let app = ctx.create_app().await; - - // Invalid URL format - let req = Request::builder() - .method("POST") - .uri("/add_worker?url=not-a-valid-url") - .body(Body::empty()) - .unwrap(); - - let resp = app.clone().oneshot(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - - // Missing URL parameter - let req = Request::builder() - .method("POST") - .uri("/add_worker") - .body(Body::empty()) - .unwrap(); - - let resp = app.clone().oneshot(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - - // Empty URL - let req = Request::builder() - .method("POST") - .uri("/add_worker?url=") - .body(Body::empty()) - .unwrap(); - - let resp = app.oneshot(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - - ctx.shutdown().await; - } - - #[tokio::test] - async fn test_add_duplicate_worker() { - // Start a mock worker - let mut worker = MockWorker::new(MockWorkerConfig { - port: 18303, - worker_type: WorkerType::Regular, - health_status: HealthStatus::Healthy, - response_delay_ms: 0, - fail_rate: 0.0, - }); - let url = worker.start().await.unwrap(); - - let ctx = TestContext::new(vec![]).await; - let app = ctx.create_app().await; - - // Add worker first time - let req = Request::builder() - .method("POST") - .uri(format!("/add_worker?url={}", url)) - .body(Body::empty()) - .unwrap(); - let resp = app.clone().oneshot(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; - - // Try to add same worker again - let req = Request::builder() - .method("POST") - .uri(format!("/add_worker?url={}", url)) - .body(Body::empty()) - .unwrap(); - let resp = app.oneshot(req).await.unwrap(); - // Should return error for duplicate - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - - worker.stop().await; - ctx.shutdown().await; - } - - #[tokio::test] - async fn test_add_unhealthy_worker() { - // Start unhealthy worker - let mut worker = MockWorker::new(MockWorkerConfig { - port: 18304, - worker_type: WorkerType::Regular, - health_status: HealthStatus::Unhealthy, - response_delay_ms: 0, - fail_rate: 0.0, - }); - let url = worker.start().await.unwrap(); - - let ctx = TestContext::new(vec![]).await; - let app = ctx.create_app().await; - - // Try to add unhealthy worker - let req = Request::builder() - .method("POST") - .uri(format!("/add_worker?url={}", url)) - .body(Body::empty()) - .unwrap(); - let resp = app.oneshot(req).await.unwrap(); - - // Router should reject unhealthy workers - assert!( - resp.status() == StatusCode::BAD_REQUEST - || resp.status() == StatusCode::SERVICE_UNAVAILABLE - ); - - worker.stop().await; - ctx.shutdown().await; - } -} - #[cfg(test)] mod router_policy_tests { use super::*; diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index d35538875..23a37facb 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -66,7 +66,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc { let worker_job_queue = Arc::new(OnceLock::new()); let workflow_engine = Arc::new(OnceLock::new()); - Arc::new(AppContext::new( + let app_context = Arc::new(AppContext::new( config, client, rate_limiter, @@ -81,7 +81,32 @@ pub fn create_test_context(config: RouterConfig) -> Arc { load_monitor, worker_job_queue, workflow_engine, - )) + )); + + // Initialize JobQueue after AppContext is created + let weak_context = Arc::downgrade(&app_context); + let job_queue = sglang_router_rs::core::JobQueue::new( + sglang_router_rs::core::JobQueueConfig::default(), + weak_context, + ); + app_context + .worker_job_queue + .set(job_queue) + .expect("JobQueue should only be initialized once"); + + // Initialize WorkflowEngine and register workflows + use sglang_router_rs::core::workflow::{ + create_worker_registration_workflow, create_worker_removal_workflow, WorkflowEngine, + }; + let engine = Arc::new(WorkflowEngine::new()); + engine.register_workflow(create_worker_registration_workflow()); + engine.register_workflow(create_worker_removal_workflow()); + app_context + .workflow_engine + .set(engine) + .expect("WorkflowEngine should only be initialized once"); + + app_context } // Tokenizer download configuration diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index 823ca5fbe..0d23d7468 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -7,7 +7,6 @@ use reqwest::Client; use serde_json::json; use sglang_router_rs::{ config::{RouterConfig, RoutingMode}, - core::WorkerManager, routers::{RouterFactory, RouterTrait}, }; @@ -51,13 +50,6 @@ impl TestContext { let app_context = common::create_test_context(config.clone()); - // Initialize workers in the registry before creating router - if !worker_urls.is_empty() { - WorkerManager::initialize_workers(&config, &app_context.worker_registry, None) - .await - .expect("Failed to initialize workers"); - } - let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router); diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index 81b7443e5..8e0c7b540 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -8,7 +8,6 @@ use reqwest::Client; use serde_json::json; use sglang_router_rs::{ config::{RouterConfig, RoutingMode}, - core::WorkerManager, routers::{RouterFactory, RouterTrait}, }; @@ -52,13 +51,6 @@ impl TestContext { let app_context = common::create_test_context(config.clone()); - // Initialize workers in the registry before creating router - if !worker_urls.is_empty() { - WorkerManager::initialize_workers(&config, &app_context.worker_registry, None) - .await - .expect("Failed to initialize workers"); - } - let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router);