From ddcba74b4df6ea31a4f823d4ef5bb73b8daac9f3 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 20 Oct 2025 17:00:22 -0700 Subject: [PATCH] [router] Worker Management Workflow Engine (#11868) --- sgl-router/src/core/job_queue.rs | 299 +++++-- sgl-router/src/core/mod.rs | 6 +- sgl-router/src/core/worker.rs | 84 -- sgl-router/src/core/workflow/definition.rs | 98 ++ sgl-router/src/core/workflow/engine.rs | 484 ++++++++++ sgl-router/src/core/workflow/event.rs | 188 ++++ sgl-router/src/core/workflow/executor.rs | 129 +++ sgl-router/src/core/workflow/mod.rs | 18 + sgl-router/src/core/workflow/state.rs | 175 ++++ sgl-router/src/core/workflow/steps/mod.rs | 12 + .../workflow/steps/worker_registration.rs | 837 ++++++++++++++++++ sgl-router/src/core/workflow/types.rs | 271 ++++++ sgl-router/src/metrics.rs | 33 + sgl-router/src/protocols/worker_spec.rs | 45 + sgl-router/src/server.rs | 46 +- sgl-router/src/service_discovery.rs | 33 +- sgl-router/tests/common/mod.rs | 4 +- sgl-router/tests/common/test_app.rs | 4 +- .../tests/policy_registry_integration.rs | 18 + sgl-router/tests/test_pd_routing.rs | 4 +- sgl-router/tests/workflow_test.rs | 320 +++++++ 21 files changed, 2937 insertions(+), 171 deletions(-) create mode 100644 sgl-router/src/core/workflow/definition.rs create mode 100644 sgl-router/src/core/workflow/engine.rs create mode 100644 sgl-router/src/core/workflow/event.rs create mode 100644 sgl-router/src/core/workflow/executor.rs create mode 100644 sgl-router/src/core/workflow/mod.rs create mode 100644 sgl-router/src/core/workflow/state.rs create mode 100644 sgl-router/src/core/workflow/steps/mod.rs create mode 100644 sgl-router/src/core/workflow/steps/worker_registration.rs create mode 100644 sgl-router/src/core/workflow/types.rs create mode 100644 sgl-router/tests/workflow_test.rs diff --git a/sgl-router/src/core/job_queue.rs b/sgl-router/src/core/job_queue.rs index b6d3e86e1..749eb9b87 100644 --- a/sgl-router/src/core/job_queue.rs +++ b/sgl-router/src/core/job_queue.rs @@ -4,17 +4,24 @@ //! them asynchronously in background worker tasks. use std::{ + collections::HashMap, sync::{Arc, Weak}, time::{Duration, SystemTime}, }; use dashmap::DashMap; -use metrics::{counter, gauge, histogram}; use tokio::sync::mpsc; use tracing::{debug, error, info, warn}; use crate::{ - core::WorkerManager, + config::{RouterConfig, RoutingMode}, + core::{ + workflow::{ + WorkflowContext, WorkflowEngine, WorkflowId, WorkflowInstanceId, WorkflowStatus, + }, + WorkerManager, + }, + metrics::RouterMetrics, protocols::worker_spec::{JobStatus, WorkerConfigRequest}, server::AppContext, }; @@ -24,6 +31,7 @@ use crate::{ pub enum Job { AddWorker { config: Box }, RemoveWorker { url: String }, + InitializeWorkersFromConfig { router_config: Box }, } impl Job { @@ -32,6 +40,7 @@ impl Job { match self { Job::AddWorker { .. } => "AddWorker", Job::RemoveWorker { .. } => "RemoveWorker", + Job::InitializeWorkersFromConfig { .. } => "InitializeWorkersFromConfig", } } @@ -40,6 +49,7 @@ impl Job { match self { Job::AddWorker { config } => &config.url, Job::RemoveWorker { url } => url, + Job::InitializeWorkersFromConfig { .. } => "startup", } } } @@ -98,7 +108,7 @@ impl Default for JobQueueConfig { fn default() -> Self { Self { queue_capacity: 1000, - worker_count: 2, + worker_count: 10, } } } @@ -166,7 +176,7 @@ impl JobQueue { pub async fn submit(&self, job: Job) -> Result<(), String> { // Check if context is still alive before accepting jobs if self.context.upgrade().is_none() { - counter!("sgl_router_job_shutdown_rejected_total").increment(1); + RouterMetrics::record_job_shutdown_rejected(); return Err("Job queue shutting down: AppContext dropped".to_string()); } @@ -183,8 +193,7 @@ impl JobQueue { match self.tx.send(job).await { Ok(_) => { let queue_depth = self.tx.max_capacity() - self.tx.capacity(); - gauge!("sgl_router_job_queue_depth").set(queue_depth as f64); - + RouterMetrics::set_job_queue_depth(queue_depth); info!( "Job submitted: type={}, worker={}, queue_depth={}", job_type, worker_url, queue_depth @@ -192,8 +201,7 @@ impl JobQueue { Ok(()) } Err(_) => { - counter!("sgl_router_job_queue_full_total").increment(1); - // Remove status on failure + RouterMetrics::record_job_queue_full(); self.status_map.remove(&worker_url); Err("Worker job queue full".to_string()) } @@ -246,39 +254,16 @@ impl JobQueue { // Upgrade weak reference to process job match context.upgrade() { Some(ctx) => { - // Execute job let result = Self::execute_job(&job, &ctx).await; let duration = start.elapsed(); - - // Record metrics - histogram!("sgl_router_job_duration_seconds", "job_type" => job_type.clone()) - .record(duration.as_secs_f64()); - - match result { - Ok(message) => { - counter!("sgl_router_job_success_total", "job_type" => job_type.clone()) - .increment(1); - // Remove status on success - worker in registry is sufficient - status_map.remove(&worker_url); - info!( - "Worker {} completed job: type={}, worker={}, duration={:.3}s, result={}", - worker_id, job_type, worker_url, duration.as_secs_f64(), message - ); - } - Err(error) => { - counter!("sgl_router_job_failure_total", "job_type" => job_type.clone()) - .increment(1); - // Keep failed status for API to report error details - status_map.insert( - worker_url.clone(), - JobStatus::failed(&job_type, &worker_url, error.clone()), - ); - warn!( - "Worker {} failed job: type={}, worker={}, duration={:.3}s, error={}", - worker_id, job_type, worker_url, duration.as_secs_f64(), error - ); - } - } + Self::record_job_completion( + &job_type, + &worker_url, + worker_id, + duration, + &result, + &status_map, + ); } None => { let error_msg = "AppContext dropped".to_string(); @@ -311,12 +296,28 @@ impl JobQueue { async fn execute_job(job: &Job, context: &Arc) -> Result { match job { Job::AddWorker { config } => { - // Register worker with is_healthy=false - let worker = - WorkerManager::add_worker_from_config(config.as_ref(), context).await?; + let engine = context + .workflow_engine + .get() + .ok_or_else(|| "Workflow engine not initialized".to_string())?; - // Validate and activate - WorkerManager::validate_and_activate_worker(&worker, context).await + let instance_id = Self::start_worker_workflow(engine, config, context).await?; + + info!( + "Started worker registration workflow for {} (instance: {})", + config.url, instance_id + ); + + let timeout_duration = + Duration::from_secs(context.router_config.worker_startup_timeout_secs + 30); + + Self::wait_for_workflow_completion( + engine, + instance_id, + &config.url, + timeout_duration, + ) + .await } Job::RemoveWorker { url } => { let result = WorkerManager::remove_worker(url, context); @@ -326,6 +327,204 @@ impl JobQueue { } result } + Job::InitializeWorkersFromConfig { router_config } => { + let api_key = router_config.api_key.clone(); + let mut worker_count = 0; + + // Create iterator of (url, worker_type, bootstrap_port) tuples based on mode + let workers: Vec<(String, &str, Option)> = match &router_config.mode { + RoutingMode::Regular { worker_urls } => worker_urls + .iter() + .map(|url| (url.clone(), "regular", None)) + .collect(), + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + .. + } => { + let prefill_workers = prefill_urls + .iter() + .map(|(url, port)| (url.clone(), "prefill", *port)); + + let decode_workers = + decode_urls.iter().map(|url| (url.clone(), "decode", None)); + + prefill_workers.chain(decode_workers).collect() + } + RoutingMode::OpenAI { .. } => { + info!("OpenAI mode: no workers to initialize"); + return Ok("OpenAI mode: no workers to initialize".to_string()); + } + }; + + info!( + "Creating AddWorker jobs for {} workers from config", + workers.len() + ); + + // Process all workers with unified loop + for (url, worker_type, bootstrap_port) in workers { + let url_for_error = url.clone(); // Clone for error message + let config = WorkerConfigRequest { + url, + api_key: api_key.clone(), + worker_type: Some(worker_type.to_string()), + labels: HashMap::new(), + model_id: None, + priority: None, + cost: None, + tokenizer_path: None, + reasoning_parser: None, + tool_parser: None, + chat_template: None, + bootstrap_port, + health_check_timeout_secs: router_config.health_check.timeout_secs, + health_check_interval_secs: router_config.health_check.check_interval_secs, + health_success_threshold: router_config.health_check.success_threshold, + health_failure_threshold: router_config.health_check.failure_threshold, + max_connection_attempts: router_config.health_check.success_threshold * 10, + dp_aware: router_config.dp_aware, + }; + + let job = Job::AddWorker { + config: Box::new(config), + }; + + if let Some(queue) = context.worker_job_queue.get() { + queue.submit(job).await.map_err(|e| { + format!( + "Failed to submit AddWorker job for {} worker {}: {}", + worker_type, url_for_error, e + ) + })?; + worker_count += 1; + } else { + return Err("JobQueue not available".to_string()); + } + } + + Ok(format!("Submitted {} AddWorker jobs", worker_count)) + } + } + } + + /// Start a workflow and return its instance ID + async fn start_worker_workflow( + engine: &Arc, + config: &WorkerConfigRequest, + context: &Arc, + ) -> Result { + let mut workflow_context = WorkflowContext::new(WorkflowInstanceId::new()); + workflow_context.set("worker_config", config.clone()); + workflow_context.set_arc("app_context", Arc::clone(context)); + + engine + .start_workflow(WorkflowId::new("worker_registration"), workflow_context) + .await + .map_err(|e| format!("Failed to start worker registration workflow: {:?}", e)) + } + + /// Wait for workflow completion with adaptive polling + async fn wait_for_workflow_completion( + engine: &Arc, + instance_id: WorkflowInstanceId, + worker_url: &str, + timeout_duration: Duration, + ) -> Result { + let start = std::time::Instant::now(); + let mut poll_interval = Duration::from_millis(100); + let max_poll_interval = Duration::from_millis(2000); + let poll_backoff = Duration::from_millis(200); + + loop { + // Check timeout + if start.elapsed() > timeout_duration { + return Err(format!( + "Workflow timeout after {}s for worker {}", + timeout_duration.as_secs(), + worker_url + )); + } + + // Get workflow status + let state = engine + .get_status(instance_id) + .map_err(|e| format!("Failed to get workflow status: {:?}", e))?; + + let result = match state.status { + WorkflowStatus::Completed => Ok(format!( + "Worker {} registered and activated successfully via workflow", + worker_url + )), + WorkflowStatus::Failed => { + let current_step = state.current_step.as_ref(); + let step_name = current_step + .map(|s| s.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + let error_msg = current_step + .and_then(|step_id| state.step_states.get(step_id)) + .and_then(|s| s.last_error.as_deref()) + .unwrap_or("Unknown error"); + Err(format!( + "Workflow failed at step {}: {}", + step_name, error_msg + )) + } + WorkflowStatus::Cancelled => { + Err(format!("Workflow cancelled for worker {}", worker_url)) + } + WorkflowStatus::Pending | WorkflowStatus::Paused | WorkflowStatus::Running => { + tokio::time::sleep(poll_interval).await; + poll_interval = (poll_interval + poll_backoff).min(max_poll_interval); + continue; + } + }; + + // Clean up terminal workflow states + engine.state_store().cleanup_if_terminal(instance_id); + return result; + } + } + + /// Record job completion metrics and update status + fn record_job_completion( + job_type: &str, + worker_url: &str, + worker_id: usize, + duration: Duration, + result: &Result, + status_map: &Arc>, + ) { + RouterMetrics::record_job_duration(job_type, duration); + + match result { + Ok(message) => { + RouterMetrics::record_job_success(job_type); + status_map.remove(worker_url); + info!( + "Worker {} completed job: type={}, worker={}, duration={:.3}s, result={}", + worker_id, + job_type, + worker_url, + duration.as_secs_f64(), + message + ); + } + Err(error) => { + RouterMetrics::record_job_failure(job_type); + status_map.insert( + worker_url.to_string(), + JobStatus::failed(job_type, worker_url, error.clone()), + ); + warn!( + "Worker {} failed job: type={}, worker={}, duration={:.3}s, error={}", + worker_id, + job_type, + worker_url, + duration.as_secs_f64(), + error + ); + } } } @@ -352,15 +551,3 @@ impl JobQueue { } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_job_queue_config_default() { - let config = JobQueueConfig::default(); - assert_eq!(config.queue_capacity, 1000); - assert_eq!(config.worker_count, 2); - } -} diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index c1cb4ef3b..f521f3839 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -4,6 +4,7 @@ //! - Worker trait and implementations //! - Error types //! - Circuit breaker for reliability +//! - Workflow engine for multi-step operations //! - Common utilities pub mod circuit_breaker; @@ -15,6 +16,7 @@ pub mod worker; pub mod worker_builder; pub mod worker_manager; pub mod worker_registry; +pub mod workflow; pub use circuit_breaker::{ CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState, @@ -23,8 +25,8 @@ pub use error::{WorkerError, WorkerResult}; pub use job_queue::{Job, JobQueue, JobQueueConfig}; pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor}; pub use worker::{ - start_health_checker, worker_to_info, BasicWorker, ConnectionMode, DPAwareWorker, - HealthChecker, HealthConfig, Worker, WorkerFactory, WorkerLoadGuard, WorkerType, + worker_to_info, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig, + Worker, WorkerFactory, WorkerLoadGuard, WorkerType, }; pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder}; pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager}; diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index a08379471..3f5c05dbb 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -8,7 +8,6 @@ use std::{ }; use async_trait::async_trait; -use futures; use serde_json; use tokio::{sync::RwLock, time}; @@ -910,89 +909,6 @@ impl HealthChecker { } } -/// Start an async background health checker for a collection of workers -pub fn start_health_checker( - workers: Arc>>>, - check_interval_secs: u64, -) -> HealthChecker { - let shutdown = Arc::new(AtomicBool::new(false)); - let shutdown_clone = shutdown.clone(); - - let handle = tokio::spawn(async move { - let mut interval = time::interval(Duration::from_secs(check_interval_secs)); - - // Counter for periodic load reset (every 10 health check cycles) - let mut check_count = 0u64; - const LOAD_RESET_INTERVAL: u64 = 10; - - loop { - interval.tick().await; - - // Check for shutdown signal - if shutdown_clone.load(Ordering::Acquire) { - tracing::debug!("Health checker shutting down"); - break; - } - - check_count += 1; - - // Check health of all workers - let workers_to_check = match workers.read() { - Ok(guard) => guard.clone(), - Err(poisoned) => { - tracing::error!("Worker lock poisoned: {}", poisoned); - continue; - } - }; - - // Periodically reset load counters to prevent drift - // Only do this when we believe all workers should be idle - if check_count.is_multiple_of(LOAD_RESET_INTERVAL) { - let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0); - // Only reset if load appears to be very low (likely drift) - if max_load <= 2 { - tracing::debug!( - "Resetting load counters to prevent drift (max_load: {})", - max_load - ); - for worker in &workers_to_check { - worker.reset_load(); - } - } - } - - // Perform health checks concurrently - let health_checks = workers_to_check.iter().map(|worker| { - let worker_url = worker.url().to_string(); - let was_healthy = worker.is_healthy(); - - async move { - match worker.check_health_async().await { - Ok(_) => { - if !was_healthy { - tracing::info!("Worker {} is now healthy", worker_url); - } - } - Err(e) => { - if was_healthy { - tracing::warn!("Worker {} health check failed: {}", worker_url, e); - } else { - // Worker was already unhealthy, log at debug level - tracing::debug!("Worker {} remains unhealthy: {}", worker_url, e); - } - } - } - } - }); - - // Execute all health checks concurrently - futures::future::join_all(health_checks).await; - } - }); - - HealthChecker { handle, shutdown } -} - /// Helper to convert Worker trait object to WorkerInfo struct pub fn worker_to_info(worker: &Arc) -> WorkerInfo { let worker_type_str = match worker.worker_type() { diff --git a/sgl-router/src/core/workflow/definition.rs b/sgl-router/src/core/workflow/definition.rs new file mode 100644 index 000000000..737d5735b --- /dev/null +++ b/sgl-router/src/core/workflow/definition.rs @@ -0,0 +1,98 @@ +//! Workflow definition types + +use std::{sync::Arc, time::Duration}; + +use super::{ + executor::StepExecutor, + types::{FailureAction, RetryPolicy, StepId, WorkflowId}, +}; + +/// Definition of a single step within a workflow +pub struct StepDefinition { + pub id: StepId, + pub name: String, + pub executor: Arc, + pub retry_policy: Option, + pub timeout: Option, + pub on_failure: FailureAction, +} + +impl StepDefinition { + pub fn new( + id: impl Into, + name: impl Into, + executor: Arc, + ) -> Self { + Self { + id: StepId::new(id.into()), + name: name.into(), + executor, + retry_policy: None, + timeout: None, + on_failure: FailureAction::FailWorkflow, + } + } + + pub fn with_retry(mut self, policy: RetryPolicy) -> Self { + self.retry_policy = Some(policy); + self + } + + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + pub fn with_failure_action(mut self, action: FailureAction) -> Self { + self.on_failure = action; + self + } +} + +/// Complete workflow definition +pub struct WorkflowDefinition { + pub id: WorkflowId, + pub name: String, + pub steps: Vec, + pub default_retry_policy: RetryPolicy, + pub default_timeout: Duration, +} + +impl WorkflowDefinition { + pub fn new(id: impl Into, name: impl Into) -> Self { + Self { + id: WorkflowId::new(id.into()), + name: name.into(), + steps: Vec::new(), + default_retry_policy: RetryPolicy::default(), + default_timeout: Duration::from_secs(300), // 5 minutes + } + } + + pub fn add_step(mut self, step: StepDefinition) -> Self { + self.steps.push(step); + self + } + + pub fn with_default_retry(mut self, policy: RetryPolicy) -> Self { + self.default_retry_policy = policy; + self + } + + pub fn with_default_timeout(mut self, timeout: Duration) -> Self { + self.default_timeout = timeout; + self + } + + /// Get the retry policy for a step (step-specific or default) + pub fn get_retry_policy<'a>(&'a self, step: &'a StepDefinition) -> &'a RetryPolicy { + step.retry_policy + .as_ref() + .unwrap_or(&self.default_retry_policy) + } + + /// Get the timeout for a step (step-specific or default) + pub fn get_timeout(&self, step: &StepDefinition) -> Duration { + step.timeout.unwrap_or(self.default_timeout) + } +} diff --git a/sgl-router/src/core/workflow/engine.rs b/sgl-router/src/core/workflow/engine.rs new file mode 100644 index 000000000..630b36304 --- /dev/null +++ b/sgl-router/src/core/workflow/engine.rs @@ -0,0 +1,484 @@ +//! Workflow execution engine + +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; +use chrono::Utc; +use parking_lot::RwLock; +use tokio::time::timeout; + +use super::{ + definition::{StepDefinition, WorkflowDefinition}, + event::{EventBus, WorkflowEvent}, + state::WorkflowStateStore, + types::*, +}; + +/// Linear backoff implementation that increases delay by a fixed amount each retry +struct LinearBackoff { + current: Duration, + increment: Duration, + max: Duration, +} + +impl LinearBackoff { + fn new(increment: Duration, max: Duration) -> Self { + Self { + current: increment, + increment, + max, + } + } +} + +impl Backoff for LinearBackoff { + fn next_backoff(&mut self) -> Option { + let next = self.current; + self.current = (self.current + self.increment).min(self.max); + Some(next) + } + + fn reset(&mut self) { + self.current = self.increment; + } +} + +/// Main workflow execution engine +pub struct WorkflowEngine { + definitions: Arc>>>, + state_store: WorkflowStateStore, + event_bus: Arc, +} + +impl WorkflowEngine { + pub fn new() -> Self { + Self { + definitions: Arc::new(RwLock::new(HashMap::new())), + state_store: WorkflowStateStore::new(), + event_bus: Arc::new(EventBus::new()), + } + } + + /// Start a background task to periodically clean up old workflow states + /// + /// This prevents unbounded memory growth by removing completed/failed workflows + /// that are older than the specified TTL. + /// + /// # Arguments + /// + /// * `ttl` - Time-to-live for terminal workflows (default: 1 hour) + /// * `interval` - How often to run cleanup (default: 5 minutes) + /// + /// # Returns + /// + /// A join handle for the cleanup task that can be used to stop it. + pub fn start_cleanup_task( + &self, + ttl: Option, + interval: Option, + ) -> tokio::task::JoinHandle<()> { + let state_store = self.state_store.clone(); + let ttl = ttl.unwrap_or(Duration::from_secs(3600)); // 1 hour default + let interval = interval.unwrap_or(Duration::from_secs(300)); // 5 minutes default + + tokio::spawn(async move { + let mut ticker = tokio::time::interval(interval); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + ticker.tick().await; + state_store.cleanup_old_workflows(ttl); + } + }) + } + + /// Register a workflow definition + pub fn register_workflow(&self, definition: WorkflowDefinition) { + let id = definition.id.clone(); + self.definitions.write().insert(id, Arc::new(definition)); + } + + /// Get the event bus for subscribing to workflow events + pub fn event_bus(&self) -> Arc { + Arc::clone(&self.event_bus) + } + + /// Get the state store + pub fn state_store(&self) -> &WorkflowStateStore { + &self.state_store + } + + /// Start a new workflow instance + pub async fn start_workflow( + &self, + definition_id: WorkflowId, + context: WorkflowContext, + ) -> WorkflowResult { + // Get workflow definition + let definition = { + let definitions = self.definitions.read(); + definitions + .get(&definition_id) + .cloned() + .ok_or_else(|| WorkflowError::DefinitionNotFound(definition_id.clone()))? + }; + + // Create new workflow instance + let instance_id = context.instance_id; + let mut state = WorkflowState::new(instance_id, definition_id.clone()); + state.status = WorkflowStatus::Running; + state.context = context; + + // Initialize step states + for step in &definition.steps { + state + .step_states + .insert(step.id.clone(), StepState::default()); + } + + // Save initial state + self.state_store.save(state)?; + + // Emit workflow started event + self.event_bus + .publish(WorkflowEvent::WorkflowStarted { + instance_id, + definition_id, + }) + .await; + + // Execute workflow in background + let engine = self.clone_for_execution(); + let def = Arc::clone(&definition); + tokio::spawn(async move { + if let Err(e) = engine.execute_workflow(instance_id, def).await { + tracing::error!(instance_id = %instance_id, error = ?e, "Workflow execution failed"); + } + }); + + Ok(instance_id) + } + + /// Execute a workflow (internal) + async fn execute_workflow( + &self, + instance_id: WorkflowInstanceId, + definition: Arc, + ) -> WorkflowResult<()> { + let start_time = std::time::Instant::now(); + + for step in &definition.steps { + // Check if workflow was cancelled + let state = self.state_store.load(instance_id)?; + if state.status == WorkflowStatus::Cancelled { + self.event_bus + .publish(WorkflowEvent::WorkflowCancelled { instance_id }) + .await; + return Ok(()); + } + + // Execute step with retry + match self + .execute_step_with_retry(instance_id, step, &definition) + .await + { + Ok(StepResult::Success) => { + // Continue to next step + } + Ok(StepResult::Skip) => { + // Step was skipped, continue to next + continue; + } + Ok(StepResult::Failure) | Err(_) => { + // Handle failure based on failure action + match step.on_failure { + FailureAction::FailWorkflow => { + let error_msg = format!("Step {} failed", step.id); + self.state_store.update(instance_id, |s| { + s.status = WorkflowStatus::Failed; + })?; + + self.event_bus + .publish(WorkflowEvent::WorkflowFailed { + instance_id, + failed_step: step.id.clone(), + error: error_msg, + }) + .await; + + return Ok(()); + } + FailureAction::ContinueNextStep => { + // Mark step as skipped and continue + self.state_store.update(instance_id, |s| { + if let Some(step_state) = s.step_states.get_mut(&step.id) { + step_state.status = StepStatus::Skipped; + } + })?; + continue; + } + FailureAction::RetryIndefinitely => { + // This should not happen as execute_step_with_retry handles it + unreachable!("RetryIndefinitely should be handled in retry logic"); + } + } + } + } + } + + // Workflow completed successfully + self.state_store.update(instance_id, |s| { + s.status = WorkflowStatus::Completed; + })?; + + let duration = start_time.elapsed(); + self.event_bus + .publish(WorkflowEvent::WorkflowCompleted { + instance_id, + duration, + }) + .await; + + Ok(()) + } + + /// Execute a step with retry logic + async fn execute_step_with_retry( + &self, + instance_id: WorkflowInstanceId, + step: &StepDefinition, + definition: &WorkflowDefinition, + ) -> WorkflowResult { + let retry_policy = definition.get_retry_policy(step); + let step_timeout = definition.get_timeout(step); + + let mut attempt = 1; + let max_attempts = if matches!(step.on_failure, FailureAction::RetryIndefinitely) { + u32::MAX + } else { + retry_policy.max_attempts + }; + + let mut backoff = Self::create_backoff(&retry_policy.backoff); + + loop { + // Check for cancellation before starting/retrying step + { + let state = self.state_store.load(instance_id)?; + if state.status == WorkflowStatus::Cancelled { + return Err(WorkflowError::Cancelled(instance_id)); + } + } + + // Update step state + self.state_store.update(instance_id, |s| { + s.current_step = Some(step.id.clone()); + if let Some(step_state) = s.step_states.get_mut(&step.id) { + step_state.status = if attempt == 1 { + StepStatus::Running + } else { + StepStatus::Retrying + }; + step_state.attempt = attempt; + step_state.started_at = Some(Utc::now()); + } + })?; + + // Emit step started event + self.event_bus + .publish(WorkflowEvent::StepStarted { + instance_id, + step_id: step.id.clone(), + attempt, + }) + .await; + + // Get current context + let mut context = self.state_store.load(instance_id)?.context; + + // Execute step with timeout + let step_start = std::time::Instant::now(); + let result = timeout(step_timeout, step.executor.execute(&mut context)).await; + + let step_duration = step_start.elapsed(); + + // Save updated context + self.state_store.update(instance_id, |s| { + s.context = context.clone(); + })?; + + match result { + Ok(Ok(StepResult::Success)) => { + // Step succeeded + self.state_store.update(instance_id, |s| { + if let Some(step_state) = s.step_states.get_mut(&step.id) { + step_state.status = StepStatus::Succeeded; + step_state.completed_at = Some(Utc::now()); + } + })?; + + self.event_bus + .publish(WorkflowEvent::StepSucceeded { + instance_id, + step_id: step.id.clone(), + duration: step_duration, + }) + .await; + + // Call on_success hook + if let Err(e) = step.executor.on_success(&context).await { + tracing::warn!(step_id = %step.id, error = ?e, "on_success hook failed"); + } + + return Ok(StepResult::Success); + } + Ok(Ok(StepResult::Skip)) => { + return Ok(StepResult::Skip); + } + Ok(Ok(StepResult::Failure)) | Ok(Err(_)) | Err(_) => { + let (error_msg, should_retry) = match result { + Ok(Err(e)) => { + let msg = format!("{}", e); + let retryable = step.executor.is_retryable(&e); + (msg, retryable) + } + Err(_) => ( + format!("Step timeout after {:?}", step_timeout), + true, // Timeouts are retryable + ), + _ => ("Step failed".to_string(), false), + }; + + let will_retry = should_retry && attempt < max_attempts; + + // Update step state + self.state_store.update(instance_id, |s| { + if let Some(step_state) = s.step_states.get_mut(&step.id) { + step_state.status = if will_retry { + StepStatus::Retrying + } else { + StepStatus::Failed + }; + step_state.last_error = Some(error_msg.clone()); + if !will_retry { + step_state.completed_at = Some(Utc::now()); + } + } + })?; + + // Emit step failed event + self.event_bus + .publish(WorkflowEvent::StepFailed { + instance_id, + step_id: step.id.clone(), + error: error_msg.clone(), + will_retry, + }) + .await; + + if will_retry { + // Calculate backoff delay + let delay = backoff + .next_backoff() + .unwrap_or_else(|| Duration::from_secs(1)); + + self.event_bus + .publish(WorkflowEvent::StepRetrying { + instance_id, + step_id: step.id.clone(), + attempt: attempt + 1, + delay, + }) + .await; + + tokio::time::sleep(delay).await; + attempt += 1; + } else { + // No more retries, call on_failure hook + // Create a generic error for the hook + let hook_error = WorkflowError::StepFailed { + step_id: step.id.clone(), + message: error_msg, + }; + if let Err(hook_err) = step.executor.on_failure(&context, &hook_error).await + { + tracing::warn!(step_id = %step.id, error = ?hook_err, "on_failure hook failed"); + } + + return Ok(StepResult::Failure); + } + } + } + } + } + + /// Create a backoff instance from strategy + fn create_backoff(strategy: &BackoffStrategy) -> Box { + match strategy { + BackoffStrategy::Fixed(duration) => { + // For fixed backoff, use exponential with multiplier 1.0 + let backoff = ExponentialBackoffBuilder::new() + .with_initial_interval(*duration) + .with_multiplier(1.0) + .with_max_interval(*duration) + .with_max_elapsed_time(None) + .build(); + Box::new(backoff) + } + BackoffStrategy::Exponential { base, max } => { + let backoff = ExponentialBackoffBuilder::new() + .with_initial_interval(*base) + .with_max_interval(*max) + .with_max_elapsed_time(None) + .build(); + Box::new(backoff) + } + BackoffStrategy::Linear { increment, max } => { + // Use proper linear backoff: increment, 2*increment, 3*increment, ... + Box::new(LinearBackoff::new(*increment, *max)) + } + } + } + + /// Cancel a running workflow + pub async fn cancel_workflow(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<()> { + self.state_store.update(instance_id, |s| { + s.status = WorkflowStatus::Cancelled; + })?; + + self.event_bus + .publish(WorkflowEvent::WorkflowCancelled { instance_id }) + .await; + + Ok(()) + } + + /// Get workflow status + pub fn get_status(&self, instance_id: WorkflowInstanceId) -> WorkflowResult { + self.state_store.load(instance_id) + } + + /// Clone engine for async execution + fn clone_for_execution(&self) -> Self { + Self { + definitions: Arc::clone(&self.definitions), + state_store: self.state_store.clone(), + event_bus: Arc::clone(&self.event_bus), + } + } +} + +impl Default for WorkflowEngine { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for WorkflowEngine { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WorkflowEngine") + .field("definitions_count", &self.definitions.read().len()) + .field("state_count", &self.state_store.count()) + .finish() + } +} diff --git a/sgl-router/src/core/workflow/event.rs b/sgl-router/src/core/workflow/event.rs new file mode 100644 index 000000000..1fd5202e8 --- /dev/null +++ b/sgl-router/src/core/workflow/event.rs @@ -0,0 +1,188 @@ +//! Workflow event system for observability and monitoring + +use std::{sync::Arc, time::Duration}; + +use async_trait::async_trait; +use tokio::sync::RwLock; +use tracing::{error, info, warn}; + +use super::types::{StepId, WorkflowId, WorkflowInstanceId}; + +/// Events emitted by the workflow engine +#[derive(Debug, Clone)] +pub enum WorkflowEvent { + WorkflowStarted { + instance_id: WorkflowInstanceId, + definition_id: WorkflowId, + }, + StepStarted { + instance_id: WorkflowInstanceId, + step_id: StepId, + attempt: u32, + }, + StepSucceeded { + instance_id: WorkflowInstanceId, + step_id: StepId, + duration: Duration, + }, + StepFailed { + instance_id: WorkflowInstanceId, + step_id: StepId, + error: String, + will_retry: bool, + }, + StepRetrying { + instance_id: WorkflowInstanceId, + step_id: StepId, + attempt: u32, + delay: Duration, + }, + WorkflowCompleted { + instance_id: WorkflowInstanceId, + duration: Duration, + }, + WorkflowFailed { + instance_id: WorkflowInstanceId, + failed_step: StepId, + error: String, + }, + WorkflowCancelled { + instance_id: WorkflowInstanceId, + }, +} + +/// Trait for subscribing to workflow events +#[async_trait] +pub trait EventSubscriber: Send + Sync { + async fn on_event(&self, event: &WorkflowEvent); +} + +/// Event bus for publishing and subscribing to workflow events +pub struct EventBus { + subscribers: Arc>>>, +} + +impl EventBus { + pub fn new() -> Self { + Self { + subscribers: Arc::new(RwLock::new(Vec::new())), + } + } + + /// Subscribe to workflow events + pub async fn subscribe(&self, subscriber: Arc) { + self.subscribers.write().await.push(subscriber); + } + + /// Publish an event to all subscribers + pub async fn publish(&self, event: WorkflowEvent) { + let subscribers = self.subscribers.read().await; + for subscriber in subscribers.iter() { + subscriber.on_event(&event).await; + } + } +} + +impl Default for EventBus { + fn default() -> Self { + Self::new() + } +} + +/// Logging subscriber that logs events using tracing +pub struct LoggingSubscriber; + +#[async_trait] +impl EventSubscriber for LoggingSubscriber { + async fn on_event(&self, event: &WorkflowEvent) { + match event { + WorkflowEvent::WorkflowStarted { + instance_id, + definition_id, + } => { + info!( + instance_id = %instance_id, + definition_id = %definition_id, + "Workflow started" + ); + } + WorkflowEvent::StepStarted { + instance_id, + step_id, + attempt, + } => { + info!( + instance_id = %instance_id, + step_id = %step_id, + attempt = attempt, + "Step started" + ); + } + WorkflowEvent::StepSucceeded { + instance_id, + step_id, + duration, + } => { + info!( + instance_id = %instance_id, + step_id = %step_id, + duration_ms = duration.as_millis(), + "Step succeeded" + ); + } + WorkflowEvent::StepFailed { + instance_id, + step_id, + error, + will_retry, + } => { + warn!( + instance_id = %instance_id, + step_id = %step_id, + error = error, + will_retry = will_retry, + "Step failed" + ); + } + WorkflowEvent::StepRetrying { + instance_id, + step_id, + attempt, + delay, + } => { + info!( + instance_id = %instance_id, + step_id = %step_id, + attempt = attempt, + delay_ms = delay.as_millis(), + "Step retrying" + ); + } + WorkflowEvent::WorkflowCompleted { + instance_id, + duration, + } => { + info!( + instance_id = %instance_id, + duration_ms = duration.as_millis(), + "Workflow completed" + ); + } + WorkflowEvent::WorkflowFailed { + instance_id, + failed_step, + error, + } => { + error!( + instance_id = %instance_id, + failed_step = %failed_step, + error = error, + "Workflow failed" + ); + } + WorkflowEvent::WorkflowCancelled { instance_id } => { + info!(instance_id = %instance_id, "Workflow cancelled"); + } + } + } +} diff --git a/sgl-router/src/core/workflow/executor.rs b/sgl-router/src/core/workflow/executor.rs new file mode 100644 index 000000000..1a27c1922 --- /dev/null +++ b/sgl-router/src/core/workflow/executor.rs @@ -0,0 +1,129 @@ +//! Step executor trait and implementations + +use async_trait::async_trait; + +use super::types::{StepResult, WorkflowContext, WorkflowError, WorkflowResult}; + +/// Trait for executing individual workflow steps +#[async_trait] +pub trait StepExecutor: Send + Sync { + /// Execute the step with the given context + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult; + + /// Check if an error is retry-able + /// + /// Override this method to customize which errors should trigger retries. + /// By default, all errors are considered retry-able. + fn is_retryable(&self, _error: &WorkflowError) -> bool { + true + } + + /// Called when the step succeeds + /// + /// This hook allows steps to perform cleanup or additional actions + /// after successful execution. + async fn on_success(&self, _context: &WorkflowContext) -> WorkflowResult<()> { + Ok(()) + } + + /// Called when the step fails after all retries + /// + /// This hook allows steps to perform cleanup or compensation logic + /// when the step cannot complete successfully. + async fn on_failure( + &self, + _context: &WorkflowContext, + _error: &WorkflowError, + ) -> WorkflowResult<()> { + Ok(()) + } +} + +/// Simple function-based step executor +pub struct FunctionStep +where + F: Fn( + &mut WorkflowContext, + ) -> std::pin::Pin< + Box> + Send + '_>, + > + Send + + Sync, +{ + func: F, +} + +impl FunctionStep +where + F: Fn( + &mut WorkflowContext, + ) -> std::pin::Pin< + Box> + Send + '_>, + > + Send + + Sync, +{ + pub fn new(func: F) -> Self { + Self { func } + } +} + +#[async_trait] +impl StepExecutor for FunctionStep +where + F: Fn( + &mut WorkflowContext, + ) -> std::pin::Pin< + Box> + Send + '_>, + > + Send + + Sync, +{ + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + (self.func)(context).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::workflow::types::WorkflowInstanceId; + + struct TestStep { + should_succeed: bool, + } + + #[async_trait] + impl StepExecutor for TestStep { + async fn execute(&self, _context: &mut WorkflowContext) -> WorkflowResult { + if self.should_succeed { + Ok(StepResult::Success) + } else { + Err(WorkflowError::StepFailed { + step_id: crate::core::workflow::types::StepId::new("test"), + message: "test error".to_string(), + }) + } + } + } + + #[tokio::test] + async fn test_step_executor_success() { + let step = TestStep { + should_succeed: true, + }; + let mut context = WorkflowContext::new(WorkflowInstanceId::new()); + + let result = step.execute(&mut context).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), StepResult::Success); + } + + #[tokio::test] + async fn test_step_executor_failure() { + let step = TestStep { + should_succeed: false, + }; + let mut context = WorkflowContext::new(WorkflowInstanceId::new()); + + let result = step.execute(&mut context).await; + assert!(result.is_err()); + } +} diff --git a/sgl-router/src/core/workflow/mod.rs b/sgl-router/src/core/workflow/mod.rs new file mode 100644 index 000000000..8e840b2ea --- /dev/null +++ b/sgl-router/src/core/workflow/mod.rs @@ -0,0 +1,18 @@ +//! Workflow engine for managing multi-step operations + +mod definition; +mod engine; +mod event; +mod executor; +mod state; +pub mod steps; +pub mod types; + +// Re-export main types +pub use definition::{StepDefinition, WorkflowDefinition}; +pub use engine::WorkflowEngine; +pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent}; +pub use executor::{FunctionStep, StepExecutor}; +pub use state::WorkflowStateStore; +pub use steps::create_worker_registration_workflow; +pub use types::*; diff --git a/sgl-router/src/core/workflow/state.rs b/sgl-router/src/core/workflow/state.rs new file mode 100644 index 000000000..2811903b1 --- /dev/null +++ b/sgl-router/src/core/workflow/state.rs @@ -0,0 +1,175 @@ +//! Workflow state management + +use std::{collections::HashMap, sync::Arc}; + +use parking_lot::RwLock; + +use super::types::{ + WorkflowError, WorkflowInstanceId, WorkflowResult, WorkflowState, WorkflowStatus, +}; + +/// In-memory state storage for workflow instances +#[derive(Clone)] +pub struct WorkflowStateStore { + states: Arc>>, +} + +impl WorkflowStateStore { + pub fn new() -> Self { + Self { + states: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Save workflow state + /// + /// # Warning + /// + /// This emits a warning if the workflow context contains unserializable data, + /// which would be lost if state persistence is later implemented. + pub fn save(&self, state: WorkflowState) -> WorkflowResult<()> { + if state.context.has_unserializable_data() { + tracing::warn!( + instance_id = %state.instance_id, + data_count = state.context.data_len(), + "Saving workflow state with {} unserializable context entries. \ + This data cannot be persisted and will be lost on restart.", + state.context.data_len() + ); + } + self.states.write().insert(state.instance_id, state); + Ok(()) + } + + /// Load workflow state by instance ID + pub fn load(&self, instance_id: WorkflowInstanceId) -> WorkflowResult { + self.states + .read() + .get(&instance_id) + .cloned() + .ok_or(WorkflowError::NotFound(instance_id)) + } + + /// List all active workflows (Running or Pending) + pub fn list_active(&self) -> WorkflowResult> { + let states = self.states.read(); + Ok(states + .values() + .filter(|s| matches!(s.status, WorkflowStatus::Running | WorkflowStatus::Pending)) + .cloned() + .collect()) + } + + /// List all workflows + pub fn list_all(&self) -> WorkflowResult> { + let states = self.states.read(); + Ok(states.values().cloned().collect()) + } + + /// Delete workflow state + pub fn delete(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<()> { + self.states.write().remove(&instance_id); + Ok(()) + } + + /// Update workflow state using a closure + pub fn update(&self, instance_id: WorkflowInstanceId, f: F) -> WorkflowResult<()> + where + F: FnOnce(&mut WorkflowState), + { + let mut states = self.states.write(); + let state = states + .get_mut(&instance_id) + .ok_or(WorkflowError::NotFound(instance_id))?; + f(state); + state.updated_at = chrono::Utc::now(); + Ok(()) + } + + /// Get count of workflows by status + pub fn count_by_status(&self, status: WorkflowStatus) -> usize { + self.states + .read() + .values() + .filter(|s| s.status == status) + .count() + } + + /// Get total count of all workflows + pub fn count(&self) -> usize { + self.states.read().len() + } + + /// Clean up old completed/failed/cancelled workflows beyond a time threshold + /// + /// This prevents unbounded memory growth by removing workflow states that + /// have been in a terminal state (Completed, Failed, Cancelled) for longer + /// than the specified TTL (time-to-live). + /// + /// Active workflows (Running, Pending, Paused) are never cleaned up. + /// + /// # Arguments + /// + /// * `ttl` - Time-to-live for terminal workflows. Workflows in terminal states + /// older than this will be removed. + /// + /// # Returns + /// + /// The number of workflow states removed. + pub fn cleanup_old_workflows(&self, ttl: std::time::Duration) -> usize { + let now = chrono::Utc::now(); + let mut states = self.states.write(); + let initial_count = states.len(); + + states.retain(|_, state| { + // Keep active workflows + if matches!( + state.status, + WorkflowStatus::Running | WorkflowStatus::Pending | WorkflowStatus::Paused + ) { + return true; + } + + // For terminal workflows, check age + let age = now + .signed_duration_since(state.updated_at) + .to_std() + .unwrap_or_default(); + age < ttl + }); + + let removed_count = initial_count - states.len(); + if removed_count > 0 { + tracing::info!( + removed = removed_count, + remaining = states.len(), + "Cleaned up old workflow states" + ); + } + removed_count + } + + /// Clean up a specific completed workflow immediately + /// + /// This is useful for cleaning up workflows right after they complete + /// when you know they won't be queried again. + pub fn cleanup_if_terminal(&self, instance_id: WorkflowInstanceId) -> bool { + let mut states = self.states.write(); + if let Some(state) = states.get(&instance_id) { + if matches!( + state.status, + WorkflowStatus::Completed | WorkflowStatus::Failed | WorkflowStatus::Cancelled + ) { + states.remove(&instance_id); + return true; + } + } + false + } +} + +impl Default for WorkflowStateStore { + fn default() -> Self { + Self::new() + } +} diff --git a/sgl-router/src/core/workflow/steps/mod.rs b/sgl-router/src/core/workflow/steps/mod.rs new file mode 100644 index 000000000..7f4dc8252 --- /dev/null +++ b/sgl-router/src/core/workflow/steps/mod.rs @@ -0,0 +1,12 @@ +//! Workflow step implementations +//! +//! This module contains concrete step implementations for various workflows: +//! - Worker registration and activation +//! - Future: Tokenizer fetching, LoRA updates, etc. + +pub mod worker_registration; + +pub use worker_registration::{ + create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep, + DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep, +}; diff --git a/sgl-router/src/core/workflow/steps/worker_registration.rs b/sgl-router/src/core/workflow/steps/worker_registration.rs new file mode 100644 index 000000000..06916a866 --- /dev/null +++ b/sgl-router/src/core/workflow/steps/worker_registration.rs @@ -0,0 +1,837 @@ +//! Worker registration workflow steps +//! +//! Each step is atomic and performs a single operation in the worker registration process. +//! +//! Workflow order: +//! 1. DetectConnectionMode - Probe both HTTP and gRPC to determine connection mode +//! 2. DiscoverMetadata - Fetch metadata from the worker +//! 3. DiscoverDPInfo - Fetch DP (Data Parallel) information (only for DP-aware workers) +//! 4. CreateWorker - Build worker object(s) with merged config + metadata +//! 5. RegisterWorker - Register worker(s) in registry +//! 6. UpdatePolicies - Update policy registry with worker information +//! 7. ActivateWorker - Mark worker(s) as healthy + +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use async_trait::async_trait; +use once_cell::sync::Lazy; +use reqwest::Client; +use serde_json::Value; +use tracing::{info, warn}; + +use crate::{ + core::{ + workflow::*, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, + DPAwareWorkerBuilder, DpInfo, HealthConfig, Worker, WorkerManager, WorkerType, + }, + grpc_client::SglangSchedulerClient, + protocols::worker_spec::WorkerConfigRequest, + server::AppContext, +}; + +// HTTP client for metadata fetching +static HTTP_CLIENT: Lazy = Lazy::new(|| { + Client::builder() + .timeout(Duration::from_secs(10)) + .build() + .expect("Failed to create HTTP client") +}); + +/// Helper: Strip protocol prefix from URL +fn strip_protocol(url: &str) -> String { + url.trim_start_matches("http://") + .trim_start_matches("https://") + .trim_start_matches("grpc://") + .to_string() +} + +/// Helper: Try HTTP health check +async fn try_http_health_check(url: &str, timeout_secs: u64) -> Result<(), String> { + let clean_url = strip_protocol(url); + let health_url = format!("http://{}/health", clean_url); + + HTTP_CLIENT + .get(&health_url) + .timeout(Duration::from_secs(timeout_secs)) + .send() + .await + .map_err(|e| format!("HTTP health check failed: {}", e))?; + + Ok(()) +} + +/// Helper: Try gRPC health check +async fn try_grpc_health_check(url: &str, timeout_secs: u64) -> Result<(), String> { + let grpc_url = if url.starts_with("grpc://") { + url.to_string() + } else { + format!("grpc://{}", strip_protocol(url)) + }; + + let connect_future = SglangSchedulerClient::connect(&grpc_url); + let client = tokio::time::timeout(Duration::from_secs(timeout_secs), connect_future) + .await + .map_err(|_| "gRPC connection timeout".to_string())? + .map_err(|e| format!("gRPC connection failed: {}", e))?; + + let health_future = client.health_check(); + tokio::time::timeout(Duration::from_secs(timeout_secs), health_future) + .await + .map_err(|_| "gRPC health check timeout".to_string())? + .map_err(|e| format!("gRPC health check failed: {}", e))?; + + Ok(()) +} + +/// Helper: Fetch HTTP metadata +async fn fetch_http_metadata( + url: &str, + api_key: Option<&str>, +) -> Result, String> { + let clean_url = strip_protocol(url); + let info_url = if clean_url.starts_with("http://") || clean_url.starts_with("https://") { + format!("{}/get_server_info", clean_url) + } else { + format!("http://{}/get_server_info", clean_url) + }; + + let mut request = HTTP_CLIENT.get(&info_url); + if let Some(key) = api_key { + request = request.header("Authorization", format!("Bearer {}", key)); + } + + let response = request + .send() + .await + .map_err(|e| format!("Failed to fetch HTTP metadata: {}", e))?; + + let server_info: Value = response + .json() + .await + .map_err(|e| format!("Failed to parse HTTP metadata: {}", e))?; + + let mut labels = HashMap::new(); + + if let Some(model_path) = server_info.get("model_path").and_then(|v| v.as_str()) { + if !model_path.is_empty() { + labels.insert("model_path".to_string(), model_path.to_string()); + } + } + if let Some(tokenizer_path) = server_info.get("tokenizer_path").and_then(|v| v.as_str()) { + if !tokenizer_path.is_empty() { + labels.insert("tokenizer_path".to_string(), tokenizer_path.to_string()); + } + } + + Ok(labels) +} + +/// Helper: Fetch gRPC metadata +async fn fetch_grpc_metadata(url: &str) -> Result, String> { + let grpc_url = if url.starts_with("grpc://") { + url.to_string() + } else { + format!("grpc://{}", strip_protocol(url)) + }; + + let client = SglangSchedulerClient::connect(&grpc_url) + .await + .map_err(|e| format!("Failed to connect to gRPC: {}", e))?; + + let model_info = client + .get_model_info() + .await + .map_err(|e| format!("Failed to fetch gRPC metadata: {}", e))?; + + let mut labels = HashMap::new(); + + // Extract all available fields + if !model_info.model_path.is_empty() { + labels.insert("model_path".to_string(), model_info.model_path.clone()); + } + if !model_info.tokenizer_path.is_empty() { + labels.insert( + "tokenizer_path".to_string(), + model_info.tokenizer_path.clone(), + ); + } + if !model_info.served_model_name.is_empty() { + labels.insert( + "served_model_name".to_string(), + model_info.served_model_name.clone(), + ); + } + if !model_info.weight_version.is_empty() { + labels.insert( + "weight_version".to_string(), + model_info.weight_version.clone(), + ); + } + if !model_info.model_type.is_empty() { + labels.insert("model_type".to_string(), model_info.model_type.clone()); + } + if model_info.max_context_length > 0 { + labels.insert( + "max_context_length".to_string(), + model_info.max_context_length.to_string(), + ); + } + if model_info.max_req_input_len > 0 { + labels.insert( + "max_req_input_len".to_string(), + model_info.max_req_input_len.to_string(), + ); + } + if model_info.vocab_size > 0 { + labels.insert("vocab_size".to_string(), model_info.vocab_size.to_string()); + } + if model_info.is_generation { + labels.insert("is_generation".to_string(), "true".to_string()); + } + + Ok(labels) +} + +/// Step 1: Detect connection mode by probing both HTTP and gRPC +pub struct DetectConnectionModeStep; + +#[async_trait] +impl StepExecutor for DetectConnectionModeStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let config: Arc = context + .get("worker_config") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?; + + info!( + "Detecting connection mode for {} (timeout: {}s, max_attempts: {})", + config.url, config.health_check_timeout_secs, config.max_connection_attempts + ); + + // Try both protocols in parallel using configured timeout + let url = config.url.clone(); + let timeout = config.health_check_timeout_secs; + let (http_result, grpc_result) = tokio::join!( + try_http_health_check(&url, timeout), + try_grpc_health_check(&url, timeout) + ); + + let connection_mode = match (http_result, grpc_result) { + (Ok(_), _) => { + info!("{} detected as HTTP", config.url); + ConnectionMode::Http + } + (_, Ok(_)) => { + info!("{} detected as gRPC", config.url); + ConnectionMode::Grpc { port: None } + } + (Err(http_err), Err(grpc_err)) => { + return Err(WorkflowError::StepFailed { + step_id: StepId::new("detect_connection_mode"), + message: format!( + "Both HTTP and gRPC health checks failed for {}: HTTP: {}, gRPC: {}", + config.url, http_err, grpc_err + ), + }); + } + }; + + // Store connection mode in context + context.set("connection_mode", connection_mode); + + Ok(StepResult::Success) + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + true // Connection issues are retryable + } +} + +/// Step 2: Discover metadata from worker +pub struct DiscoverMetadataStep; + +#[async_trait] +impl StepExecutor for DiscoverMetadataStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let config: Arc = context + .get("worker_config") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?; + let connection_mode: Arc = context + .get("connection_mode") + .ok_or_else(|| WorkflowError::ContextValueNotFound("connection_mode".to_string()))?; + + info!( + "Discovering metadata for {} ({:?})", + config.url, *connection_mode + ); + + let discovered_labels = match connection_mode.as_ref() { + ConnectionMode::Http => { + fetch_http_metadata(&config.url, config.api_key.as_deref()).await + } + ConnectionMode::Grpc { .. } => fetch_grpc_metadata(&config.url).await, + } + .unwrap_or_else(|e| { + warn!("Failed to fetch metadata for {}: {}", config.url, e); + HashMap::new() + }); + + info!( + "Discovered {} metadata labels for {}", + discovered_labels.len(), + config.url + ); + + // Store discovered labels in context + context.set("discovered_labels", discovered_labels); + + Ok(StepResult::Success) + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + true // Metadata discovery failures are retryable + } +} + +/// Step 2.5: Discover DP (Data Parallel) information (only for DP-aware workers) +pub struct DiscoverDPInfoStep; + +#[async_trait] +impl StepExecutor for DiscoverDPInfoStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let config: Arc = context + .get("worker_config") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?; + + // Skip DP discovery if not DP-aware + if !config.dp_aware { + info!( + "Worker {} is not DP-aware, skipping DP discovery", + config.url + ); + return Ok(StepResult::Success); + } + + info!("Discovering DP info for {} (DP-aware)", config.url); + + // Get DP info from worker + let dp_info = WorkerManager::get_dp_info(&config.url, config.api_key.as_deref()) + .await + .map_err(|e| WorkflowError::StepFailed { + step_id: StepId::new("discover_dp_info"), + message: format!("Failed to get DP info: {}", e), + })?; + + info!( + "Discovered DP size {} for {} (model: {})", + dp_info.dp_size, config.url, dp_info.model_id + ); + + // Store DP info in context + context.set("dp_info", Arc::new(dp_info)); + + Ok(StepResult::Success) + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + true // DP info discovery failures are retryable + } +} + +/// Step 3: Create worker object with merged configuration + metadata +pub struct CreateWorkerStep; + +#[async_trait] +impl StepExecutor for CreateWorkerStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let config: Arc = context + .get("worker_config") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?; + let app_context: Arc = context + .get("app_context") + .ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?; + let connection_mode: Arc = context + .get("connection_mode") + .ok_or_else(|| WorkflowError::ContextValueNotFound("connection_mode".to_string()))?; + let discovered_labels: Arc> = context + .get("discovered_labels") + .ok_or_else(|| WorkflowError::ContextValueNotFound("discovered_labels".to_string()))?; + + // Check if worker already exists + if app_context + .worker_registry + .get_by_url(&config.url) + .is_some() + { + return Err(WorkflowError::StepFailed { + step_id: StepId::new("create_worker"), + message: format!("Worker {} already exists", config.url), + }); + } + + // Build labels from config + let mut config_labels = config.labels.clone(); + if let Some(model_id) = &config.model_id { + config_labels.insert("model_id".to_string(), model_id.clone()); + } + if let Some(priority) = config.priority { + config_labels.insert("priority".to_string(), priority.to_string()); + } + if let Some(cost) = config.cost { + config_labels.insert("cost".to_string(), cost.to_string()); + } + if let Some(ref tokenizer_path) = config.tokenizer_path { + config_labels.insert("tokenizer_path".to_string(), tokenizer_path.clone()); + } + if let Some(ref reasoning_parser) = config.reasoning_parser { + config_labels.insert("reasoning_parser".to_string(), reasoning_parser.clone()); + } + if let Some(ref tool_parser) = config.tool_parser { + config_labels.insert("tool_parser".to_string(), tool_parser.clone()); + } + if let Some(ref chat_template) = config.chat_template { + config_labels.insert("chat_template".to_string(), chat_template.clone()); + } + + // Merge: discovered labels first, then config labels (config takes precedence) + let mut final_labels = discovered_labels.as_ref().clone(); + for (key, value) in &config_labels { + final_labels.insert(key.clone(), value.clone()); + } + + // Derive model_id if not already set + if !final_labels.contains_key("model_id") { + let derived_model_id = final_labels + .get("served_model_name") + .or_else(|| final_labels.get("model_path")) + .cloned(); + + if let Some(model_id) = derived_model_id { + info!("Derived model_id from metadata: {}", model_id); + final_labels.insert("model_id".to_string(), model_id); + } + } + + info!( + "Creating worker {} with {} discovered + {} config = {} final labels", + config.url, + discovered_labels.len(), + config_labels.len(), + final_labels.len() + ); + + // Parse worker type + let worker_type = config + .worker_type + .as_ref() + .map(|t| match t.as_str() { + "prefill" => WorkerType::Prefill { + bootstrap_port: config.bootstrap_port, + }, + "decode" => WorkerType::Decode, + _ => WorkerType::Regular, + }) + .unwrap_or(WorkerType::Regular); + + // Build circuit breaker config + let circuit_breaker_config = { + let cfg = app_context.router_config.effective_circuit_breaker_config(); + CircuitBreakerConfig { + failure_threshold: cfg.failure_threshold, + success_threshold: cfg.success_threshold, + timeout_duration: Duration::from_secs(cfg.timeout_duration_secs), + window_duration: Duration::from_secs(cfg.window_duration_secs), + } + }; + + // Build health config + let health_config = { + let cfg = &app_context.router_config.health_check; + HealthConfig { + timeout_secs: cfg.timeout_secs, + check_interval_secs: cfg.check_interval_secs, + endpoint: cfg.endpoint.clone(), + failure_threshold: cfg.failure_threshold, + success_threshold: cfg.success_threshold, + } + }; + + // Normalize URL: add protocol prefix only if missing + let normalized_url = if config.url.starts_with("http://") + || config.url.starts_with("https://") + || config.url.starts_with("grpc://") + { + // URL already has protocol, use as-is + config.url.clone() + } else { + // Bare IP:port format, add appropriate protocol based on detected mode + match connection_mode.as_ref() { + ConnectionMode::Http => format!("http://{}", config.url), + ConnectionMode::Grpc { .. } => format!("grpc://{}", config.url), + } + }; + + if normalized_url != config.url { + info!( + "Normalized worker URL: {} -> {} ({:?})", + config.url, + normalized_url, + connection_mode.as_ref() + ); + } + + // Handle DP-aware vs non-DP-aware workers + if config.dp_aware { + // DP-aware path: Create multiple workers (one per rank) + let dp_info: Arc = context + .get("dp_info") + .ok_or_else(|| WorkflowError::ContextValueNotFound("dp_info".to_string()))?; + + info!( + "Creating {} DP-aware workers for {} (dp_size: {})", + dp_info.dp_size, config.url, dp_info.dp_size + ); + + let mut workers = Vec::new(); + for rank in 0..dp_info.dp_size { + let mut builder = + DPAwareWorkerBuilder::new(normalized_url.clone(), rank, dp_info.dp_size) + .worker_type(worker_type.clone()) + .connection_mode(connection_mode.as_ref().clone()) + .circuit_breaker_config(circuit_breaker_config.clone()) + .health_config(health_config.clone()); + + if let Some(ref api_key) = config.api_key { + builder = builder.api_key(api_key.clone()); + } + + if !final_labels.is_empty() { + builder = builder.labels(final_labels.clone()); + } + + let worker = Arc::new(builder.build()) as Arc; + worker.set_healthy(false); + workers.push(worker); + + info!( + "Created DP-aware worker {}@{}/{} ({:?})", + config.url, + rank, + dp_info.dp_size, + connection_mode.as_ref() + ); + } + + // Store workers (plural) and labels in context + context.set("workers", Arc::new(workers)); + context.set("labels", final_labels); + + Ok(StepResult::Success) + } else { + // Non-DP-aware path: Create single worker + let mut builder = BasicWorkerBuilder::new(normalized_url.clone()) + .worker_type(worker_type) + .connection_mode(connection_mode.as_ref().clone()) + .circuit_breaker_config(circuit_breaker_config) + .health_config(health_config); + + if let Some(ref api_key) = config.api_key { + builder = builder.api_key(api_key.clone()); + } + + if !final_labels.is_empty() { + builder = builder.labels(final_labels.clone()); + } + + let worker = Arc::new(builder.build()) as Arc; + worker.set_healthy(false); + + info!( + "Created worker object for {} ({:?}) with {} labels", + config.url, + connection_mode.as_ref(), + final_labels.len() + ); + + // Store worker (singular) and labels in context + context.set("worker", worker); + context.set("labels", final_labels); + + Ok(StepResult::Success) + } + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + false // Worker creation failures are not retryable (likely config issues) + } +} + +/// Step 4: Register worker(s) in registry +pub struct RegisterWorkerStep; + +#[async_trait] +impl StepExecutor for RegisterWorkerStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let config: Arc = context + .get("worker_config") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?; + let app_context: Arc = context + .get("app_context") + .ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?; + + // Check if we have multiple workers (DP-aware) or single worker + if config.dp_aware { + // DP-aware path: Register multiple workers + let workers: Arc>> = context + .get("workers") + .ok_or_else(|| WorkflowError::ContextValueNotFound("workers".to_string()))?; + + let mut worker_ids = Vec::new(); + for worker in workers.iter() { + let worker_id = app_context.worker_registry.register(Arc::clone(worker)); + worker_ids.push(worker_id.clone()); + info!( + "Registered DP-aware worker {} with ID {:?}", + config.url, worker_id + ); + } + + context.set("worker_ids", Arc::new(worker_ids)); + Ok(StepResult::Success) + } else { + // Non-DP-aware path: Register single worker + let worker: Arc> = context + .get("worker") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker".to_string()))?; + + let worker_id = app_context + .worker_registry + .register(Arc::clone(worker.as_ref())); + + info!("Registered worker {} with ID {:?}", config.url, worker_id); + context.set("worker_id", worker_id); + + Ok(StepResult::Success) + } + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + false // Registration failures are not retryable + } +} + +/// Step 5: Update policy registry with worker information +pub struct UpdatePoliciesStep; + +#[async_trait] +impl StepExecutor for UpdatePoliciesStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let config: Arc = context + .get("worker_config") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?; + let labels: Arc> = context + .get("labels") + .ok_or_else(|| WorkflowError::ContextValueNotFound("labels".to_string()))?; + let app_context: Arc = context + .get("app_context") + .ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?; + + let policy_hint = labels.get("policy").map(|s| s.as_str()); + + // Check if we have multiple workers (DP-aware) or single worker + if config.dp_aware { + // DP-aware path: Update policies for multiple workers + let workers: Arc>> = context + .get("workers") + .ok_or_else(|| WorkflowError::ContextValueNotFound("workers".to_string()))?; + + // Get model_id from first worker (all DP workers have same model) + let model_id = workers[0].model_id().to_string(); + + // Notify policy registry for each worker + for _ in 0..workers.len() { + app_context + .policy_registry + .on_worker_added(&model_id, policy_hint); + } + + // Initialize cache-aware policy if needed + let all_workers = app_context.worker_registry.get_by_model_fast(&model_id); + if let Some(policy) = app_context.policy_registry.get_policy(&model_id) { + if policy.name() == "cache_aware" { + app_context + .policy_registry + .init_cache_aware_policy(&model_id, &all_workers); + } + } + + info!( + "Updated policies for {} DP-aware workers {} (model: {})", + workers.len(), + config.url, + model_id + ); + } else { + // Non-DP-aware path: Update policy for single worker + let worker: Arc> = context + .get("worker") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker".to_string()))?; + + let model_id = worker.model_id().to_string(); + + // Notify policy registry + app_context + .policy_registry + .on_worker_added(&model_id, policy_hint); + + // Initialize cache-aware policy if needed + let all_workers = app_context.worker_registry.get_by_model_fast(&model_id); + if let Some(policy) = app_context.policy_registry.get_policy(&model_id) { + if policy.name() == "cache_aware" { + app_context + .policy_registry + .init_cache_aware_policy(&model_id, &all_workers); + } + } + + info!( + "Updated policies for worker {} (model: {})", + config.url, model_id + ); + } + + Ok(StepResult::Success) + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + false // Policy update failures are not retryable + } +} + +/// Step 6: Activate worker(s) by marking them as healthy +pub struct ActivateWorkerStep; + +#[async_trait] +impl StepExecutor for ActivateWorkerStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let config: Arc = context + .get("worker_config") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?; + + // Check if we have multiple workers (DP-aware) or single worker + if config.dp_aware { + // DP-aware path: Activate multiple workers + let workers: Arc>> = context + .get("workers") + .ok_or_else(|| WorkflowError::ContextValueNotFound("workers".to_string()))?; + + for worker in workers.iter() { + worker.set_healthy(true); + } + + info!( + "Activated {} DP-aware workers {} (marked as healthy)", + workers.len(), + config.url + ); + } else { + // Non-DP-aware path: Activate single worker + let worker: Arc> = context + .get("worker") + .ok_or_else(|| WorkflowError::ContextValueNotFound("worker".to_string()))?; + + worker.set_healthy(true); + + info!("Activated worker {} (marked as healthy)", config.url); + } + + Ok(StepResult::Success) + } + + fn is_retryable(&self, _error: &WorkflowError) -> bool { + false // Activation is just setting a flag, not retryable + } +} + +/// Create worker registration workflow definition +/// +/// Note: Actual health check timeouts and retry attempts are configured per-worker +/// via WorkerConfigRequest (populated from router config). The timeouts and retry +/// policies here serve as workflow-level bounds to prevent infinite waiting. +pub fn create_worker_registration_workflow() -> WorkflowDefinition { + WorkflowDefinition::new("worker_registration", "Worker Registration") + .add_step( + StepDefinition::new( + "detect_connection_mode", + "Detect Connection Mode", + Arc::new(DetectConnectionModeStep), + ) + .with_retry(RetryPolicy { + max_attempts: 100, + backoff: BackoffStrategy::Linear { + increment: Duration::from_secs(1), + max: Duration::from_secs(5), + }, + }) + // Workflow-level timeout (upper bound); step uses config.health_check_timeout_secs + .with_timeout(Duration::from_secs(7200)) // 2 hours max + .with_failure_action(FailureAction::FailWorkflow), + ) + .add_step( + StepDefinition::new( + "discover_metadata", + "Discover Metadata", + Arc::new(DiscoverMetadataStep), + ) + .with_retry(RetryPolicy { + max_attempts: 3, + backoff: BackoffStrategy::Fixed(Duration::from_secs(1)), + }) + .with_timeout(Duration::from_secs(10)) + .with_failure_action(FailureAction::ContinueNextStep), // Metadata discovery is optional + ) + .add_step( + StepDefinition::new( + "discover_dp_info", + "Discover DP Info", + Arc::new(DiscoverDPInfoStep), + ) + .with_retry(RetryPolicy { + max_attempts: 3, + backoff: BackoffStrategy::Fixed(Duration::from_secs(1)), + }) + .with_timeout(Duration::from_secs(10)) + .with_failure_action(FailureAction::FailWorkflow), // DP info is required for DP-aware workers + ) + .add_step( + StepDefinition::new("create_worker", "Create Worker", Arc::new(CreateWorkerStep)) + .with_timeout(Duration::from_secs(5)) + .with_failure_action(FailureAction::FailWorkflow), + ) + .add_step( + StepDefinition::new( + "register_worker", + "Register Worker", + Arc::new(RegisterWorkerStep), + ) + .with_timeout(Duration::from_secs(5)) + .with_failure_action(FailureAction::FailWorkflow), + ) + .add_step( + StepDefinition::new( + "update_policies", + "Update Policies", + Arc::new(UpdatePoliciesStep), + ) + .with_timeout(Duration::from_secs(5)) + .with_failure_action(FailureAction::ContinueNextStep), // Policy updates are optional + ) + .add_step( + StepDefinition::new( + "activate_worker", + "Activate Worker", + Arc::new(ActivateWorkerStep), + ) + .with_timeout(Duration::from_secs(5)) + .with_failure_action(FailureAction::FailWorkflow), + ) +} diff --git a/sgl-router/src/core/workflow/types.rs b/sgl-router/src/core/workflow/types.rs new file mode 100644 index 000000000..e7af5c654 --- /dev/null +++ b/sgl-router/src/core/workflow/types.rs @@ -0,0 +1,271 @@ +//! Core workflow types and definitions + +use std::{collections::HashMap, fmt, sync::Arc, time::Duration}; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Unique identifier for a workflow definition +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct WorkflowId(String); + +impl WorkflowId { + pub fn new(id: impl Into) -> Self { + Self(id.into()) + } +} + +impl fmt::Display for WorkflowId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Unique identifier for a workflow instance +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct WorkflowInstanceId(Uuid); + +impl WorkflowInstanceId { + pub fn new() -> Self { + Self(Uuid::new_v4()) + } +} + +impl Default for WorkflowInstanceId { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Display for WorkflowInstanceId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Unique identifier for a workflow step +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct StepId(String); + +impl StepId { + pub fn new(id: impl Into) -> Self { + Self(id.into()) + } +} + +impl fmt::Display for StepId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Retry policy configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetryPolicy { + pub max_attempts: u32, + pub backoff: BackoffStrategy, +} + +impl Default for RetryPolicy { + fn default() -> Self { + Self { + max_attempts: 3, + backoff: BackoffStrategy::Exponential { + base: Duration::from_secs(1), + max: Duration::from_secs(30), + }, + } + } +} + +/// Backoff strategy for retries +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum BackoffStrategy { + /// Fixed delay between retries + Fixed(Duration), + /// Exponential backoff with base and max duration + Exponential { base: Duration, max: Duration }, + /// Linear backoff with increment and max duration + Linear { increment: Duration, max: Duration }, +} + +/// Action to take when a step fails +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum FailureAction { + /// Stop the entire workflow + FailWorkflow, + /// Skip this step and continue to the next + ContinueNextStep, + /// Keep retrying indefinitely until manual intervention + RetryIndefinitely, +} + +/// Workflow execution status +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum WorkflowStatus { + Pending, + Running, + Paused, + Completed, + Failed, + Cancelled, +} + +/// Step execution status +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum StepStatus { + Pending, + Running, + Succeeded, + Failed, + Retrying, + Skipped, +} + +/// State of a workflow step +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StepState { + pub status: StepStatus, + pub attempt: u32, + pub last_error: Option, + pub started_at: Option>, + pub completed_at: Option>, +} + +impl Default for StepState { + fn default() -> Self { + Self { + status: StepStatus::Pending, + attempt: 0, + last_error: None, + started_at: None, + completed_at: None, + } + } +} + +/// Workflow instance state +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkflowState { + pub instance_id: WorkflowInstanceId, + pub definition_id: WorkflowId, + pub status: WorkflowStatus, + pub current_step: Option, + pub step_states: HashMap, + pub context: WorkflowContext, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +impl WorkflowState { + pub fn new(instance_id: WorkflowInstanceId, definition_id: WorkflowId) -> Self { + let now = Utc::now(); + Self { + instance_id, + definition_id, + status: WorkflowStatus::Pending, + current_step: None, + step_states: HashMap::new(), + context: WorkflowContext::new(instance_id), + created_at: now, + updated_at: now, + } + } +} + +/// Shared context passed between workflow steps +/// +/// # Serialization Warning +/// +/// The `data` field contains type-erased values that cannot be serialized. +/// This means workflow context is **not preserved** across: +/// - Process restarts +/// - State persistence to disk +/// - Network serialization +/// +/// The workflow engine only supports **in-memory execution**. If you need +/// durable workflows, consider implementing a custom serializable context type. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkflowContext { + pub instance_id: WorkflowInstanceId, + #[serde(skip)] + data: HashMap>, +} + +impl WorkflowContext { + pub fn new(instance_id: WorkflowInstanceId) -> Self { + Self { + instance_id, + data: HashMap::new(), + } + } + + /// Store a value in the context (will be wrapped in Arc) + pub fn set(&mut self, key: impl Into, value: T) { + self.data.insert(key.into(), Arc::new(value)); + } + + /// Store an Arc directly without double-wrapping + pub fn set_arc(&mut self, key: impl Into, value: Arc) { + self.data.insert(key.into(), value); + } + + /// Retrieve a value from the context + pub fn get(&self, key: &str) -> Option> { + self.data + .get(key) + .and_then(|v| v.clone().downcast::().ok()) + } + + /// Check if the context has any data that would be lost during serialization + pub fn has_unserializable_data(&self) -> bool { + !self.data.is_empty() + } + + /// Get the number of context entries (useful for debugging) + pub fn data_len(&self) -> usize { + self.data.len() + } +} + +/// Result returned by a step execution +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StepResult { + Success, + Failure, + Skip, +} + +/// Error kinds for workflow operations +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +pub enum WorkflowError { + #[error("Workflow not found: {0}")] + NotFound(WorkflowInstanceId), + + #[error("Workflow definition not found: {0}")] + DefinitionNotFound(WorkflowId), + + #[error("Step failed: {step_id} - {message}")] + StepFailed { step_id: StepId, message: String }, + + #[error("Step timeout: {step_id}")] + StepTimeout { step_id: StepId }, + + #[error("Workflow cancelled: {0}")] + Cancelled(WorkflowInstanceId), + + #[error("Invalid state transition: {from:?} -> {to:?}")] + InvalidStateTransition { + from: WorkflowStatus, + to: WorkflowStatus, + }, + + #[error("Context value not found: {0}")] + ContextValueNotFound(String), + + #[error("Context value type mismatch: {0}")] + ContextTypeMismatch(String), +} + +pub type WorkflowResult = Result; diff --git a/sgl-router/src/metrics.rs b/sgl-router/src/metrics.rs index b063a46b2..3f1d4b09e 100644 --- a/sgl-router/src/metrics.rs +++ b/sgl-router/src/metrics.rs @@ -530,6 +530,39 @@ impl RouterMetrics { ) .increment(1); } + + pub fn set_job_queue_depth(depth: usize) { + gauge!("sgl_router_job_queue_depth").set(depth as f64); + } + + pub fn record_job_duration(job_type: &str, duration: Duration) { + histogram!("sgl_router_job_duration_seconds", + "job_type" => job_type.to_string() + ) + .record(duration.as_secs_f64()); + } + + pub fn record_job_success(job_type: &str) { + counter!("sgl_router_job_success_total", + "job_type" => job_type.to_string() + ) + .increment(1); + } + + pub fn record_job_failure(job_type: &str) { + counter!("sgl_router_job_failure_total", + "job_type" => job_type.to_string() + ) + .increment(1); + } + + pub fn record_job_queue_full() { + counter!("sgl_router_job_queue_full_total").increment(1); + } + + pub fn record_job_shutdown_rejected() { + counter!("sgl_router_job_shutdown_rejected_total").increment(1); + } } impl TokenizerMetrics { diff --git a/sgl-router/src/protocols/worker_spec.rs b/sgl-router/src/protocols/worker_spec.rs index 44d4297ee..cd286fd51 100644 --- a/sgl-router/src/protocols/worker_spec.rs +++ b/sgl-router/src/protocols/worker_spec.rs @@ -56,6 +56,51 @@ pub struct WorkerConfigRequest { /// Additional labels (optional) #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub labels: HashMap, + + /// Health check timeout in seconds (default: 30) + #[serde(default = "default_health_check_timeout")] + pub health_check_timeout_secs: u64, + + /// Health check interval in seconds (default: 60) + #[serde(default = "default_health_check_interval")] + pub health_check_interval_secs: u64, + + /// Number of successful health checks needed to mark worker as healthy (default: 2) + #[serde(default = "default_health_success_threshold")] + pub health_success_threshold: u32, + + /// Number of failed health checks before marking worker as unhealthy (default: 3) + #[serde(default = "default_health_failure_threshold")] + pub health_failure_threshold: u32, + + /// Maximum connection attempts during worker registration (default: 20) + #[serde(default = "default_max_connection_attempts")] + pub max_connection_attempts: u32, + + /// Enable data parallelism aware scheduling (default: false) + #[serde(default)] + pub dp_aware: bool, +} + +// Default value functions for serde +fn default_health_check_timeout() -> u64 { + 30 +} + +fn default_health_check_interval() -> u64 { + 60 +} + +fn default_health_success_threshold() -> u32 { + 2 +} + +fn default_health_failure_threshold() -> u32 { + 3 +} + +fn default_max_connection_attempts() -> u32 { + 20 } /// Worker information for API responses diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 8be88e995..a57d44282 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -22,8 +22,8 @@ use tracing::{error, info, warn, Level}; use crate::{ config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, core::{ - worker_to_info, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, - WorkerType, + worker_to_info, workflow::WorkflowEngine, Job, JobQueue, JobQueueConfig, LoadMonitor, + WorkerManager, WorkerRegistry, WorkerType, }, data_connector::{ MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, @@ -77,6 +77,7 @@ pub struct AppContext { pub configured_reasoning_parser: Option, pub configured_tool_parser: Option, pub worker_job_queue: Arc>>, + pub workflow_engine: Arc>>, } impl AppContext { @@ -95,6 +96,7 @@ impl AppContext { conversation_item_storage: SharedConversationItemStorage, load_monitor: Option>, worker_job_queue: Arc>>, + workflow_engine: Arc>>, ) -> Self { let configured_reasoning_parser = router_config.reasoning_parser.clone(); let configured_tool_parser = router_config.tool_call_parser.clone(); @@ -116,6 +118,7 @@ impl AppContext { configured_reasoning_parser, configured_tool_parser, worker_job_queue, + workflow_engine, } } } @@ -979,8 +982,9 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Result<(), Box Result<(), Box String { + // Default to http:// prefix; workflow will detect actual protocol (HTTP vs gRPC) format!("http://{}:{}", self.ip, port) } } @@ -382,10 +387,18 @@ async fn handle_pod_event( tool_parser: None, chat_template: None, api_key: None, + health_check_timeout_secs: app_context.router_config.health_check.timeout_secs, + health_check_interval_secs: app_context + .router_config + .health_check + .check_interval_secs, + health_success_threshold: app_context.router_config.health_check.success_threshold, + health_failure_threshold: app_context.router_config.health_check.failure_threshold, + max_connection_attempts: app_context.router_config.health_check.success_threshold + * 20, + dp_aware: false, }; - // Submit job for async worker addition - use crate::core::Job; let job = Job::AddWorker { config: Box::new(config.clone()), }; @@ -568,6 +581,7 @@ mod tests { configured_reasoning_parser: None, configured_tool_parser: None, worker_job_queue: Arc::new(std::sync::OnceLock::new()), + workflow_engine: Arc::new(std::sync::OnceLock::new()), }) } @@ -815,19 +829,6 @@ mod tests { assert!(!not_running_pod.is_healthy()); } - #[test] - fn test_pod_info_worker_url() { - let pod_info = PodInfo { - name: "p1".into(), - ip: "1.2.3.4".into(), - status: "Running".into(), - is_ready: true, - pod_type: None, - bootstrap_port: None, - }; - assert_eq!(pod_info.worker_url(8080), "http://1.2.3.4:8080"); - } - #[test] fn test_pod_info_equality_with_pod_type() { let pod1 = PodInfo { diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index ab92d4ed8..d35538875 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -62,8 +62,9 @@ pub fn create_test_context(config: RouterConfig) -> Arc { config.worker_startup_check_interval_secs, ))); - // Create empty OnceLock for worker job queue + // Create empty OnceLock for worker job queue and workflow engine let worker_job_queue = Arc::new(OnceLock::new()); + let workflow_engine = Arc::new(OnceLock::new()); Arc::new(AppContext::new( config, @@ -79,6 +80,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc { conversation_item_storage, load_monitor, worker_job_queue, + workflow_engine, )) } diff --git a/sgl-router/tests/common/test_app.rs b/sgl-router/tests/common/test_app.rs index 76a04c086..c93bbeffb 100644 --- a/sgl-router/tests/common/test_app.rs +++ b/sgl-router/tests/common/test_app.rs @@ -53,8 +53,9 @@ pub fn create_test_app( router_config.worker_startup_check_interval_secs, ))); - // Create empty OnceLock for worker job queue + // Create empty OnceLock for worker job queue and workflow engine let worker_job_queue = Arc::new(OnceLock::new()); + let workflow_engine = Arc::new(OnceLock::new()); // Create AppContext let app_context = Arc::new(AppContext::new( @@ -71,6 +72,7 @@ pub fn create_test_app( conversation_item_storage, load_monitor, worker_job_queue, + workflow_engine, )); // Create AppState with the test router and context diff --git a/sgl-router/tests/policy_registry_integration.rs b/sgl-router/tests/policy_registry_integration.rs index 2f38ade3e..331df82dd 100644 --- a/sgl-router/tests/policy_registry_integration.rs +++ b/sgl-router/tests/policy_registry_integration.rs @@ -36,6 +36,12 @@ async fn test_policy_registry_with_router_manager() { reasoning_parser: None, tool_parser: None, chat_template: None, + health_check_timeout_secs: 30, + health_check_interval_secs: 60, + health_success_threshold: 2, + health_failure_threshold: 3, + max_connection_attempts: 20, + dp_aware: false, }; // This would normally connect to a real worker, but for testing we'll just verify the structure @@ -61,6 +67,12 @@ async fn test_policy_registry_with_router_manager() { reasoning_parser: None, tool_parser: None, chat_template: None, + health_check_timeout_secs: 30, + health_check_interval_secs: 60, + health_success_threshold: 2, + health_failure_threshold: 3, + max_connection_attempts: 20, + dp_aware: false, }; // The second worker should use the same policy as the first (cache_aware) @@ -82,6 +94,12 @@ async fn test_policy_registry_with_router_manager() { reasoning_parser: None, tool_parser: None, chat_template: None, + health_check_timeout_secs: 30, + health_check_interval_secs: 60, + health_success_threshold: 2, + health_failure_threshold: 3, + max_connection_attempts: 20, + dp_aware: false, }; let _gpt_policy = policy_registry.get_policy("gpt-4"); diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 6e054126c..fdb566fe8 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -238,8 +238,9 @@ mod test_pd_routing { config.worker_startup_check_interval_secs, ))); - // Create empty OnceLock for worker job queue + // Create empty OnceLock for worker job queue and workflow engine let worker_job_queue = Arc::new(OnceLock::new()); + let workflow_engine = Arc::new(OnceLock::new()); Arc::new(sglang_router_rs::server::AppContext::new( config, @@ -255,6 +256,7 @@ mod test_pd_routing { conversation_item_storage, load_monitor, worker_job_queue, + workflow_engine, )) }; let result = RouterFactory::create_router(&app_context).await; diff --git a/sgl-router/tests/workflow_test.rs b/sgl-router/tests/workflow_test.rs new file mode 100644 index 000000000..6b2f58ee2 --- /dev/null +++ b/sgl-router/tests/workflow_test.rs @@ -0,0 +1,320 @@ +//! Integration tests for workflow engine + +use std::{ + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::Duration, +}; + +use sglang_router_rs::core::workflow::*; +use tokio::time::sleep; + +// Test step that counts invocations +struct CountingStep { + counter: Arc, + should_succeed_after: u32, +} + +#[async_trait::async_trait] +impl StepExecutor for CountingStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let count = self.counter.fetch_add(1, Ordering::SeqCst) + 1; + + // Store count in context + context.set("execution_count", count); + + if count >= self.should_succeed_after { + Ok(StepResult::Success) + } else { + Err(WorkflowError::StepFailed { + step_id: StepId::new("counting_step"), + message: format!("Not ready yet, attempt {}", count), + }) + } + } +} + +// Test step that always succeeds +struct AlwaysSucceedStep; + +#[async_trait::async_trait] +impl StepExecutor for AlwaysSucceedStep { + async fn execute(&self, _context: &mut WorkflowContext) -> WorkflowResult { + Ok(StepResult::Success) + } +} + +#[tokio::test] +async fn test_simple_workflow_execution() { + let engine = WorkflowEngine::new(); + + // Subscribe to events for logging + engine + .event_bus() + .subscribe(Arc::new(LoggingSubscriber)) + .await; + + // Create a simple workflow + let workflow = WorkflowDefinition::new("test_workflow", "Simple Test Workflow") + .add_step(StepDefinition::new( + "step1", + "First Step", + Arc::new(AlwaysSucceedStep), + )) + .add_step(StepDefinition::new( + "step2", + "Second Step", + Arc::new(AlwaysSucceedStep), + )); + + let workflow_id = workflow.id.clone(); + engine.register_workflow(workflow); + + // Start workflow + let instance_id = engine + .start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new())) + .await + .unwrap(); + + // Wait for completion + sleep(Duration::from_millis(100)).await; + + // Check status + let state = engine.get_status(instance_id).unwrap(); + assert_eq!(state.status, WorkflowStatus::Completed); + assert_eq!(state.step_states.len(), 2); +} + +#[tokio::test] +async fn test_workflow_with_retry() { + let engine = WorkflowEngine::new(); + engine + .event_bus() + .subscribe(Arc::new(LoggingSubscriber)) + .await; + + let counter = Arc::new(AtomicU32::new(0)); + + // Create workflow with retry logic + let workflow = WorkflowDefinition::new("retry_workflow", "Workflow with Retry").add_step( + StepDefinition::new( + "retry_step", + "Step that retries", + Arc::new(CountingStep { + counter: Arc::clone(&counter), + should_succeed_after: 3, + }), + ) + .with_retry(RetryPolicy { + max_attempts: 5, + backoff: BackoffStrategy::Fixed(Duration::from_millis(10)), + }) + .with_timeout(Duration::from_secs(5)), + ); + + let workflow_id = workflow.id.clone(); + engine.register_workflow(workflow); + + // Start workflow + let instance_id = engine + .start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new())) + .await + .unwrap(); + + // Wait for completion + sleep(Duration::from_millis(500)).await; + + // Check that step was retried and eventually succeeded + let state = engine.get_status(instance_id).unwrap(); + assert_eq!(state.status, WorkflowStatus::Completed); + + let step_state = state.step_states.get(&StepId::new("retry_step")).unwrap(); + assert_eq!(step_state.status, StepStatus::Succeeded); + assert_eq!(step_state.attempt, 3); // Should have taken 3 attempts + + // Verify counter + assert_eq!(counter.load(Ordering::SeqCst), 3); +} + +#[tokio::test] +async fn test_workflow_failure_after_max_retries() { + let engine = WorkflowEngine::new(); + engine + .event_bus() + .subscribe(Arc::new(LoggingSubscriber)) + .await; + + let counter = Arc::new(AtomicU32::new(0)); + + // Create workflow that will fail + let workflow = WorkflowDefinition::new("failing_workflow", "Workflow that Fails").add_step( + StepDefinition::new( + "failing_step", + "Step that always fails", + Arc::new(CountingStep { + counter: Arc::clone(&counter), + should_succeed_after: 10, // Will never succeed within max_attempts + }), + ) + .with_retry(RetryPolicy { + max_attempts: 3, + backoff: BackoffStrategy::Fixed(Duration::from_millis(10)), + }) + .with_failure_action(FailureAction::FailWorkflow), + ); + + let workflow_id = workflow.id.clone(); + engine.register_workflow(workflow); + + // Start workflow + let instance_id = engine + .start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new())) + .await + .unwrap(); + + // Wait for completion + sleep(Duration::from_millis(500)).await; + + // Check that workflow failed + let state = engine.get_status(instance_id).unwrap(); + assert_eq!(state.status, WorkflowStatus::Failed); + + let step_state = state.step_states.get(&StepId::new("failing_step")).unwrap(); + assert_eq!(step_state.status, StepStatus::Failed); + assert_eq!(step_state.attempt, 3); // Should have tried 3 times + + // Verify counter + assert_eq!(counter.load(Ordering::SeqCst), 3); +} + +#[tokio::test] +async fn test_workflow_continue_on_failure() { + let engine = WorkflowEngine::new(); + engine + .event_bus() + .subscribe(Arc::new(LoggingSubscriber)) + .await; + + let counter = Arc::new(AtomicU32::new(0)); + + // Create workflow where first step fails but workflow continues + let workflow = WorkflowDefinition::new("continue_workflow", "Continue on Failure") + .add_step( + StepDefinition::new( + "failing_step", + "Step that fails", + Arc::new(CountingStep { + counter: Arc::clone(&counter), + should_succeed_after: 10, + }), + ) + .with_retry(RetryPolicy { + max_attempts: 2, + backoff: BackoffStrategy::Fixed(Duration::from_millis(10)), + }) + .with_failure_action(FailureAction::ContinueNextStep), + ) + .add_step(StepDefinition::new( + "success_step", + "Step that succeeds", + Arc::new(AlwaysSucceedStep), + )); + + let workflow_id = workflow.id.clone(); + engine.register_workflow(workflow); + + // Start workflow + let instance_id = engine + .start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new())) + .await + .unwrap(); + + // Wait for completion + sleep(Duration::from_millis(500)).await; + + // Workflow should complete despite first step failing + let state = engine.get_status(instance_id).unwrap(); + assert_eq!(state.status, WorkflowStatus::Completed); + + // First step should be skipped + let step1_state = state.step_states.get(&StepId::new("failing_step")).unwrap(); + assert_eq!(step1_state.status, StepStatus::Skipped); + + // Second step should succeed + let step2_state = state.step_states.get(&StepId::new("success_step")).unwrap(); + assert_eq!(step2_state.status, StepStatus::Succeeded); +} + +#[tokio::test] +async fn test_workflow_context_sharing() { + let engine = WorkflowEngine::new(); + + struct ContextWriterStep { + key: String, + value: String, + } + + #[async_trait::async_trait] + impl StepExecutor for ContextWriterStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + context.set(self.key.clone(), self.value.clone()); + Ok(StepResult::Success) + } + } + + struct ContextReaderStep { + key: String, + expected_value: String, + } + + #[async_trait::async_trait] + impl StepExecutor for ContextReaderStep { + async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult { + let value: Arc = context + .get(&self.key) + .ok_or_else(|| WorkflowError::ContextValueNotFound(self.key.clone()))?; + + if *value == self.expected_value { + Ok(StepResult::Success) + } else { + Err(WorkflowError::StepFailed { + step_id: StepId::new("reader"), + message: format!("Expected {}, got {}", self.expected_value, value), + }) + } + } + } + + let workflow = WorkflowDefinition::new("context_workflow", "Context Sharing Test") + .add_step(StepDefinition::new( + "writer", + "Write to context", + Arc::new(ContextWriterStep { + key: "test_key".to_string(), + value: "test_value".to_string(), + }), + )) + .add_step(StepDefinition::new( + "reader", + "Read from context", + Arc::new(ContextReaderStep { + key: "test_key".to_string(), + expected_value: "test_value".to_string(), + }), + )); + + let workflow_id = workflow.id.clone(); + engine.register_workflow(workflow); + + let instance_id = engine + .start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new())) + .await + .unwrap(); + + sleep(Duration::from_millis(100)).await; + + let state = engine.get_status(instance_id).unwrap(); + assert_eq!(state.status, WorkflowStatus::Completed); +}