[router] change worker api to async instead of sync (#11566)
This commit is contained in:
360
sgl-router/src/core/job_queue.rs
Normal file
360
sgl-router/src/core/job_queue.rs
Normal file
@@ -0,0 +1,360 @@
|
||||
//! Async job queue for control plane operations
|
||||
//!
|
||||
//! Provides non-blocking worker management by queuing operations and processing
|
||||
//! them asynchronously in background worker tasks.
|
||||
|
||||
use crate::core::WorkerManager;
|
||||
use crate::protocols::worker_spec::{JobStatus, WorkerConfigRequest};
|
||||
use crate::server::AppContext;
|
||||
use dashmap::DashMap;
|
||||
use metrics::{counter, gauge, histogram};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Job types for control plane operations
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Job {
|
||||
AddWorker { config: Box<WorkerConfigRequest> },
|
||||
RemoveWorker { url: String },
|
||||
}
|
||||
|
||||
impl Job {
|
||||
/// Get job type as string for logging
|
||||
pub fn job_type(&self) -> &str {
|
||||
match self {
|
||||
Job::AddWorker { .. } => "AddWorker",
|
||||
Job::RemoveWorker { .. } => "RemoveWorker",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get worker URL for logging
|
||||
pub fn worker_url(&self) -> &str {
|
||||
match self {
|
||||
Job::AddWorker { config } => &config.url,
|
||||
Job::RemoveWorker { url } => url,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl JobStatus {
|
||||
fn pending(job_type: &str, worker_url: &str) -> Self {
|
||||
Self {
|
||||
job_type: job_type.to_string(),
|
||||
worker_url: worker_url.to_string(),
|
||||
status: "pending".to_string(),
|
||||
message: None,
|
||||
timestamp: SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
}
|
||||
}
|
||||
|
||||
fn processing(job_type: &str, worker_url: &str) -> Self {
|
||||
Self {
|
||||
job_type: job_type.to_string(),
|
||||
worker_url: worker_url.to_string(),
|
||||
status: "processing".to_string(),
|
||||
message: None,
|
||||
timestamp: SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
}
|
||||
}
|
||||
|
||||
fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
|
||||
Self {
|
||||
job_type: job_type.to_string(),
|
||||
worker_url: worker_url.to_string(),
|
||||
status: "failed".to_string(),
|
||||
message: Some(error),
|
||||
timestamp: SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Job queue configuration
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct JobQueueConfig {
|
||||
/// Maximum pending jobs in queue
|
||||
pub queue_capacity: usize,
|
||||
/// Number of worker tasks processing jobs
|
||||
pub worker_count: usize,
|
||||
}
|
||||
|
||||
impl Default for JobQueueConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
queue_capacity: 1000,
|
||||
worker_count: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Job queue manager for worker validation and removal operations
|
||||
pub struct JobQueue {
|
||||
/// Channel for submitting jobs
|
||||
tx: mpsc::Sender<Job>,
|
||||
/// Weak reference to AppContext to avoid circular dependencies
|
||||
context: Weak<AppContext>,
|
||||
/// Job status tracking by worker URL
|
||||
status_map: Arc<DashMap<String, JobStatus>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for JobQueue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("JobQueue")
|
||||
.field("status_count", &self.status_map.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl JobQueue {
|
||||
/// Create a new job queue with background workers (spawns tasks)
|
||||
///
|
||||
/// Takes a Weak reference to AppContext to avoid circular strong references.
|
||||
/// Spawns background worker tasks that will process jobs asynchronously.
|
||||
pub fn new(config: JobQueueConfig, context: Weak<AppContext>) -> Arc<Self> {
|
||||
let (tx, rx) = mpsc::channel(config.queue_capacity);
|
||||
|
||||
info!(
|
||||
"Initializing worker job queue: capacity={}, workers={}",
|
||||
config.queue_capacity, config.worker_count
|
||||
);
|
||||
|
||||
let rx = Arc::new(tokio::sync::Mutex::new(rx));
|
||||
let status_map = Arc::new(DashMap::new());
|
||||
|
||||
let queue = Arc::new(Self {
|
||||
tx,
|
||||
context: context.clone(),
|
||||
status_map: status_map.clone(),
|
||||
});
|
||||
|
||||
for i in 0..config.worker_count {
|
||||
let rx = Arc::clone(&rx);
|
||||
let context = context.clone();
|
||||
let status_map = status_map.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
Self::worker_loop(i, rx, context, status_map).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Spawn cleanup task for old job statuses (TTL 5 minutes)
|
||||
let cleanup_status_map = status_map.clone();
|
||||
tokio::spawn(async move {
|
||||
Self::cleanup_old_statuses(cleanup_status_map).await;
|
||||
});
|
||||
|
||||
queue
|
||||
}
|
||||
|
||||
/// Submit a job
|
||||
pub async fn submit(&self, job: Job) -> Result<(), String> {
|
||||
// Check if context is still alive before accepting jobs
|
||||
if self.context.upgrade().is_none() {
|
||||
counter!("sgl_router_job_shutdown_rejected_total").increment(1);
|
||||
return Err("Job queue shutting down: AppContext dropped".to_string());
|
||||
}
|
||||
|
||||
// Extract values before moving job
|
||||
let job_type = job.job_type().to_string();
|
||||
let worker_url = job.worker_url().to_string();
|
||||
|
||||
// Record pending status
|
||||
self.status_map.insert(
|
||||
worker_url.clone(),
|
||||
JobStatus::pending(&job_type, &worker_url),
|
||||
);
|
||||
|
||||
match self.tx.send(job).await {
|
||||
Ok(_) => {
|
||||
let queue_depth = self.tx.max_capacity() - self.tx.capacity();
|
||||
gauge!("sgl_router_job_queue_depth").set(queue_depth as f64);
|
||||
|
||||
info!(
|
||||
"Job submitted: type={}, worker={}, queue_depth={}",
|
||||
job_type, worker_url, queue_depth
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => {
|
||||
counter!("sgl_router_job_queue_full_total").increment(1);
|
||||
// Remove status on failure
|
||||
self.status_map.remove(&worker_url);
|
||||
Err("Worker job queue full".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get job status by worker URL
|
||||
pub fn get_status(&self, worker_url: &str) -> Option<JobStatus> {
|
||||
self.status_map.get(worker_url).map(|entry| entry.clone())
|
||||
}
|
||||
|
||||
/// Remove job status (called when worker is deleted)
|
||||
pub fn remove_status(&self, worker_url: &str) {
|
||||
self.status_map.remove(worker_url);
|
||||
}
|
||||
|
||||
/// Worker loop that processes jobs
|
||||
async fn worker_loop(
|
||||
worker_id: usize,
|
||||
rx: Arc<tokio::sync::Mutex<mpsc::Receiver<Job>>>,
|
||||
context: Weak<AppContext>,
|
||||
status_map: Arc<DashMap<String, JobStatus>>,
|
||||
) {
|
||||
info!("Worker job queue worker {} started", worker_id);
|
||||
|
||||
loop {
|
||||
// Lock the receiver and try to receive a job
|
||||
let job = {
|
||||
let mut rx_guard = rx.lock().await;
|
||||
rx_guard.recv().await
|
||||
};
|
||||
|
||||
match job {
|
||||
Some(job) => {
|
||||
let job_type = job.job_type().to_string();
|
||||
let worker_url = job.worker_url().to_string();
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Update status to processing
|
||||
status_map.insert(
|
||||
worker_url.clone(),
|
||||
JobStatus::processing(&job_type, &worker_url),
|
||||
);
|
||||
|
||||
info!(
|
||||
"Worker {} processing job: type={}, worker={}",
|
||||
worker_id, job_type, worker_url
|
||||
);
|
||||
|
||||
// Upgrade weak reference to process job
|
||||
match context.upgrade() {
|
||||
Some(ctx) => {
|
||||
// Execute job
|
||||
let result = Self::execute_job(&job, &ctx).await;
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Record metrics
|
||||
histogram!("sgl_router_job_duration_seconds", "job_type" => job_type.clone())
|
||||
.record(duration.as_secs_f64());
|
||||
|
||||
match result {
|
||||
Ok(message) => {
|
||||
counter!("sgl_router_job_success_total", "job_type" => job_type.clone())
|
||||
.increment(1);
|
||||
// Remove status on success - worker in registry is sufficient
|
||||
status_map.remove(&worker_url);
|
||||
info!(
|
||||
"Worker {} completed job: type={}, worker={}, duration={:.3}s, result={}",
|
||||
worker_id, job_type, worker_url, duration.as_secs_f64(), message
|
||||
);
|
||||
}
|
||||
Err(error) => {
|
||||
counter!("sgl_router_job_failure_total", "job_type" => job_type.clone())
|
||||
.increment(1);
|
||||
// Keep failed status for API to report error details
|
||||
status_map.insert(
|
||||
worker_url.clone(),
|
||||
JobStatus::failed(&job_type, &worker_url, error.clone()),
|
||||
);
|
||||
warn!(
|
||||
"Worker {} failed job: type={}, worker={}, duration={:.3}s, error={}",
|
||||
worker_id, job_type, worker_url, duration.as_secs_f64(), error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let error_msg = "AppContext dropped".to_string();
|
||||
status_map.insert(
|
||||
worker_url.clone(),
|
||||
JobStatus::failed(&job_type, &worker_url, error_msg),
|
||||
);
|
||||
error!(
|
||||
"Worker {}: AppContext dropped, cannot process job: type={}, worker={}",
|
||||
worker_id, job_type, worker_url
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
"Worker job queue worker {} channel closed, stopping",
|
||||
worker_id
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
warn!("Worker job queue worker {} stopped", worker_id);
|
||||
}
|
||||
|
||||
/// Execute a specific job
|
||||
async fn execute_job(job: &Job, context: &Arc<AppContext>) -> Result<String, String> {
|
||||
match job {
|
||||
Job::AddWorker { config } => {
|
||||
// Register worker with is_healthy=false
|
||||
let worker =
|
||||
WorkerManager::add_worker_from_config(config.as_ref(), context).await?;
|
||||
|
||||
// Validate and activate
|
||||
WorkerManager::validate_and_activate_worker(&worker, context).await
|
||||
}
|
||||
Job::RemoveWorker { url } => {
|
||||
let result = WorkerManager::remove_worker(url, context);
|
||||
// Clean up job status when removing worker
|
||||
if let Some(queue) = context.worker_job_queue.get() {
|
||||
queue.remove_status(url);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cleanup old job statuses (TTL 5 minutes)
|
||||
async fn cleanup_old_statuses(status_map: Arc<DashMap<String, JobStatus>>) {
|
||||
const CLEANUP_INTERVAL: Duration = Duration::from_secs(60); // Run every minute
|
||||
const STATUS_TTL: u64 = 300; // 5 minutes in seconds
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(CLEANUP_INTERVAL).await;
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
// Remove statuses older than TTL
|
||||
status_map.retain(|_key, value| now - value.timestamp < STATUS_TTL);
|
||||
|
||||
debug!(
|
||||
"Cleaned up old job statuses, remaining: {}",
|
||||
status_map.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_job_queue_config_default() {
|
||||
let config = JobQueueConfig::default();
|
||||
assert_eq!(config.queue_capacity, 1000);
|
||||
assert_eq!(config.worker_count, 2);
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
pub mod circuit_breaker;
|
||||
pub mod error;
|
||||
pub mod job_queue;
|
||||
pub mod retry;
|
||||
pub mod token_bucket;
|
||||
pub mod worker;
|
||||
@@ -19,10 +20,11 @@ pub use circuit_breaker::{
|
||||
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
|
||||
};
|
||||
pub use error::{WorkerError, WorkerResult};
|
||||
pub use job_queue::{Job, JobQueue, JobQueueConfig};
|
||||
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
|
||||
pub use worker::{
|
||||
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
|
||||
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||
start_health_checker, worker_to_info, BasicWorker, ConnectionMode, DPAwareWorker,
|
||||
HealthChecker, HealthConfig, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||
};
|
||||
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||
pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager};
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::core::CircuitState;
|
||||
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::protocols::worker_spec::WorkerInfo;
|
||||
use async_trait::async_trait;
|
||||
use futures;
|
||||
use serde_json;
|
||||
@@ -974,6 +975,39 @@ pub fn start_health_checker(
|
||||
HealthChecker { handle, shutdown }
|
||||
}
|
||||
|
||||
/// Helper to convert Worker trait object to WorkerInfo struct
|
||||
pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo {
|
||||
let worker_type_str = match worker.worker_type() {
|
||||
WorkerType::Regular => "regular",
|
||||
WorkerType::Prefill { .. } => "prefill",
|
||||
WorkerType::Decode => "decode",
|
||||
};
|
||||
|
||||
let bootstrap_port = match worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
|
||||
WorkerInfo {
|
||||
id: worker.url().to_string(),
|
||||
url: worker.url().to_string(),
|
||||
model_id: worker.model_id().to_string(),
|
||||
priority: worker.priority(),
|
||||
cost: worker.cost(),
|
||||
worker_type: worker_type_str.to_string(),
|
||||
is_healthy: worker.is_healthy(),
|
||||
load: worker.load(),
|
||||
connection_mode: format!("{:?}", worker.connection_mode()),
|
||||
tokenizer_path: worker.tokenizer_path().map(String::from),
|
||||
reasoning_parser: worker.reasoning_parser().map(String::from),
|
||||
tool_parser: worker.tool_parser().map(String::from),
|
||||
chat_template: worker.chat_template().map(String::from),
|
||||
bootstrap_port,
|
||||
metadata: worker.metadata().labels.clone(),
|
||||
job_status: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -551,31 +551,23 @@ impl WorkerManager {
|
||||
}
|
||||
|
||||
/// Add a worker from a configuration request
|
||||
///
|
||||
/// Registers worker immediately with healthy=false, returns worker for async validation
|
||||
pub async fn add_worker_from_config(
|
||||
config: &WorkerConfigRequest,
|
||||
context: &AppContext,
|
||||
) -> Result<String, String> {
|
||||
) -> Result<Arc<dyn Worker>, String> {
|
||||
// Check if worker already exists
|
||||
if context.worker_registry.get_by_url(&config.url).is_some() {
|
||||
return Err(format!("Worker {} already exists", config.url));
|
||||
}
|
||||
let mut labels = config.labels.clone();
|
||||
|
||||
let model_id = if let Some(ref model_id) = config.model_id {
|
||||
model_id.clone()
|
||||
} else {
|
||||
match Self::get_server_info(&config.url, config.api_key.as_deref()).await {
|
||||
Ok(info) => info
|
||||
.model_id
|
||||
.or_else(|| {
|
||||
info.model_path
|
||||
.as_ref()
|
||||
.and_then(|path| path.split('/').next_back().map(|s| s.to_string()))
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".to_string()),
|
||||
Err(e) => {
|
||||
warn!("Failed to query server info from {}: {}", config.url, e);
|
||||
"unknown".to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Use provided model_id or default to "unknown"
|
||||
let model_id = config
|
||||
.model_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
labels.insert("model_id".to_string(), model_id.clone());
|
||||
if let Some(priority) = config.priority {
|
||||
labels.insert("priority".to_string(), priority.to_string());
|
||||
@@ -614,18 +606,54 @@ impl WorkerManager {
|
||||
ConnectionMode::Http
|
||||
};
|
||||
|
||||
let policy_hint = labels.get("policy").cloned();
|
||||
let circuit_breaker_config = Self::convert_circuit_breaker_config(
|
||||
&context.router_config.effective_circuit_breaker_config(),
|
||||
);
|
||||
let health_config = Self::convert_health_config(&context.router_config.health_check);
|
||||
|
||||
Self::add_worker_internal(
|
||||
&config.url,
|
||||
// Create and register worker (starts with healthy=false)
|
||||
let worker = Self::create_basic_worker(
|
||||
config.url.clone(),
|
||||
worker_type,
|
||||
connection_mode,
|
||||
config.api_key.clone(),
|
||||
Some(labels),
|
||||
policy_hint.as_deref(),
|
||||
context,
|
||||
)
|
||||
.await
|
||||
Some(labels.clone()),
|
||||
circuit_breaker_config,
|
||||
health_config,
|
||||
);
|
||||
|
||||
worker.set_healthy(false);
|
||||
context.worker_registry.register(worker.clone());
|
||||
|
||||
let policy_hint = labels.get("policy").map(|s| s.as_str());
|
||||
context
|
||||
.policy_registry
|
||||
.on_worker_added(&model_id, policy_hint);
|
||||
|
||||
info!("Registered worker {} (initializing)", config.url);
|
||||
|
||||
// Return worker for async validation
|
||||
Ok(worker)
|
||||
}
|
||||
|
||||
/// Validate and activate a worker (for async validation after registration)
|
||||
pub async fn validate_and_activate_worker(
|
||||
worker: &Arc<dyn Worker>,
|
||||
context: &AppContext,
|
||||
) -> Result<String, String> {
|
||||
let url = worker.url();
|
||||
|
||||
// Perform health validation
|
||||
WorkerFactory::validate_health(url, context.router_config.worker_startup_timeout_secs)
|
||||
.await
|
||||
.map_err(|e| format!("Health check failed for {}: {}", url, e))?;
|
||||
|
||||
// Mark as healthy
|
||||
worker.set_healthy(true);
|
||||
|
||||
info!("Worker {} validated and activated", url);
|
||||
|
||||
Ok(format!("Worker {} is now healthy", url))
|
||||
}
|
||||
|
||||
/// Add a worker from URL (legacy endpoint)
|
||||
|
||||
@@ -71,6 +71,31 @@ pub fn init_metrics() {
|
||||
"Total requests processed by each worker"
|
||||
);
|
||||
|
||||
describe_gauge!(
|
||||
"sgl_router_job_queue_depth",
|
||||
"Current number of pending jobs in the queue"
|
||||
);
|
||||
describe_histogram!(
|
||||
"sgl_router_job_duration_seconds",
|
||||
"Job processing duration in seconds by job type"
|
||||
);
|
||||
describe_counter!(
|
||||
"sgl_router_job_success_total",
|
||||
"Total successful job completions by job type"
|
||||
);
|
||||
describe_counter!(
|
||||
"sgl_router_job_failure_total",
|
||||
"Total failed job completions by job type"
|
||||
);
|
||||
describe_counter!(
|
||||
"sgl_router_job_queue_full_total",
|
||||
"Total number of jobs rejected due to queue full"
|
||||
);
|
||||
describe_counter!(
|
||||
"sgl_router_job_shutdown_rejected_total",
|
||||
"Total number of jobs rejected due to shutdown"
|
||||
);
|
||||
|
||||
describe_counter!(
|
||||
"sgl_router_policy_decisions_total",
|
||||
"Total routing policy decisions by policy and worker"
|
||||
|
||||
@@ -100,9 +100,27 @@ pub struct WorkerInfo {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chat_template: Option<String>,
|
||||
|
||||
/// Bootstrap port for prefill workers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bootstrap_port: Option<u16>,
|
||||
|
||||
/// Additional metadata
|
||||
#[serde(skip_serializing_if = "HashMap::is_empty")]
|
||||
pub metadata: HashMap<String, String>,
|
||||
|
||||
/// Job status for async operations (if available)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub job_status: Option<JobStatus>,
|
||||
}
|
||||
|
||||
/// Job status for async control plane operations
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JobStatus {
|
||||
pub job_type: String,
|
||||
pub worker_url: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
/// Worker list response
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use crate::{
|
||||
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
||||
core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType},
|
||||
core::{
|
||||
worker_to_info, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry,
|
||||
WorkerType,
|
||||
},
|
||||
data_connector::{
|
||||
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||
NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage,
|
||||
@@ -17,7 +20,7 @@ use crate::{
|
||||
RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput,
|
||||
},
|
||||
validated::ValidatedJson,
|
||||
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
||||
worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo},
|
||||
},
|
||||
reasoning_parser::ParserFactory as ReasoningParserFactory,
|
||||
routers::{router_manager::RouterManager, RouterTrait},
|
||||
@@ -35,6 +38,7 @@ use axum::{
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::OnceLock;
|
||||
use std::{
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
sync::Arc,
|
||||
@@ -62,6 +66,7 @@ pub struct AppContext {
|
||||
pub load_monitor: Option<Arc<LoadMonitor>>,
|
||||
pub configured_reasoning_parser: Option<String>,
|
||||
pub configured_tool_parser: Option<String>,
|
||||
pub worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
@@ -178,6 +183,9 @@ impl AppContext {
|
||||
let configured_reasoning_parser = router_config.reasoning_parser.clone();
|
||||
let configured_tool_parser = router_config.tool_call_parser.clone();
|
||||
|
||||
// Create empty OnceLock for worker job queue (will be initialized in startup())
|
||||
let worker_job_queue = Arc::new(OnceLock::new());
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
router_config,
|
||||
@@ -194,6 +202,7 @@ impl AppContext {
|
||||
load_monitor,
|
||||
configured_reasoning_parser,
|
||||
configured_tool_parser,
|
||||
worker_job_queue,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -636,52 +645,42 @@ async fn create_worker(
|
||||
);
|
||||
}
|
||||
|
||||
let result = WorkerManager::add_worker_from_config(&config, &state.context).await;
|
||||
// Submit job for async processing
|
||||
let worker_url = config.url.clone();
|
||||
let job = Job::AddWorker {
|
||||
config: Box::new(config),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(message) => {
|
||||
let response = WorkerApiResponse {
|
||||
success: true,
|
||||
message,
|
||||
worker: None,
|
||||
};
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
let job_queue = state
|
||||
.context
|
||||
.worker_job_queue
|
||||
.get()
|
||||
.expect("JobQueue not initialized");
|
||||
match job_queue.submit(job).await {
|
||||
Ok(_) => {
|
||||
let response = json!({
|
||||
"status": "accepted",
|
||||
"worker_id": worker_url,
|
||||
"message": "Worker addition queued for background processing"
|
||||
});
|
||||
(StatusCode::ACCEPTED, Json(response)).into_response()
|
||||
}
|
||||
Err(error) => {
|
||||
let error_response = WorkerErrorResponse {
|
||||
error,
|
||||
code: "ADD_WORKER_FAILED".to_string(),
|
||||
code: "INTERNAL_SERVER_ERROR".to_string(),
|
||||
};
|
||||
(StatusCode::BAD_REQUEST, Json(error_response)).into_response()
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
||||
let workers = state.context.worker_registry.get_all();
|
||||
let response = serde_json::json!({
|
||||
"workers": workers.iter().map(|worker| {
|
||||
let mut worker_info = serde_json::json!({
|
||||
"url": worker.url(),
|
||||
"model_id": worker.model_id(),
|
||||
"worker_type": match worker.worker_type() {
|
||||
WorkerType::Regular => "regular",
|
||||
WorkerType::Prefill { .. } => "prefill",
|
||||
WorkerType::Decode => "decode",
|
||||
},
|
||||
"is_healthy": worker.is_healthy(),
|
||||
"load": worker.load(),
|
||||
"connection_mode": format!("{:?}", worker.connection_mode()),
|
||||
"priority": worker.priority(),
|
||||
"cost": worker.cost(),
|
||||
});
|
||||
let worker_infos: Vec<WorkerInfo> = workers.iter().map(worker_to_info).collect();
|
||||
|
||||
if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
|
||||
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
|
||||
}
|
||||
|
||||
worker_info
|
||||
}).collect::<Vec<_>>(),
|
||||
let response = json!({
|
||||
"workers": worker_infos,
|
||||
"total": workers.len(),
|
||||
"stats": {
|
||||
"prefill_count": state.context.worker_registry.get_prefill_workers().len(),
|
||||
@@ -693,41 +692,77 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
||||
}
|
||||
|
||||
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
||||
let workers = WorkerManager::get_worker_urls(&state.context.worker_registry);
|
||||
if workers.contains(&url) {
|
||||
Json(json!({
|
||||
"url": url,
|
||||
"model_id": "unknown",
|
||||
"is_healthy": true
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
let error = WorkerErrorResponse {
|
||||
error: format!("Worker {url} not found"),
|
||||
code: "WORKER_NOT_FOUND".to_string(),
|
||||
};
|
||||
(StatusCode::NOT_FOUND, Json(error)).into_response()
|
||||
let job_queue = state
|
||||
.context
|
||||
.worker_job_queue
|
||||
.get()
|
||||
.expect("JobQueue not initialized");
|
||||
|
||||
if let Some(worker) = state.context.worker_registry.get_by_url(&url) {
|
||||
// Worker exists in registry, get its full info and attach job status if any
|
||||
let mut worker_info = worker_to_info(&worker);
|
||||
if let Some(status) = job_queue.get_status(&url) {
|
||||
worker_info.job_status = Some(status);
|
||||
}
|
||||
return Json(worker_info).into_response();
|
||||
}
|
||||
|
||||
// Worker not in registry, check job queue for its status
|
||||
if let Some(status) = job_queue.get_status(&url) {
|
||||
// Create a partial WorkerInfo to report the job status
|
||||
let worker_info = WorkerInfo {
|
||||
id: url.clone(),
|
||||
url: url.clone(),
|
||||
model_id: "unknown".to_string(),
|
||||
priority: 0,
|
||||
cost: 1.0,
|
||||
worker_type: "unknown".to_string(),
|
||||
is_healthy: false,
|
||||
load: 0,
|
||||
connection_mode: "unknown".to_string(),
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
bootstrap_port: None,
|
||||
metadata: std::collections::HashMap::new(),
|
||||
job_status: Some(status),
|
||||
};
|
||||
return Json(worker_info).into_response();
|
||||
}
|
||||
|
||||
// Worker not found in registry or job queue
|
||||
let error = WorkerErrorResponse {
|
||||
error: format!("Worker {url} not found"),
|
||||
code: "WORKER_NOT_FOUND".to_string(),
|
||||
};
|
||||
(StatusCode::NOT_FOUND, Json(error)).into_response()
|
||||
}
|
||||
|
||||
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
||||
let result = WorkerManager::remove_worker(&url, &state.context);
|
||||
let worker_id = url.clone();
|
||||
let job = Job::RemoveWorker { url };
|
||||
|
||||
match result {
|
||||
Ok(message) => {
|
||||
let response = WorkerApiResponse {
|
||||
success: true,
|
||||
message,
|
||||
worker: None,
|
||||
};
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
let job_queue = state
|
||||
.context
|
||||
.worker_job_queue
|
||||
.get()
|
||||
.expect("JobQueue not initialized");
|
||||
match job_queue.submit(job).await {
|
||||
Ok(_) => {
|
||||
let response = json!({
|
||||
"status": "accepted",
|
||||
"worker_id": worker_id,
|
||||
"message": "Worker removal queued for background processing"
|
||||
});
|
||||
(StatusCode::ACCEPTED, Json(response)).into_response()
|
||||
}
|
||||
Err(error) => {
|
||||
let error_response = WorkerErrorResponse {
|
||||
error,
|
||||
code: "REMOVE_WORKER_FAILED".to_string(),
|
||||
code: "INTERNAL_SERVER_ERROR".to_string(),
|
||||
};
|
||||
(StatusCode::BAD_REQUEST, Json(error_response)).into_response()
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -898,6 +933,13 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
|
||||
let app_context = Arc::new(app_context);
|
||||
|
||||
let weak_context = Arc::downgrade(&app_context);
|
||||
let worker_job_queue = JobQueue::new(JobQueueConfig::default(), weak_context);
|
||||
app_context
|
||||
.worker_job_queue
|
||||
.set(worker_job_queue)
|
||||
.expect("JobQueue should only be initialized once");
|
||||
|
||||
info!(
|
||||
"Initializing workers for routing mode: {:?}",
|
||||
config.router_config.mode
|
||||
|
||||
@@ -383,18 +383,32 @@ async fn handle_pod_event(
|
||||
api_key: None,
|
||||
};
|
||||
|
||||
let result = WorkerManager::add_worker_from_config(&config, &app_context).await;
|
||||
// Submit job for async worker addition
|
||||
use crate::core::Job;
|
||||
let job = Job::AddWorker {
|
||||
config: Box::new(config.clone()),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
debug!("Worker added: {}", worker_url);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to add worker {} to router: {}", worker_url, e);
|
||||
if let Ok(mut tracker) = tracked_pods.lock() {
|
||||
tracker.remove(pod_info);
|
||||
if let Some(job_queue) = app_context.worker_job_queue.get() {
|
||||
match job_queue.submit(job).await {
|
||||
Ok(_) => {
|
||||
debug!("Worker addition job submitted for: {}", worker_url);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to submit worker addition job for {}: {}",
|
||||
worker_url, e
|
||||
);
|
||||
if let Ok(mut tracker) = tracked_pods.lock() {
|
||||
tracker.remove(pod_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
"JobQueue not initialized, skipping async worker addition for: {}",
|
||||
worker_url
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,6 +543,8 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Note: Using uninitialized queue for tests to avoid spawning background workers
|
||||
// Jobs submitted during tests will queue but not be processed
|
||||
Arc::new(AppContext {
|
||||
client: reqwest::Client::new(),
|
||||
router_config: router_config.clone(),
|
||||
@@ -549,6 +565,7 @@ mod tests {
|
||||
load_monitor: None,
|
||||
configured_reasoning_parser: None,
|
||||
configured_tool_parser: None,
|
||||
worker_job_queue: Arc::new(std::sync::OnceLock::new()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -907,8 +924,6 @@ mod tests {
|
||||
};
|
||||
let port = 8080u16;
|
||||
|
||||
// This test validates the structure but won't actually add workers since
|
||||
// the test worker URL won't be reachable
|
||||
handle_pod_event(
|
||||
&pod_info,
|
||||
Arc::clone(&tracked_pods),
|
||||
@@ -918,8 +933,12 @@ mod tests {
|
||||
)
|
||||
.await;
|
||||
|
||||
// Pod should not be tracked since add_worker_from_config will fail for non-running server
|
||||
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
// With fully async control plane, pod is tracked and job is queued
|
||||
// Worker registration and validation happen in background job
|
||||
assert!(tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
|
||||
// Note: In tests with uninitialized queue, background jobs don't process
|
||||
// Worker won't appear in registry until background job runs (in production)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -945,8 +964,12 @@ mod tests {
|
||||
)
|
||||
.await;
|
||||
|
||||
// Pod should not be tracked since add_worker_from_config will fail for non-running server
|
||||
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
// With fully async control plane, pod is tracked and job is queued
|
||||
// Worker registration and validation happen in background job
|
||||
assert!(tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
|
||||
// Note: In tests with uninitialized queue, background jobs don't process
|
||||
// Worker won't appear in registry until background job runs (in production)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1033,7 +1056,13 @@ mod tests {
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
// With fully async control plane, pod is tracked and job is queued
|
||||
// In regular mode (pd_mode=false), worker_type defaults to Regular
|
||||
// Worker registration and validation happen in background job
|
||||
assert!(tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
|
||||
// Note: In tests with uninitialized queue, background jobs don't process
|
||||
// Worker won't appear in registry until background job runs (in production)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1059,8 +1088,12 @@ mod tests {
|
||||
)
|
||||
.await;
|
||||
|
||||
// Pod should not be tracked since add_worker_from_config will fail for non-running server
|
||||
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
// With fully async control plane, pod is tracked and job is queued
|
||||
// Worker registration and validation happen in background job
|
||||
assert!(tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
|
||||
// Note: In tests with uninitialized queue, background jobs don't process
|
||||
// Worker won't appear in registry until background job runs (in production)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user