[router] Worker Management Workflow Engine (#11868)
This commit is contained in:
@@ -4,17 +4,24 @@
|
||||
//! them asynchronously in background worker tasks.
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Weak},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use metrics::{counter, gauge, histogram};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{
|
||||
core::WorkerManager,
|
||||
config::{RouterConfig, RoutingMode},
|
||||
core::{
|
||||
workflow::{
|
||||
WorkflowContext, WorkflowEngine, WorkflowId, WorkflowInstanceId, WorkflowStatus,
|
||||
},
|
||||
WorkerManager,
|
||||
},
|
||||
metrics::RouterMetrics,
|
||||
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
|
||||
server::AppContext,
|
||||
};
|
||||
@@ -24,6 +31,7 @@ use crate::{
|
||||
pub enum Job {
|
||||
AddWorker { config: Box<WorkerConfigRequest> },
|
||||
RemoveWorker { url: String },
|
||||
InitializeWorkersFromConfig { router_config: Box<RouterConfig> },
|
||||
}
|
||||
|
||||
impl Job {
|
||||
@@ -32,6 +40,7 @@ impl Job {
|
||||
match self {
|
||||
Job::AddWorker { .. } => "AddWorker",
|
||||
Job::RemoveWorker { .. } => "RemoveWorker",
|
||||
Job::InitializeWorkersFromConfig { .. } => "InitializeWorkersFromConfig",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +49,7 @@ impl Job {
|
||||
match self {
|
||||
Job::AddWorker { config } => &config.url,
|
||||
Job::RemoveWorker { url } => url,
|
||||
Job::InitializeWorkersFromConfig { .. } => "startup",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -98,7 +108,7 @@ impl Default for JobQueueConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
queue_capacity: 1000,
|
||||
worker_count: 2,
|
||||
worker_count: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -166,7 +176,7 @@ impl JobQueue {
|
||||
pub async fn submit(&self, job: Job) -> Result<(), String> {
|
||||
// Check if context is still alive before accepting jobs
|
||||
if self.context.upgrade().is_none() {
|
||||
counter!("sgl_router_job_shutdown_rejected_total").increment(1);
|
||||
RouterMetrics::record_job_shutdown_rejected();
|
||||
return Err("Job queue shutting down: AppContext dropped".to_string());
|
||||
}
|
||||
|
||||
@@ -183,8 +193,7 @@ impl JobQueue {
|
||||
match self.tx.send(job).await {
|
||||
Ok(_) => {
|
||||
let queue_depth = self.tx.max_capacity() - self.tx.capacity();
|
||||
gauge!("sgl_router_job_queue_depth").set(queue_depth as f64);
|
||||
|
||||
RouterMetrics::set_job_queue_depth(queue_depth);
|
||||
info!(
|
||||
"Job submitted: type={}, worker={}, queue_depth={}",
|
||||
job_type, worker_url, queue_depth
|
||||
@@ -192,8 +201,7 @@ impl JobQueue {
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => {
|
||||
counter!("sgl_router_job_queue_full_total").increment(1);
|
||||
// Remove status on failure
|
||||
RouterMetrics::record_job_queue_full();
|
||||
self.status_map.remove(&worker_url);
|
||||
Err("Worker job queue full".to_string())
|
||||
}
|
||||
@@ -246,39 +254,16 @@ impl JobQueue {
|
||||
// Upgrade weak reference to process job
|
||||
match context.upgrade() {
|
||||
Some(ctx) => {
|
||||
// Execute job
|
||||
let result = Self::execute_job(&job, &ctx).await;
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Record metrics
|
||||
histogram!("sgl_router_job_duration_seconds", "job_type" => job_type.clone())
|
||||
.record(duration.as_secs_f64());
|
||||
|
||||
match result {
|
||||
Ok(message) => {
|
||||
counter!("sgl_router_job_success_total", "job_type" => job_type.clone())
|
||||
.increment(1);
|
||||
// Remove status on success - worker in registry is sufficient
|
||||
status_map.remove(&worker_url);
|
||||
info!(
|
||||
"Worker {} completed job: type={}, worker={}, duration={:.3}s, result={}",
|
||||
worker_id, job_type, worker_url, duration.as_secs_f64(), message
|
||||
);
|
||||
}
|
||||
Err(error) => {
|
||||
counter!("sgl_router_job_failure_total", "job_type" => job_type.clone())
|
||||
.increment(1);
|
||||
// Keep failed status for API to report error details
|
||||
status_map.insert(
|
||||
worker_url.clone(),
|
||||
JobStatus::failed(&job_type, &worker_url, error.clone()),
|
||||
);
|
||||
warn!(
|
||||
"Worker {} failed job: type={}, worker={}, duration={:.3}s, error={}",
|
||||
worker_id, job_type, worker_url, duration.as_secs_f64(), error
|
||||
);
|
||||
}
|
||||
}
|
||||
Self::record_job_completion(
|
||||
&job_type,
|
||||
&worker_url,
|
||||
worker_id,
|
||||
duration,
|
||||
&result,
|
||||
&status_map,
|
||||
);
|
||||
}
|
||||
None => {
|
||||
let error_msg = "AppContext dropped".to_string();
|
||||
@@ -311,12 +296,28 @@ impl JobQueue {
|
||||
async fn execute_job(job: &Job, context: &Arc<AppContext>) -> Result<String, String> {
|
||||
match job {
|
||||
Job::AddWorker { config } => {
|
||||
// Register worker with is_healthy=false
|
||||
let worker =
|
||||
WorkerManager::add_worker_from_config(config.as_ref(), context).await?;
|
||||
let engine = context
|
||||
.workflow_engine
|
||||
.get()
|
||||
.ok_or_else(|| "Workflow engine not initialized".to_string())?;
|
||||
|
||||
// Validate and activate
|
||||
WorkerManager::validate_and_activate_worker(&worker, context).await
|
||||
let instance_id = Self::start_worker_workflow(engine, config, context).await?;
|
||||
|
||||
info!(
|
||||
"Started worker registration workflow for {} (instance: {})",
|
||||
config.url, instance_id
|
||||
);
|
||||
|
||||
let timeout_duration =
|
||||
Duration::from_secs(context.router_config.worker_startup_timeout_secs + 30);
|
||||
|
||||
Self::wait_for_workflow_completion(
|
||||
engine,
|
||||
instance_id,
|
||||
&config.url,
|
||||
timeout_duration,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Job::RemoveWorker { url } => {
|
||||
let result = WorkerManager::remove_worker(url, context);
|
||||
@@ -326,6 +327,204 @@ impl JobQueue {
|
||||
}
|
||||
result
|
||||
}
|
||||
Job::InitializeWorkersFromConfig { router_config } => {
|
||||
let api_key = router_config.api_key.clone();
|
||||
let mut worker_count = 0;
|
||||
|
||||
// Create iterator of (url, worker_type, bootstrap_port) tuples based on mode
|
||||
let workers: Vec<(String, &str, Option<u16>)> = match &router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => worker_urls
|
||||
.iter()
|
||||
.map(|url| (url.clone(), "regular", None))
|
||||
.collect(),
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
..
|
||||
} => {
|
||||
let prefill_workers = prefill_urls
|
||||
.iter()
|
||||
.map(|(url, port)| (url.clone(), "prefill", *port));
|
||||
|
||||
let decode_workers =
|
||||
decode_urls.iter().map(|url| (url.clone(), "decode", None));
|
||||
|
||||
prefill_workers.chain(decode_workers).collect()
|
||||
}
|
||||
RoutingMode::OpenAI { .. } => {
|
||||
info!("OpenAI mode: no workers to initialize");
|
||||
return Ok("OpenAI mode: no workers to initialize".to_string());
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
"Creating AddWorker jobs for {} workers from config",
|
||||
workers.len()
|
||||
);
|
||||
|
||||
// Process all workers with unified loop
|
||||
for (url, worker_type, bootstrap_port) in workers {
|
||||
let url_for_error = url.clone(); // Clone for error message
|
||||
let config = WorkerConfigRequest {
|
||||
url,
|
||||
api_key: api_key.clone(),
|
||||
worker_type: Some(worker_type.to_string()),
|
||||
labels: HashMap::new(),
|
||||
model_id: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
bootstrap_port,
|
||||
health_check_timeout_secs: router_config.health_check.timeout_secs,
|
||||
health_check_interval_secs: router_config.health_check.check_interval_secs,
|
||||
health_success_threshold: router_config.health_check.success_threshold,
|
||||
health_failure_threshold: router_config.health_check.failure_threshold,
|
||||
max_connection_attempts: router_config.health_check.success_threshold * 10,
|
||||
dp_aware: router_config.dp_aware,
|
||||
};
|
||||
|
||||
let job = Job::AddWorker {
|
||||
config: Box::new(config),
|
||||
};
|
||||
|
||||
if let Some(queue) = context.worker_job_queue.get() {
|
||||
queue.submit(job).await.map_err(|e| {
|
||||
format!(
|
||||
"Failed to submit AddWorker job for {} worker {}: {}",
|
||||
worker_type, url_for_error, e
|
||||
)
|
||||
})?;
|
||||
worker_count += 1;
|
||||
} else {
|
||||
return Err("JobQueue not available".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(format!("Submitted {} AddWorker jobs", worker_count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a workflow and return its instance ID
|
||||
async fn start_worker_workflow(
|
||||
engine: &Arc<WorkflowEngine>,
|
||||
config: &WorkerConfigRequest,
|
||||
context: &Arc<AppContext>,
|
||||
) -> Result<WorkflowInstanceId, String> {
|
||||
let mut workflow_context = WorkflowContext::new(WorkflowInstanceId::new());
|
||||
workflow_context.set("worker_config", config.clone());
|
||||
workflow_context.set_arc("app_context", Arc::clone(context));
|
||||
|
||||
engine
|
||||
.start_workflow(WorkflowId::new("worker_registration"), workflow_context)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to start worker registration workflow: {:?}", e))
|
||||
}
|
||||
|
||||
/// Wait for workflow completion with adaptive polling
|
||||
async fn wait_for_workflow_completion(
|
||||
engine: &Arc<WorkflowEngine>,
|
||||
instance_id: WorkflowInstanceId,
|
||||
worker_url: &str,
|
||||
timeout_duration: Duration,
|
||||
) -> Result<String, String> {
|
||||
let start = std::time::Instant::now();
|
||||
let mut poll_interval = Duration::from_millis(100);
|
||||
let max_poll_interval = Duration::from_millis(2000);
|
||||
let poll_backoff = Duration::from_millis(200);
|
||||
|
||||
loop {
|
||||
// Check timeout
|
||||
if start.elapsed() > timeout_duration {
|
||||
return Err(format!(
|
||||
"Workflow timeout after {}s for worker {}",
|
||||
timeout_duration.as_secs(),
|
||||
worker_url
|
||||
));
|
||||
}
|
||||
|
||||
// Get workflow status
|
||||
let state = engine
|
||||
.get_status(instance_id)
|
||||
.map_err(|e| format!("Failed to get workflow status: {:?}", e))?;
|
||||
|
||||
let result = match state.status {
|
||||
WorkflowStatus::Completed => Ok(format!(
|
||||
"Worker {} registered and activated successfully via workflow",
|
||||
worker_url
|
||||
)),
|
||||
WorkflowStatus::Failed => {
|
||||
let current_step = state.current_step.as_ref();
|
||||
let step_name = current_step
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let error_msg = current_step
|
||||
.and_then(|step_id| state.step_states.get(step_id))
|
||||
.and_then(|s| s.last_error.as_deref())
|
||||
.unwrap_or("Unknown error");
|
||||
Err(format!(
|
||||
"Workflow failed at step {}: {}",
|
||||
step_name, error_msg
|
||||
))
|
||||
}
|
||||
WorkflowStatus::Cancelled => {
|
||||
Err(format!("Workflow cancelled for worker {}", worker_url))
|
||||
}
|
||||
WorkflowStatus::Pending | WorkflowStatus::Paused | WorkflowStatus::Running => {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
poll_interval = (poll_interval + poll_backoff).min(max_poll_interval);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Clean up terminal workflow states
|
||||
engine.state_store().cleanup_if_terminal(instance_id);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/// Record job completion metrics and update status
|
||||
fn record_job_completion(
|
||||
job_type: &str,
|
||||
worker_url: &str,
|
||||
worker_id: usize,
|
||||
duration: Duration,
|
||||
result: &Result<String, String>,
|
||||
status_map: &Arc<DashMap<String, JobStatus>>,
|
||||
) {
|
||||
RouterMetrics::record_job_duration(job_type, duration);
|
||||
|
||||
match result {
|
||||
Ok(message) => {
|
||||
RouterMetrics::record_job_success(job_type);
|
||||
status_map.remove(worker_url);
|
||||
info!(
|
||||
"Worker {} completed job: type={}, worker={}, duration={:.3}s, result={}",
|
||||
worker_id,
|
||||
job_type,
|
||||
worker_url,
|
||||
duration.as_secs_f64(),
|
||||
message
|
||||
);
|
||||
}
|
||||
Err(error) => {
|
||||
RouterMetrics::record_job_failure(job_type);
|
||||
status_map.insert(
|
||||
worker_url.to_string(),
|
||||
JobStatus::failed(job_type, worker_url, error.clone()),
|
||||
);
|
||||
warn!(
|
||||
"Worker {} failed job: type={}, worker={}, duration={:.3}s, error={}",
|
||||
worker_id,
|
||||
job_type,
|
||||
worker_url,
|
||||
duration.as_secs_f64(),
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,15 +551,3 @@ impl JobQueue {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_job_queue_config_default() {
|
||||
let config = JobQueueConfig::default();
|
||||
assert_eq!(config.queue_capacity, 1000);
|
||||
assert_eq!(config.worker_count, 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
//! - Worker trait and implementations
|
||||
//! - Error types
|
||||
//! - Circuit breaker for reliability
|
||||
//! - Workflow engine for multi-step operations
|
||||
//! - Common utilities
|
||||
|
||||
pub mod circuit_breaker;
|
||||
@@ -15,6 +16,7 @@ pub mod worker;
|
||||
pub mod worker_builder;
|
||||
pub mod worker_manager;
|
||||
pub mod worker_registry;
|
||||
pub mod workflow;
|
||||
|
||||
pub use circuit_breaker::{
|
||||
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
|
||||
@@ -23,8 +25,8 @@ pub use error::{WorkerError, WorkerResult};
|
||||
pub use job_queue::{Job, JobQueue, JobQueueConfig};
|
||||
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
|
||||
pub use worker::{
|
||||
start_health_checker, worker_to_info, BasicWorker, ConnectionMode, DPAwareWorker,
|
||||
HealthChecker, HealthConfig, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||
worker_to_info, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
|
||||
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||
};
|
||||
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||
pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager};
|
||||
|
||||
@@ -8,7 +8,6 @@ use std::{
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures;
|
||||
use serde_json;
|
||||
use tokio::{sync::RwLock, time};
|
||||
|
||||
@@ -910,89 +909,6 @@ impl HealthChecker {
|
||||
}
|
||||
}
|
||||
|
||||
/// Start an async background health checker for a collection of workers
|
||||
pub fn start_health_checker(
|
||||
workers: Arc<std::sync::RwLock<Vec<Arc<dyn Worker>>>>,
|
||||
check_interval_secs: u64,
|
||||
) -> HealthChecker {
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
let shutdown_clone = shutdown.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut interval = time::interval(Duration::from_secs(check_interval_secs));
|
||||
|
||||
// Counter for periodic load reset (every 10 health check cycles)
|
||||
let mut check_count = 0u64;
|
||||
const LOAD_RESET_INTERVAL: u64 = 10;
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
// Check for shutdown signal
|
||||
if shutdown_clone.load(Ordering::Acquire) {
|
||||
tracing::debug!("Health checker shutting down");
|
||||
break;
|
||||
}
|
||||
|
||||
check_count += 1;
|
||||
|
||||
// Check health of all workers
|
||||
let workers_to_check = match workers.read() {
|
||||
Ok(guard) => guard.clone(),
|
||||
Err(poisoned) => {
|
||||
tracing::error!("Worker lock poisoned: {}", poisoned);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Periodically reset load counters to prevent drift
|
||||
// Only do this when we believe all workers should be idle
|
||||
if check_count.is_multiple_of(LOAD_RESET_INTERVAL) {
|
||||
let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0);
|
||||
// Only reset if load appears to be very low (likely drift)
|
||||
if max_load <= 2 {
|
||||
tracing::debug!(
|
||||
"Resetting load counters to prevent drift (max_load: {})",
|
||||
max_load
|
||||
);
|
||||
for worker in &workers_to_check {
|
||||
worker.reset_load();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform health checks concurrently
|
||||
let health_checks = workers_to_check.iter().map(|worker| {
|
||||
let worker_url = worker.url().to_string();
|
||||
let was_healthy = worker.is_healthy();
|
||||
|
||||
async move {
|
||||
match worker.check_health_async().await {
|
||||
Ok(_) => {
|
||||
if !was_healthy {
|
||||
tracing::info!("Worker {} is now healthy", worker_url);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if was_healthy {
|
||||
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
|
||||
} else {
|
||||
// Worker was already unhealthy, log at debug level
|
||||
tracing::debug!("Worker {} remains unhealthy: {}", worker_url, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Execute all health checks concurrently
|
||||
futures::future::join_all(health_checks).await;
|
||||
}
|
||||
});
|
||||
|
||||
HealthChecker { handle, shutdown }
|
||||
}
|
||||
|
||||
/// Helper to convert Worker trait object to WorkerInfo struct
|
||||
pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo {
|
||||
let worker_type_str = match worker.worker_type() {
|
||||
|
||||
98
sgl-router/src/core/workflow/definition.rs
Normal file
98
sgl-router/src/core/workflow/definition.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
//! Workflow definition types
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use super::{
|
||||
executor::StepExecutor,
|
||||
types::{FailureAction, RetryPolicy, StepId, WorkflowId},
|
||||
};
|
||||
|
||||
/// Definition of a single step within a workflow
|
||||
pub struct StepDefinition {
|
||||
pub id: StepId,
|
||||
pub name: String,
|
||||
pub executor: Arc<dyn StepExecutor>,
|
||||
pub retry_policy: Option<RetryPolicy>,
|
||||
pub timeout: Option<Duration>,
|
||||
pub on_failure: FailureAction,
|
||||
}
|
||||
|
||||
impl StepDefinition {
|
||||
pub fn new(
|
||||
id: impl Into<String>,
|
||||
name: impl Into<String>,
|
||||
executor: Arc<dyn StepExecutor>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: StepId::new(id.into()),
|
||||
name: name.into(),
|
||||
executor,
|
||||
retry_policy: None,
|
||||
timeout: None,
|
||||
on_failure: FailureAction::FailWorkflow,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_retry(mut self, policy: RetryPolicy) -> Self {
|
||||
self.retry_policy = Some(policy);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.timeout = Some(timeout);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_failure_action(mut self, action: FailureAction) -> Self {
|
||||
self.on_failure = action;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete workflow definition
|
||||
pub struct WorkflowDefinition {
|
||||
pub id: WorkflowId,
|
||||
pub name: String,
|
||||
pub steps: Vec<StepDefinition>,
|
||||
pub default_retry_policy: RetryPolicy,
|
||||
pub default_timeout: Duration,
|
||||
}
|
||||
|
||||
impl WorkflowDefinition {
|
||||
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: WorkflowId::new(id.into()),
|
||||
name: name.into(),
|
||||
steps: Vec::new(),
|
||||
default_retry_policy: RetryPolicy::default(),
|
||||
default_timeout: Duration::from_secs(300), // 5 minutes
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_step(mut self, step: StepDefinition) -> Self {
|
||||
self.steps.push(step);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_default_retry(mut self, policy: RetryPolicy) -> Self {
|
||||
self.default_retry_policy = policy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.default_timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the retry policy for a step (step-specific or default)
|
||||
pub fn get_retry_policy<'a>(&'a self, step: &'a StepDefinition) -> &'a RetryPolicy {
|
||||
step.retry_policy
|
||||
.as_ref()
|
||||
.unwrap_or(&self.default_retry_policy)
|
||||
}
|
||||
|
||||
/// Get the timeout for a step (step-specific or default)
|
||||
pub fn get_timeout(&self, step: &StepDefinition) -> Duration {
|
||||
step.timeout.unwrap_or(self.default_timeout)
|
||||
}
|
||||
}
|
||||
484
sgl-router/src/core/workflow/engine.rs
Normal file
484
sgl-router/src/core/workflow/engine.rs
Normal file
@@ -0,0 +1,484 @@
|
||||
//! Workflow execution engine
|
||||
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
|
||||
use chrono::Utc;
|
||||
use parking_lot::RwLock;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::{
|
||||
definition::{StepDefinition, WorkflowDefinition},
|
||||
event::{EventBus, WorkflowEvent},
|
||||
state::WorkflowStateStore,
|
||||
types::*,
|
||||
};
|
||||
|
||||
/// Linear backoff implementation that increases delay by a fixed amount each retry
|
||||
struct LinearBackoff {
|
||||
current: Duration,
|
||||
increment: Duration,
|
||||
max: Duration,
|
||||
}
|
||||
|
||||
impl LinearBackoff {
|
||||
fn new(increment: Duration, max: Duration) -> Self {
|
||||
Self {
|
||||
current: increment,
|
||||
increment,
|
||||
max,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Backoff for LinearBackoff {
|
||||
fn next_backoff(&mut self) -> Option<Duration> {
|
||||
let next = self.current;
|
||||
self.current = (self.current + self.increment).min(self.max);
|
||||
Some(next)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.current = self.increment;
|
||||
}
|
||||
}
|
||||
|
||||
/// Main workflow execution engine
|
||||
pub struct WorkflowEngine {
|
||||
definitions: Arc<RwLock<HashMap<WorkflowId, Arc<WorkflowDefinition>>>>,
|
||||
state_store: WorkflowStateStore,
|
||||
event_bus: Arc<EventBus>,
|
||||
}
|
||||
|
||||
impl WorkflowEngine {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
definitions: Arc::new(RwLock::new(HashMap::new())),
|
||||
state_store: WorkflowStateStore::new(),
|
||||
event_bus: Arc::new(EventBus::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a background task to periodically clean up old workflow states
|
||||
///
|
||||
/// This prevents unbounded memory growth by removing completed/failed workflows
|
||||
/// that are older than the specified TTL.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `ttl` - Time-to-live for terminal workflows (default: 1 hour)
|
||||
/// * `interval` - How often to run cleanup (default: 5 minutes)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A join handle for the cleanup task that can be used to stop it.
|
||||
pub fn start_cleanup_task(
|
||||
&self,
|
||||
ttl: Option<Duration>,
|
||||
interval: Option<Duration>,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
let state_store = self.state_store.clone();
|
||||
let ttl = ttl.unwrap_or(Duration::from_secs(3600)); // 1 hour default
|
||||
let interval = interval.unwrap_or(Duration::from_secs(300)); // 5 minutes default
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut ticker = tokio::time::interval(interval);
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
state_store.cleanup_old_workflows(ttl);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Register a workflow definition
|
||||
pub fn register_workflow(&self, definition: WorkflowDefinition) {
|
||||
let id = definition.id.clone();
|
||||
self.definitions.write().insert(id, Arc::new(definition));
|
||||
}
|
||||
|
||||
/// Get the event bus for subscribing to workflow events
|
||||
pub fn event_bus(&self) -> Arc<EventBus> {
|
||||
Arc::clone(&self.event_bus)
|
||||
}
|
||||
|
||||
/// Get the state store
|
||||
pub fn state_store(&self) -> &WorkflowStateStore {
|
||||
&self.state_store
|
||||
}
|
||||
|
||||
/// Start a new workflow instance
|
||||
pub async fn start_workflow(
|
||||
&self,
|
||||
definition_id: WorkflowId,
|
||||
context: WorkflowContext,
|
||||
) -> WorkflowResult<WorkflowInstanceId> {
|
||||
// Get workflow definition
|
||||
let definition = {
|
||||
let definitions = self.definitions.read();
|
||||
definitions
|
||||
.get(&definition_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| WorkflowError::DefinitionNotFound(definition_id.clone()))?
|
||||
};
|
||||
|
||||
// Create new workflow instance
|
||||
let instance_id = context.instance_id;
|
||||
let mut state = WorkflowState::new(instance_id, definition_id.clone());
|
||||
state.status = WorkflowStatus::Running;
|
||||
state.context = context;
|
||||
|
||||
// Initialize step states
|
||||
for step in &definition.steps {
|
||||
state
|
||||
.step_states
|
||||
.insert(step.id.clone(), StepState::default());
|
||||
}
|
||||
|
||||
// Save initial state
|
||||
self.state_store.save(state)?;
|
||||
|
||||
// Emit workflow started event
|
||||
self.event_bus
|
||||
.publish(WorkflowEvent::WorkflowStarted {
|
||||
instance_id,
|
||||
definition_id,
|
||||
})
|
||||
.await;
|
||||
|
||||
// Execute workflow in background
|
||||
let engine = self.clone_for_execution();
|
||||
let def = Arc::clone(&definition);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = engine.execute_workflow(instance_id, def).await {
|
||||
tracing::error!(instance_id = %instance_id, error = ?e, "Workflow execution failed");
|
||||
}
|
||||
});
|
||||
|
||||
Ok(instance_id)
|
||||
}
|
||||
|
||||
/// Execute a workflow (internal)
|
||||
async fn execute_workflow(
|
||||
&self,
|
||||
instance_id: WorkflowInstanceId,
|
||||
definition: Arc<WorkflowDefinition>,
|
||||
) -> WorkflowResult<()> {
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
for step in &definition.steps {
|
||||
// Check if workflow was cancelled
|
||||
let state = self.state_store.load(instance_id)?;
|
||||
if state.status == WorkflowStatus::Cancelled {
|
||||
self.event_bus
|
||||
.publish(WorkflowEvent::WorkflowCancelled { instance_id })
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Execute step with retry
|
||||
match self
|
||||
.execute_step_with_retry(instance_id, step, &definition)
|
||||
.await
|
||||
{
|
||||
Ok(StepResult::Success) => {
|
||||
// Continue to next step
|
||||
}
|
||||
Ok(StepResult::Skip) => {
|
||||
// Step was skipped, continue to next
|
||||
continue;
|
||||
}
|
||||
Ok(StepResult::Failure) | Err(_) => {
|
||||
// Handle failure based on failure action
|
||||
match step.on_failure {
|
||||
FailureAction::FailWorkflow => {
|
||||
let error_msg = format!("Step {} failed", step.id);
|
||||
self.state_store.update(instance_id, |s| {
|
||||
s.status = WorkflowStatus::Failed;
|
||||
})?;
|
||||
|
||||
self.event_bus
|
||||
.publish(WorkflowEvent::WorkflowFailed {
|
||||
instance_id,
|
||||
failed_step: step.id.clone(),
|
||||
error: error_msg,
|
||||
})
|
||||
.await;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
FailureAction::ContinueNextStep => {
|
||||
// Mark step as skipped and continue
|
||||
self.state_store.update(instance_id, |s| {
|
||||
if let Some(step_state) = s.step_states.get_mut(&step.id) {
|
||||
step_state.status = StepStatus::Skipped;
|
||||
}
|
||||
})?;
|
||||
continue;
|
||||
}
|
||||
FailureAction::RetryIndefinitely => {
|
||||
// This should not happen as execute_step_with_retry handles it
|
||||
unreachable!("RetryIndefinitely should be handled in retry logic");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Workflow completed successfully
|
||||
self.state_store.update(instance_id, |s| {
|
||||
s.status = WorkflowStatus::Completed;
|
||||
})?;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
self.event_bus
|
||||
.publish(WorkflowEvent::WorkflowCompleted {
|
||||
instance_id,
|
||||
duration,
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute a step with retry logic
|
||||
async fn execute_step_with_retry(
|
||||
&self,
|
||||
instance_id: WorkflowInstanceId,
|
||||
step: &StepDefinition,
|
||||
definition: &WorkflowDefinition,
|
||||
) -> WorkflowResult<StepResult> {
|
||||
let retry_policy = definition.get_retry_policy(step);
|
||||
let step_timeout = definition.get_timeout(step);
|
||||
|
||||
let mut attempt = 1;
|
||||
let max_attempts = if matches!(step.on_failure, FailureAction::RetryIndefinitely) {
|
||||
u32::MAX
|
||||
} else {
|
||||
retry_policy.max_attempts
|
||||
};
|
||||
|
||||
let mut backoff = Self::create_backoff(&retry_policy.backoff);
|
||||
|
||||
loop {
|
||||
// Check for cancellation before starting/retrying step
|
||||
{
|
||||
let state = self.state_store.load(instance_id)?;
|
||||
if state.status == WorkflowStatus::Cancelled {
|
||||
return Err(WorkflowError::Cancelled(instance_id));
|
||||
}
|
||||
}
|
||||
|
||||
// Update step state
|
||||
self.state_store.update(instance_id, |s| {
|
||||
s.current_step = Some(step.id.clone());
|
||||
if let Some(step_state) = s.step_states.get_mut(&step.id) {
|
||||
step_state.status = if attempt == 1 {
|
||||
StepStatus::Running
|
||||
} else {
|
||||
StepStatus::Retrying
|
||||
};
|
||||
step_state.attempt = attempt;
|
||||
step_state.started_at = Some(Utc::now());
|
||||
}
|
||||
})?;
|
||||
|
||||
// Emit step started event
|
||||
self.event_bus
|
||||
.publish(WorkflowEvent::StepStarted {
|
||||
instance_id,
|
||||
step_id: step.id.clone(),
|
||||
attempt,
|
||||
})
|
||||
.await;
|
||||
|
||||
// Get current context
|
||||
let mut context = self.state_store.load(instance_id)?.context;
|
||||
|
||||
// Execute step with timeout
|
||||
let step_start = std::time::Instant::now();
|
||||
let result = timeout(step_timeout, step.executor.execute(&mut context)).await;
|
||||
|
||||
let step_duration = step_start.elapsed();
|
||||
|
||||
// Save updated context
|
||||
self.state_store.update(instance_id, |s| {
|
||||
s.context = context.clone();
|
||||
})?;
|
||||
|
||||
match result {
|
||||
Ok(Ok(StepResult::Success)) => {
|
||||
// Step succeeded
|
||||
self.state_store.update(instance_id, |s| {
|
||||
if let Some(step_state) = s.step_states.get_mut(&step.id) {
|
||||
step_state.status = StepStatus::Succeeded;
|
||||
step_state.completed_at = Some(Utc::now());
|
||||
}
|
||||
})?;
|
||||
|
||||
self.event_bus
|
||||
.publish(WorkflowEvent::StepSucceeded {
|
||||
instance_id,
|
||||
step_id: step.id.clone(),
|
||||
duration: step_duration,
|
||||
})
|
||||
.await;
|
||||
|
||||
// Call on_success hook
|
||||
if let Err(e) = step.executor.on_success(&context).await {
|
||||
tracing::warn!(step_id = %step.id, error = ?e, "on_success hook failed");
|
||||
}
|
||||
|
||||
return Ok(StepResult::Success);
|
||||
}
|
||||
Ok(Ok(StepResult::Skip)) => {
|
||||
return Ok(StepResult::Skip);
|
||||
}
|
||||
Ok(Ok(StepResult::Failure)) | Ok(Err(_)) | Err(_) => {
|
||||
let (error_msg, should_retry) = match result {
|
||||
Ok(Err(e)) => {
|
||||
let msg = format!("{}", e);
|
||||
let retryable = step.executor.is_retryable(&e);
|
||||
(msg, retryable)
|
||||
}
|
||||
Err(_) => (
|
||||
format!("Step timeout after {:?}", step_timeout),
|
||||
true, // Timeouts are retryable
|
||||
),
|
||||
_ => ("Step failed".to_string(), false),
|
||||
};
|
||||
|
||||
let will_retry = should_retry && attempt < max_attempts;
|
||||
|
||||
// Update step state
|
||||
self.state_store.update(instance_id, |s| {
|
||||
if let Some(step_state) = s.step_states.get_mut(&step.id) {
|
||||
step_state.status = if will_retry {
|
||||
StepStatus::Retrying
|
||||
} else {
|
||||
StepStatus::Failed
|
||||
};
|
||||
step_state.last_error = Some(error_msg.clone());
|
||||
if !will_retry {
|
||||
step_state.completed_at = Some(Utc::now());
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
// Emit step failed event
|
||||
self.event_bus
|
||||
.publish(WorkflowEvent::StepFailed {
|
||||
instance_id,
|
||||
step_id: step.id.clone(),
|
||||
error: error_msg.clone(),
|
||||
will_retry,
|
||||
})
|
||||
.await;
|
||||
|
||||
if will_retry {
|
||||
// Calculate backoff delay
|
||||
let delay = backoff
|
||||
.next_backoff()
|
||||
.unwrap_or_else(|| Duration::from_secs(1));
|
||||
|
||||
self.event_bus
|
||||
.publish(WorkflowEvent::StepRetrying {
|
||||
instance_id,
|
||||
step_id: step.id.clone(),
|
||||
attempt: attempt + 1,
|
||||
delay,
|
||||
})
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(delay).await;
|
||||
attempt += 1;
|
||||
} else {
|
||||
// No more retries, call on_failure hook
|
||||
// Create a generic error for the hook
|
||||
let hook_error = WorkflowError::StepFailed {
|
||||
step_id: step.id.clone(),
|
||||
message: error_msg,
|
||||
};
|
||||
if let Err(hook_err) = step.executor.on_failure(&context, &hook_error).await
|
||||
{
|
||||
tracing::warn!(step_id = %step.id, error = ?hook_err, "on_failure hook failed");
|
||||
}
|
||||
|
||||
return Ok(StepResult::Failure);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a backoff instance from strategy
|
||||
fn create_backoff(strategy: &BackoffStrategy) -> Box<dyn Backoff + Send> {
|
||||
match strategy {
|
||||
BackoffStrategy::Fixed(duration) => {
|
||||
// For fixed backoff, use exponential with multiplier 1.0
|
||||
let backoff = ExponentialBackoffBuilder::new()
|
||||
.with_initial_interval(*duration)
|
||||
.with_multiplier(1.0)
|
||||
.with_max_interval(*duration)
|
||||
.with_max_elapsed_time(None)
|
||||
.build();
|
||||
Box::new(backoff)
|
||||
}
|
||||
BackoffStrategy::Exponential { base, max } => {
|
||||
let backoff = ExponentialBackoffBuilder::new()
|
||||
.with_initial_interval(*base)
|
||||
.with_max_interval(*max)
|
||||
.with_max_elapsed_time(None)
|
||||
.build();
|
||||
Box::new(backoff)
|
||||
}
|
||||
BackoffStrategy::Linear { increment, max } => {
|
||||
// Use proper linear backoff: increment, 2*increment, 3*increment, ...
|
||||
Box::new(LinearBackoff::new(*increment, *max))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel a running workflow
|
||||
pub async fn cancel_workflow(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<()> {
|
||||
self.state_store.update(instance_id, |s| {
|
||||
s.status = WorkflowStatus::Cancelled;
|
||||
})?;
|
||||
|
||||
self.event_bus
|
||||
.publish(WorkflowEvent::WorkflowCancelled { instance_id })
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get workflow status
|
||||
pub fn get_status(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<WorkflowState> {
|
||||
self.state_store.load(instance_id)
|
||||
}
|
||||
|
||||
/// Clone engine for async execution
|
||||
fn clone_for_execution(&self) -> Self {
|
||||
Self {
|
||||
definitions: Arc::clone(&self.definitions),
|
||||
state_store: self.state_store.clone(),
|
||||
event_bus: Arc::clone(&self.event_bus),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WorkflowEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for WorkflowEngine {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WorkflowEngine")
|
||||
.field("definitions_count", &self.definitions.read().len())
|
||||
.field("state_count", &self.state_store.count())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
188
sgl-router/src/core/workflow/event.rs
Normal file
188
sgl-router/src/core/workflow/event.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
//! Workflow event system for observability and monitoring
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use super::types::{StepId, WorkflowId, WorkflowInstanceId};
|
||||
|
||||
/// Events emitted by the workflow engine
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum WorkflowEvent {
|
||||
WorkflowStarted {
|
||||
instance_id: WorkflowInstanceId,
|
||||
definition_id: WorkflowId,
|
||||
},
|
||||
StepStarted {
|
||||
instance_id: WorkflowInstanceId,
|
||||
step_id: StepId,
|
||||
attempt: u32,
|
||||
},
|
||||
StepSucceeded {
|
||||
instance_id: WorkflowInstanceId,
|
||||
step_id: StepId,
|
||||
duration: Duration,
|
||||
},
|
||||
StepFailed {
|
||||
instance_id: WorkflowInstanceId,
|
||||
step_id: StepId,
|
||||
error: String,
|
||||
will_retry: bool,
|
||||
},
|
||||
StepRetrying {
|
||||
instance_id: WorkflowInstanceId,
|
||||
step_id: StepId,
|
||||
attempt: u32,
|
||||
delay: Duration,
|
||||
},
|
||||
WorkflowCompleted {
|
||||
instance_id: WorkflowInstanceId,
|
||||
duration: Duration,
|
||||
},
|
||||
WorkflowFailed {
|
||||
instance_id: WorkflowInstanceId,
|
||||
failed_step: StepId,
|
||||
error: String,
|
||||
},
|
||||
WorkflowCancelled {
|
||||
instance_id: WorkflowInstanceId,
|
||||
},
|
||||
}
|
||||
|
||||
/// Trait for subscribing to workflow events
|
||||
#[async_trait]
|
||||
pub trait EventSubscriber: Send + Sync {
|
||||
async fn on_event(&self, event: &WorkflowEvent);
|
||||
}
|
||||
|
||||
/// Event bus for publishing and subscribing to workflow events
|
||||
pub struct EventBus {
|
||||
subscribers: Arc<RwLock<Vec<Arc<dyn EventSubscriber>>>>,
|
||||
}
|
||||
|
||||
impl EventBus {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
subscribers: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to workflow events
|
||||
pub async fn subscribe(&self, subscriber: Arc<dyn EventSubscriber>) {
|
||||
self.subscribers.write().await.push(subscriber);
|
||||
}
|
||||
|
||||
/// Publish an event to all subscribers
|
||||
pub async fn publish(&self, event: WorkflowEvent) {
|
||||
let subscribers = self.subscribers.read().await;
|
||||
for subscriber in subscribers.iter() {
|
||||
subscriber.on_event(&event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EventBus {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Logging subscriber that logs events using tracing
|
||||
pub struct LoggingSubscriber;
|
||||
|
||||
#[async_trait]
|
||||
impl EventSubscriber for LoggingSubscriber {
|
||||
async fn on_event(&self, event: &WorkflowEvent) {
|
||||
match event {
|
||||
WorkflowEvent::WorkflowStarted {
|
||||
instance_id,
|
||||
definition_id,
|
||||
} => {
|
||||
info!(
|
||||
instance_id = %instance_id,
|
||||
definition_id = %definition_id,
|
||||
"Workflow started"
|
||||
);
|
||||
}
|
||||
WorkflowEvent::StepStarted {
|
||||
instance_id,
|
||||
step_id,
|
||||
attempt,
|
||||
} => {
|
||||
info!(
|
||||
instance_id = %instance_id,
|
||||
step_id = %step_id,
|
||||
attempt = attempt,
|
||||
"Step started"
|
||||
);
|
||||
}
|
||||
WorkflowEvent::StepSucceeded {
|
||||
instance_id,
|
||||
step_id,
|
||||
duration,
|
||||
} => {
|
||||
info!(
|
||||
instance_id = %instance_id,
|
||||
step_id = %step_id,
|
||||
duration_ms = duration.as_millis(),
|
||||
"Step succeeded"
|
||||
);
|
||||
}
|
||||
WorkflowEvent::StepFailed {
|
||||
instance_id,
|
||||
step_id,
|
||||
error,
|
||||
will_retry,
|
||||
} => {
|
||||
warn!(
|
||||
instance_id = %instance_id,
|
||||
step_id = %step_id,
|
||||
error = error,
|
||||
will_retry = will_retry,
|
||||
"Step failed"
|
||||
);
|
||||
}
|
||||
WorkflowEvent::StepRetrying {
|
||||
instance_id,
|
||||
step_id,
|
||||
attempt,
|
||||
delay,
|
||||
} => {
|
||||
info!(
|
||||
instance_id = %instance_id,
|
||||
step_id = %step_id,
|
||||
attempt = attempt,
|
||||
delay_ms = delay.as_millis(),
|
||||
"Step retrying"
|
||||
);
|
||||
}
|
||||
WorkflowEvent::WorkflowCompleted {
|
||||
instance_id,
|
||||
duration,
|
||||
} => {
|
||||
info!(
|
||||
instance_id = %instance_id,
|
||||
duration_ms = duration.as_millis(),
|
||||
"Workflow completed"
|
||||
);
|
||||
}
|
||||
WorkflowEvent::WorkflowFailed {
|
||||
instance_id,
|
||||
failed_step,
|
||||
error,
|
||||
} => {
|
||||
error!(
|
||||
instance_id = %instance_id,
|
||||
failed_step = %failed_step,
|
||||
error = error,
|
||||
"Workflow failed"
|
||||
);
|
||||
}
|
||||
WorkflowEvent::WorkflowCancelled { instance_id } => {
|
||||
info!(instance_id = %instance_id, "Workflow cancelled");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
129
sgl-router/src/core/workflow/executor.rs
Normal file
129
sgl-router/src/core/workflow/executor.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
//! Step executor trait and implementations
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::types::{StepResult, WorkflowContext, WorkflowError, WorkflowResult};
|
||||
|
||||
/// Trait for executing individual workflow steps
|
||||
#[async_trait]
|
||||
pub trait StepExecutor: Send + Sync {
|
||||
/// Execute the step with the given context
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult>;
|
||||
|
||||
/// Check if an error is retry-able
|
||||
///
|
||||
/// Override this method to customize which errors should trigger retries.
|
||||
/// By default, all errors are considered retry-able.
|
||||
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Called when the step succeeds
|
||||
///
|
||||
/// This hook allows steps to perform cleanup or additional actions
|
||||
/// after successful execution.
|
||||
async fn on_success(&self, _context: &WorkflowContext) -> WorkflowResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Called when the step fails after all retries
|
||||
///
|
||||
/// This hook allows steps to perform cleanup or compensation logic
|
||||
/// when the step cannot complete successfully.
|
||||
async fn on_failure(
|
||||
&self,
|
||||
_context: &WorkflowContext,
|
||||
_error: &WorkflowError,
|
||||
) -> WorkflowResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple function-based step executor
|
||||
pub struct FunctionStep<F>
|
||||
where
|
||||
F: Fn(
|
||||
&mut WorkflowContext,
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = WorkflowResult<StepResult>> + Send + '_>,
|
||||
> + Send
|
||||
+ Sync,
|
||||
{
|
||||
func: F,
|
||||
}
|
||||
|
||||
impl<F> FunctionStep<F>
|
||||
where
|
||||
F: Fn(
|
||||
&mut WorkflowContext,
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = WorkflowResult<StepResult>> + Send + '_>,
|
||||
> + Send
|
||||
+ Sync,
|
||||
{
|
||||
pub fn new(func: F) -> Self {
|
||||
Self { func }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F> StepExecutor for FunctionStep<F>
|
||||
where
|
||||
F: Fn(
|
||||
&mut WorkflowContext,
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = WorkflowResult<StepResult>> + Send + '_>,
|
||||
> + Send
|
||||
+ Sync,
|
||||
{
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
(self.func)(context).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::workflow::types::WorkflowInstanceId;
|
||||
|
||||
struct TestStep {
|
||||
should_succeed: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StepExecutor for TestStep {
|
||||
async fn execute(&self, _context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
if self.should_succeed {
|
||||
Ok(StepResult::Success)
|
||||
} else {
|
||||
Err(WorkflowError::StepFailed {
|
||||
step_id: crate::core::workflow::types::StepId::new("test"),
|
||||
message: "test error".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step_executor_success() {
|
||||
let step = TestStep {
|
||||
should_succeed: true,
|
||||
};
|
||||
let mut context = WorkflowContext::new(WorkflowInstanceId::new());
|
||||
|
||||
let result = step.execute(&mut context).await;
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), StepResult::Success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step_executor_failure() {
|
||||
let step = TestStep {
|
||||
should_succeed: false,
|
||||
};
|
||||
let mut context = WorkflowContext::new(WorkflowInstanceId::new());
|
||||
|
||||
let result = step.execute(&mut context).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
18
sgl-router/src/core/workflow/mod.rs
Normal file
18
sgl-router/src/core/workflow/mod.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! Workflow engine for managing multi-step operations
|
||||
|
||||
mod definition;
|
||||
mod engine;
|
||||
mod event;
|
||||
mod executor;
|
||||
mod state;
|
||||
pub mod steps;
|
||||
pub mod types;
|
||||
|
||||
// Re-export main types
|
||||
pub use definition::{StepDefinition, WorkflowDefinition};
|
||||
pub use engine::WorkflowEngine;
|
||||
pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent};
|
||||
pub use executor::{FunctionStep, StepExecutor};
|
||||
pub use state::WorkflowStateStore;
|
||||
pub use steps::create_worker_registration_workflow;
|
||||
pub use types::*;
|
||||
175
sgl-router/src/core/workflow/state.rs
Normal file
175
sgl-router/src/core/workflow/state.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
//! Workflow state management
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use super::types::{
|
||||
WorkflowError, WorkflowInstanceId, WorkflowResult, WorkflowState, WorkflowStatus,
|
||||
};
|
||||
|
||||
/// In-memory state storage for workflow instances
|
||||
#[derive(Clone)]
|
||||
pub struct WorkflowStateStore {
|
||||
states: Arc<RwLock<HashMap<WorkflowInstanceId, WorkflowState>>>,
|
||||
}
|
||||
|
||||
impl WorkflowStateStore {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
states: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save workflow state
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// This emits a warning if the workflow context contains unserializable data,
|
||||
/// which would be lost if state persistence is later implemented.
|
||||
pub fn save(&self, state: WorkflowState) -> WorkflowResult<()> {
|
||||
if state.context.has_unserializable_data() {
|
||||
tracing::warn!(
|
||||
instance_id = %state.instance_id,
|
||||
data_count = state.context.data_len(),
|
||||
"Saving workflow state with {} unserializable context entries. \
|
||||
This data cannot be persisted and will be lost on restart.",
|
||||
state.context.data_len()
|
||||
);
|
||||
}
|
||||
self.states.write().insert(state.instance_id, state);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load workflow state by instance ID
|
||||
pub fn load(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<WorkflowState> {
|
||||
self.states
|
||||
.read()
|
||||
.get(&instance_id)
|
||||
.cloned()
|
||||
.ok_or(WorkflowError::NotFound(instance_id))
|
||||
}
|
||||
|
||||
/// List all active workflows (Running or Pending)
|
||||
pub fn list_active(&self) -> WorkflowResult<Vec<WorkflowState>> {
|
||||
let states = self.states.read();
|
||||
Ok(states
|
||||
.values()
|
||||
.filter(|s| matches!(s.status, WorkflowStatus::Running | WorkflowStatus::Pending))
|
||||
.cloned()
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// List all workflows
|
||||
pub fn list_all(&self) -> WorkflowResult<Vec<WorkflowState>> {
|
||||
let states = self.states.read();
|
||||
Ok(states.values().cloned().collect())
|
||||
}
|
||||
|
||||
/// Delete workflow state
|
||||
pub fn delete(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<()> {
|
||||
self.states.write().remove(&instance_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update workflow state using a closure
|
||||
pub fn update<F>(&self, instance_id: WorkflowInstanceId, f: F) -> WorkflowResult<()>
|
||||
where
|
||||
F: FnOnce(&mut WorkflowState),
|
||||
{
|
||||
let mut states = self.states.write();
|
||||
let state = states
|
||||
.get_mut(&instance_id)
|
||||
.ok_or(WorkflowError::NotFound(instance_id))?;
|
||||
f(state);
|
||||
state.updated_at = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get count of workflows by status
|
||||
pub fn count_by_status(&self, status: WorkflowStatus) -> usize {
|
||||
self.states
|
||||
.read()
|
||||
.values()
|
||||
.filter(|s| s.status == status)
|
||||
.count()
|
||||
}
|
||||
|
||||
/// Get total count of all workflows
|
||||
pub fn count(&self) -> usize {
|
||||
self.states.read().len()
|
||||
}
|
||||
|
||||
/// Clean up old completed/failed/cancelled workflows beyond a time threshold
|
||||
///
|
||||
/// This prevents unbounded memory growth by removing workflow states that
|
||||
/// have been in a terminal state (Completed, Failed, Cancelled) for longer
|
||||
/// than the specified TTL (time-to-live).
|
||||
///
|
||||
/// Active workflows (Running, Pending, Paused) are never cleaned up.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `ttl` - Time-to-live for terminal workflows. Workflows in terminal states
|
||||
/// older than this will be removed.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The number of workflow states removed.
|
||||
pub fn cleanup_old_workflows(&self, ttl: std::time::Duration) -> usize {
|
||||
let now = chrono::Utc::now();
|
||||
let mut states = self.states.write();
|
||||
let initial_count = states.len();
|
||||
|
||||
states.retain(|_, state| {
|
||||
// Keep active workflows
|
||||
if matches!(
|
||||
state.status,
|
||||
WorkflowStatus::Running | WorkflowStatus::Pending | WorkflowStatus::Paused
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// For terminal workflows, check age
|
||||
let age = now
|
||||
.signed_duration_since(state.updated_at)
|
||||
.to_std()
|
||||
.unwrap_or_default();
|
||||
age < ttl
|
||||
});
|
||||
|
||||
let removed_count = initial_count - states.len();
|
||||
if removed_count > 0 {
|
||||
tracing::info!(
|
||||
removed = removed_count,
|
||||
remaining = states.len(),
|
||||
"Cleaned up old workflow states"
|
||||
);
|
||||
}
|
||||
removed_count
|
||||
}
|
||||
|
||||
/// Clean up a specific completed workflow immediately
|
||||
///
|
||||
/// This is useful for cleaning up workflows right after they complete
|
||||
/// when you know they won't be queried again.
|
||||
pub fn cleanup_if_terminal(&self, instance_id: WorkflowInstanceId) -> bool {
|
||||
let mut states = self.states.write();
|
||||
if let Some(state) = states.get(&instance_id) {
|
||||
if matches!(
|
||||
state.status,
|
||||
WorkflowStatus::Completed | WorkflowStatus::Failed | WorkflowStatus::Cancelled
|
||||
) {
|
||||
states.remove(&instance_id);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WorkflowStateStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
12
sgl-router/src/core/workflow/steps/mod.rs
Normal file
12
sgl-router/src/core/workflow/steps/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Workflow step implementations
|
||||
//!
|
||||
//! This module contains concrete step implementations for various workflows:
|
||||
//! - Worker registration and activation
|
||||
//! - Future: Tokenizer fetching, LoRA updates, etc.
|
||||
|
||||
pub mod worker_registration;
|
||||
|
||||
pub use worker_registration::{
|
||||
create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep,
|
||||
DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep,
|
||||
};
|
||||
837
sgl-router/src/core/workflow/steps/worker_registration.rs
Normal file
837
sgl-router/src/core/workflow/steps/worker_registration.rs
Normal file
@@ -0,0 +1,837 @@
|
||||
//! Worker registration workflow steps
|
||||
//!
|
||||
//! Each step is atomic and performs a single operation in the worker registration process.
|
||||
//!
|
||||
//! Workflow order:
|
||||
//! 1. DetectConnectionMode - Probe both HTTP and gRPC to determine connection mode
|
||||
//! 2. DiscoverMetadata - Fetch metadata from the worker
|
||||
//! 3. DiscoverDPInfo - Fetch DP (Data Parallel) information (only for DP-aware workers)
|
||||
//! 4. CreateWorker - Build worker object(s) with merged config + metadata
|
||||
//! 5. RegisterWorker - Register worker(s) in registry
|
||||
//! 6. UpdatePolicies - Update policy registry with worker information
|
||||
//! 7. ActivateWorker - Mark worker(s) as healthy
|
||||
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use once_cell::sync::Lazy;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
core::{
|
||||
workflow::*, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode,
|
||||
DPAwareWorkerBuilder, DpInfo, HealthConfig, Worker, WorkerManager, WorkerType,
|
||||
},
|
||||
grpc_client::SglangSchedulerClient,
|
||||
protocols::worker_spec::WorkerConfigRequest,
|
||||
server::AppContext,
|
||||
};
|
||||
|
||||
// HTTP client for metadata fetching
|
||||
static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| {
|
||||
Client::builder()
|
||||
.timeout(Duration::from_secs(10))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client")
|
||||
});
|
||||
|
||||
/// Helper: Strip protocol prefix from URL
|
||||
fn strip_protocol(url: &str) -> String {
|
||||
url.trim_start_matches("http://")
|
||||
.trim_start_matches("https://")
|
||||
.trim_start_matches("grpc://")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Helper: Try HTTP health check
|
||||
async fn try_http_health_check(url: &str, timeout_secs: u64) -> Result<(), String> {
|
||||
let clean_url = strip_protocol(url);
|
||||
let health_url = format!("http://{}/health", clean_url);
|
||||
|
||||
HTTP_CLIENT
|
||||
.get(&health_url)
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("HTTP health check failed: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper: Try gRPC health check
|
||||
async fn try_grpc_health_check(url: &str, timeout_secs: u64) -> Result<(), String> {
|
||||
let grpc_url = if url.starts_with("grpc://") {
|
||||
url.to_string()
|
||||
} else {
|
||||
format!("grpc://{}", strip_protocol(url))
|
||||
};
|
||||
|
||||
let connect_future = SglangSchedulerClient::connect(&grpc_url);
|
||||
let client = tokio::time::timeout(Duration::from_secs(timeout_secs), connect_future)
|
||||
.await
|
||||
.map_err(|_| "gRPC connection timeout".to_string())?
|
||||
.map_err(|e| format!("gRPC connection failed: {}", e))?;
|
||||
|
||||
let health_future = client.health_check();
|
||||
tokio::time::timeout(Duration::from_secs(timeout_secs), health_future)
|
||||
.await
|
||||
.map_err(|_| "gRPC health check timeout".to_string())?
|
||||
.map_err(|e| format!("gRPC health check failed: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper: Fetch HTTP metadata
|
||||
async fn fetch_http_metadata(
|
||||
url: &str,
|
||||
api_key: Option<&str>,
|
||||
) -> Result<HashMap<String, String>, String> {
|
||||
let clean_url = strip_protocol(url);
|
||||
let info_url = if clean_url.starts_with("http://") || clean_url.starts_with("https://") {
|
||||
format!("{}/get_server_info", clean_url)
|
||||
} else {
|
||||
format!("http://{}/get_server_info", clean_url)
|
||||
};
|
||||
|
||||
let mut request = HTTP_CLIENT.get(&info_url);
|
||||
if let Some(key) = api_key {
|
||||
request = request.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let response = request
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch HTTP metadata: {}", e))?;
|
||||
|
||||
let server_info: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse HTTP metadata: {}", e))?;
|
||||
|
||||
let mut labels = HashMap::new();
|
||||
|
||||
if let Some(model_path) = server_info.get("model_path").and_then(|v| v.as_str()) {
|
||||
if !model_path.is_empty() {
|
||||
labels.insert("model_path".to_string(), model_path.to_string());
|
||||
}
|
||||
}
|
||||
if let Some(tokenizer_path) = server_info.get("tokenizer_path").and_then(|v| v.as_str()) {
|
||||
if !tokenizer_path.is_empty() {
|
||||
labels.insert("tokenizer_path".to_string(), tokenizer_path.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(labels)
|
||||
}
|
||||
|
||||
/// Helper: Fetch gRPC metadata
|
||||
async fn fetch_grpc_metadata(url: &str) -> Result<HashMap<String, String>, String> {
|
||||
let grpc_url = if url.starts_with("grpc://") {
|
||||
url.to_string()
|
||||
} else {
|
||||
format!("grpc://{}", strip_protocol(url))
|
||||
};
|
||||
|
||||
let client = SglangSchedulerClient::connect(&grpc_url)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to connect to gRPC: {}", e))?;
|
||||
|
||||
let model_info = client
|
||||
.get_model_info()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch gRPC metadata: {}", e))?;
|
||||
|
||||
let mut labels = HashMap::new();
|
||||
|
||||
// Extract all available fields
|
||||
if !model_info.model_path.is_empty() {
|
||||
labels.insert("model_path".to_string(), model_info.model_path.clone());
|
||||
}
|
||||
if !model_info.tokenizer_path.is_empty() {
|
||||
labels.insert(
|
||||
"tokenizer_path".to_string(),
|
||||
model_info.tokenizer_path.clone(),
|
||||
);
|
||||
}
|
||||
if !model_info.served_model_name.is_empty() {
|
||||
labels.insert(
|
||||
"served_model_name".to_string(),
|
||||
model_info.served_model_name.clone(),
|
||||
);
|
||||
}
|
||||
if !model_info.weight_version.is_empty() {
|
||||
labels.insert(
|
||||
"weight_version".to_string(),
|
||||
model_info.weight_version.clone(),
|
||||
);
|
||||
}
|
||||
if !model_info.model_type.is_empty() {
|
||||
labels.insert("model_type".to_string(), model_info.model_type.clone());
|
||||
}
|
||||
if model_info.max_context_length > 0 {
|
||||
labels.insert(
|
||||
"max_context_length".to_string(),
|
||||
model_info.max_context_length.to_string(),
|
||||
);
|
||||
}
|
||||
if model_info.max_req_input_len > 0 {
|
||||
labels.insert(
|
||||
"max_req_input_len".to_string(),
|
||||
model_info.max_req_input_len.to_string(),
|
||||
);
|
||||
}
|
||||
if model_info.vocab_size > 0 {
|
||||
labels.insert("vocab_size".to_string(), model_info.vocab_size.to_string());
|
||||
}
|
||||
if model_info.is_generation {
|
||||
labels.insert("is_generation".to_string(), "true".to_string());
|
||||
}
|
||||
|
||||
Ok(labels)
|
||||
}
|
||||
|
||||
/// Step 1: Detect connection mode by probing both HTTP and gRPC
|
||||
pub struct DetectConnectionModeStep;
|
||||
|
||||
#[async_trait]
|
||||
impl StepExecutor for DetectConnectionModeStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
let config: Arc<WorkerConfigRequest> = context
|
||||
.get("worker_config")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?;
|
||||
|
||||
info!(
|
||||
"Detecting connection mode for {} (timeout: {}s, max_attempts: {})",
|
||||
config.url, config.health_check_timeout_secs, config.max_connection_attempts
|
||||
);
|
||||
|
||||
// Try both protocols in parallel using configured timeout
|
||||
let url = config.url.clone();
|
||||
let timeout = config.health_check_timeout_secs;
|
||||
let (http_result, grpc_result) = tokio::join!(
|
||||
try_http_health_check(&url, timeout),
|
||||
try_grpc_health_check(&url, timeout)
|
||||
);
|
||||
|
||||
let connection_mode = match (http_result, grpc_result) {
|
||||
(Ok(_), _) => {
|
||||
info!("{} detected as HTTP", config.url);
|
||||
ConnectionMode::Http
|
||||
}
|
||||
(_, Ok(_)) => {
|
||||
info!("{} detected as gRPC", config.url);
|
||||
ConnectionMode::Grpc { port: None }
|
||||
}
|
||||
(Err(http_err), Err(grpc_err)) => {
|
||||
return Err(WorkflowError::StepFailed {
|
||||
step_id: StepId::new("detect_connection_mode"),
|
||||
message: format!(
|
||||
"Both HTTP and gRPC health checks failed for {}: HTTP: {}, gRPC: {}",
|
||||
config.url, http_err, grpc_err
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Store connection mode in context
|
||||
context.set("connection_mode", connection_mode);
|
||||
|
||||
Ok(StepResult::Success)
|
||||
}
|
||||
|
||||
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||
true // Connection issues are retryable
|
||||
}
|
||||
}
|
||||
|
||||
/// Step 2: Discover metadata from worker
|
||||
pub struct DiscoverMetadataStep;
|
||||
|
||||
#[async_trait]
|
||||
impl StepExecutor for DiscoverMetadataStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
let config: Arc<WorkerConfigRequest> = context
|
||||
.get("worker_config")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?;
|
||||
let connection_mode: Arc<ConnectionMode> = context
|
||||
.get("connection_mode")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("connection_mode".to_string()))?;
|
||||
|
||||
info!(
|
||||
"Discovering metadata for {} ({:?})",
|
||||
config.url, *connection_mode
|
||||
);
|
||||
|
||||
let discovered_labels = match connection_mode.as_ref() {
|
||||
ConnectionMode::Http => {
|
||||
fetch_http_metadata(&config.url, config.api_key.as_deref()).await
|
||||
}
|
||||
ConnectionMode::Grpc { .. } => fetch_grpc_metadata(&config.url).await,
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
warn!("Failed to fetch metadata for {}: {}", config.url, e);
|
||||
HashMap::new()
|
||||
});
|
||||
|
||||
info!(
|
||||
"Discovered {} metadata labels for {}",
|
||||
discovered_labels.len(),
|
||||
config.url
|
||||
);
|
||||
|
||||
// Store discovered labels in context
|
||||
context.set("discovered_labels", discovered_labels);
|
||||
|
||||
Ok(StepResult::Success)
|
||||
}
|
||||
|
||||
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||
true // Metadata discovery failures are retryable
|
||||
}
|
||||
}
|
||||
|
||||
/// Step 2.5: Discover DP (Data Parallel) information (only for DP-aware workers)
|
||||
pub struct DiscoverDPInfoStep;
|
||||
|
||||
#[async_trait]
|
||||
impl StepExecutor for DiscoverDPInfoStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
let config: Arc<WorkerConfigRequest> = context
|
||||
.get("worker_config")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?;
|
||||
|
||||
// Skip DP discovery if not DP-aware
|
||||
if !config.dp_aware {
|
||||
info!(
|
||||
"Worker {} is not DP-aware, skipping DP discovery",
|
||||
config.url
|
||||
);
|
||||
return Ok(StepResult::Success);
|
||||
}
|
||||
|
||||
info!("Discovering DP info for {} (DP-aware)", config.url);
|
||||
|
||||
// Get DP info from worker
|
||||
let dp_info = WorkerManager::get_dp_info(&config.url, config.api_key.as_deref())
|
||||
.await
|
||||
.map_err(|e| WorkflowError::StepFailed {
|
||||
step_id: StepId::new("discover_dp_info"),
|
||||
message: format!("Failed to get DP info: {}", e),
|
||||
})?;
|
||||
|
||||
info!(
|
||||
"Discovered DP size {} for {} (model: {})",
|
||||
dp_info.dp_size, config.url, dp_info.model_id
|
||||
);
|
||||
|
||||
// Store DP info in context
|
||||
context.set("dp_info", Arc::new(dp_info));
|
||||
|
||||
Ok(StepResult::Success)
|
||||
}
|
||||
|
||||
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||
true // DP info discovery failures are retryable
|
||||
}
|
||||
}
|
||||
|
||||
/// Step 3: Create worker object with merged configuration + metadata
|
||||
pub struct CreateWorkerStep;
|
||||
|
||||
#[async_trait]
|
||||
impl StepExecutor for CreateWorkerStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
let config: Arc<WorkerConfigRequest> = context
|
||||
.get("worker_config")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?;
|
||||
let app_context: Arc<AppContext> = context
|
||||
.get("app_context")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
|
||||
let connection_mode: Arc<ConnectionMode> = context
|
||||
.get("connection_mode")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("connection_mode".to_string()))?;
|
||||
let discovered_labels: Arc<HashMap<String, String>> = context
|
||||
.get("discovered_labels")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("discovered_labels".to_string()))?;
|
||||
|
||||
// Check if worker already exists
|
||||
if app_context
|
||||
.worker_registry
|
||||
.get_by_url(&config.url)
|
||||
.is_some()
|
||||
{
|
||||
return Err(WorkflowError::StepFailed {
|
||||
step_id: StepId::new("create_worker"),
|
||||
message: format!("Worker {} already exists", config.url),
|
||||
});
|
||||
}
|
||||
|
||||
// Build labels from config
|
||||
let mut config_labels = config.labels.clone();
|
||||
if let Some(model_id) = &config.model_id {
|
||||
config_labels.insert("model_id".to_string(), model_id.clone());
|
||||
}
|
||||
if let Some(priority) = config.priority {
|
||||
config_labels.insert("priority".to_string(), priority.to_string());
|
||||
}
|
||||
if let Some(cost) = config.cost {
|
||||
config_labels.insert("cost".to_string(), cost.to_string());
|
||||
}
|
||||
if let Some(ref tokenizer_path) = config.tokenizer_path {
|
||||
config_labels.insert("tokenizer_path".to_string(), tokenizer_path.clone());
|
||||
}
|
||||
if let Some(ref reasoning_parser) = config.reasoning_parser {
|
||||
config_labels.insert("reasoning_parser".to_string(), reasoning_parser.clone());
|
||||
}
|
||||
if let Some(ref tool_parser) = config.tool_parser {
|
||||
config_labels.insert("tool_parser".to_string(), tool_parser.clone());
|
||||
}
|
||||
if let Some(ref chat_template) = config.chat_template {
|
||||
config_labels.insert("chat_template".to_string(), chat_template.clone());
|
||||
}
|
||||
|
||||
// Merge: discovered labels first, then config labels (config takes precedence)
|
||||
let mut final_labels = discovered_labels.as_ref().clone();
|
||||
for (key, value) in &config_labels {
|
||||
final_labels.insert(key.clone(), value.clone());
|
||||
}
|
||||
|
||||
// Derive model_id if not already set
|
||||
if !final_labels.contains_key("model_id") {
|
||||
let derived_model_id = final_labels
|
||||
.get("served_model_name")
|
||||
.or_else(|| final_labels.get("model_path"))
|
||||
.cloned();
|
||||
|
||||
if let Some(model_id) = derived_model_id {
|
||||
info!("Derived model_id from metadata: {}", model_id);
|
||||
final_labels.insert("model_id".to_string(), model_id);
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Creating worker {} with {} discovered + {} config = {} final labels",
|
||||
config.url,
|
||||
discovered_labels.len(),
|
||||
config_labels.len(),
|
||||
final_labels.len()
|
||||
);
|
||||
|
||||
// Parse worker type
|
||||
let worker_type = config
|
||||
.worker_type
|
||||
.as_ref()
|
||||
.map(|t| match t.as_str() {
|
||||
"prefill" => WorkerType::Prefill {
|
||||
bootstrap_port: config.bootstrap_port,
|
||||
},
|
||||
"decode" => WorkerType::Decode,
|
||||
_ => WorkerType::Regular,
|
||||
})
|
||||
.unwrap_or(WorkerType::Regular);
|
||||
|
||||
// Build circuit breaker config
|
||||
let circuit_breaker_config = {
|
||||
let cfg = app_context.router_config.effective_circuit_breaker_config();
|
||||
CircuitBreakerConfig {
|
||||
failure_threshold: cfg.failure_threshold,
|
||||
success_threshold: cfg.success_threshold,
|
||||
timeout_duration: Duration::from_secs(cfg.timeout_duration_secs),
|
||||
window_duration: Duration::from_secs(cfg.window_duration_secs),
|
||||
}
|
||||
};
|
||||
|
||||
// Build health config
|
||||
let health_config = {
|
||||
let cfg = &app_context.router_config.health_check;
|
||||
HealthConfig {
|
||||
timeout_secs: cfg.timeout_secs,
|
||||
check_interval_secs: cfg.check_interval_secs,
|
||||
endpoint: cfg.endpoint.clone(),
|
||||
failure_threshold: cfg.failure_threshold,
|
||||
success_threshold: cfg.success_threshold,
|
||||
}
|
||||
};
|
||||
|
||||
// Normalize URL: add protocol prefix only if missing
|
||||
let normalized_url = if config.url.starts_with("http://")
|
||||
|| config.url.starts_with("https://")
|
||||
|| config.url.starts_with("grpc://")
|
||||
{
|
||||
// URL already has protocol, use as-is
|
||||
config.url.clone()
|
||||
} else {
|
||||
// Bare IP:port format, add appropriate protocol based on detected mode
|
||||
match connection_mode.as_ref() {
|
||||
ConnectionMode::Http => format!("http://{}", config.url),
|
||||
ConnectionMode::Grpc { .. } => format!("grpc://{}", config.url),
|
||||
}
|
||||
};
|
||||
|
||||
if normalized_url != config.url {
|
||||
info!(
|
||||
"Normalized worker URL: {} -> {} ({:?})",
|
||||
config.url,
|
||||
normalized_url,
|
||||
connection_mode.as_ref()
|
||||
);
|
||||
}
|
||||
|
||||
// Handle DP-aware vs non-DP-aware workers
|
||||
if config.dp_aware {
|
||||
// DP-aware path: Create multiple workers (one per rank)
|
||||
let dp_info: Arc<DpInfo> = context
|
||||
.get("dp_info")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("dp_info".to_string()))?;
|
||||
|
||||
info!(
|
||||
"Creating {} DP-aware workers for {} (dp_size: {})",
|
||||
dp_info.dp_size, config.url, dp_info.dp_size
|
||||
);
|
||||
|
||||
let mut workers = Vec::new();
|
||||
for rank in 0..dp_info.dp_size {
|
||||
let mut builder =
|
||||
DPAwareWorkerBuilder::new(normalized_url.clone(), rank, dp_info.dp_size)
|
||||
.worker_type(worker_type.clone())
|
||||
.connection_mode(connection_mode.as_ref().clone())
|
||||
.circuit_breaker_config(circuit_breaker_config.clone())
|
||||
.health_config(health_config.clone());
|
||||
|
||||
if let Some(ref api_key) = config.api_key {
|
||||
builder = builder.api_key(api_key.clone());
|
||||
}
|
||||
|
||||
if !final_labels.is_empty() {
|
||||
builder = builder.labels(final_labels.clone());
|
||||
}
|
||||
|
||||
let worker = Arc::new(builder.build()) as Arc<dyn Worker>;
|
||||
worker.set_healthy(false);
|
||||
workers.push(worker);
|
||||
|
||||
info!(
|
||||
"Created DP-aware worker {}@{}/{} ({:?})",
|
||||
config.url,
|
||||
rank,
|
||||
dp_info.dp_size,
|
||||
connection_mode.as_ref()
|
||||
);
|
||||
}
|
||||
|
||||
// Store workers (plural) and labels in context
|
||||
context.set("workers", Arc::new(workers));
|
||||
context.set("labels", final_labels);
|
||||
|
||||
Ok(StepResult::Success)
|
||||
} else {
|
||||
// Non-DP-aware path: Create single worker
|
||||
let mut builder = BasicWorkerBuilder::new(normalized_url.clone())
|
||||
.worker_type(worker_type)
|
||||
.connection_mode(connection_mode.as_ref().clone())
|
||||
.circuit_breaker_config(circuit_breaker_config)
|
||||
.health_config(health_config);
|
||||
|
||||
if let Some(ref api_key) = config.api_key {
|
||||
builder = builder.api_key(api_key.clone());
|
||||
}
|
||||
|
||||
if !final_labels.is_empty() {
|
||||
builder = builder.labels(final_labels.clone());
|
||||
}
|
||||
|
||||
let worker = Arc::new(builder.build()) as Arc<dyn Worker>;
|
||||
worker.set_healthy(false);
|
||||
|
||||
info!(
|
||||
"Created worker object for {} ({:?}) with {} labels",
|
||||
config.url,
|
||||
connection_mode.as_ref(),
|
||||
final_labels.len()
|
||||
);
|
||||
|
||||
// Store worker (singular) and labels in context
|
||||
context.set("worker", worker);
|
||||
context.set("labels", final_labels);
|
||||
|
||||
Ok(StepResult::Success)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||
false // Worker creation failures are not retryable (likely config issues)
|
||||
}
|
||||
}
|
||||
|
||||
/// Step 4: Register worker(s) in registry
|
||||
pub struct RegisterWorkerStep;
|
||||
|
||||
#[async_trait]
|
||||
impl StepExecutor for RegisterWorkerStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
let config: Arc<WorkerConfigRequest> = context
|
||||
.get("worker_config")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?;
|
||||
let app_context: Arc<AppContext> = context
|
||||
.get("app_context")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
|
||||
|
||||
// Check if we have multiple workers (DP-aware) or single worker
|
||||
if config.dp_aware {
|
||||
// DP-aware path: Register multiple workers
|
||||
let workers: Arc<Vec<Arc<dyn Worker>>> = context
|
||||
.get("workers")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("workers".to_string()))?;
|
||||
|
||||
let mut worker_ids = Vec::new();
|
||||
for worker in workers.iter() {
|
||||
let worker_id = app_context.worker_registry.register(Arc::clone(worker));
|
||||
worker_ids.push(worker_id.clone());
|
||||
info!(
|
||||
"Registered DP-aware worker {} with ID {:?}",
|
||||
config.url, worker_id
|
||||
);
|
||||
}
|
||||
|
||||
context.set("worker_ids", Arc::new(worker_ids));
|
||||
Ok(StepResult::Success)
|
||||
} else {
|
||||
// Non-DP-aware path: Register single worker
|
||||
let worker: Arc<Arc<dyn Worker>> = context
|
||||
.get("worker")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker".to_string()))?;
|
||||
|
||||
let worker_id = app_context
|
||||
.worker_registry
|
||||
.register(Arc::clone(worker.as_ref()));
|
||||
|
||||
info!("Registered worker {} with ID {:?}", config.url, worker_id);
|
||||
context.set("worker_id", worker_id);
|
||||
|
||||
Ok(StepResult::Success)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||
false // Registration failures are not retryable
|
||||
}
|
||||
}
|
||||
|
||||
/// Step 5: Update policy registry with worker information
|
||||
pub struct UpdatePoliciesStep;
|
||||
|
||||
#[async_trait]
|
||||
impl StepExecutor for UpdatePoliciesStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
let config: Arc<WorkerConfigRequest> = context
|
||||
.get("worker_config")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?;
|
||||
let labels: Arc<HashMap<String, String>> = context
|
||||
.get("labels")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("labels".to_string()))?;
|
||||
let app_context: Arc<AppContext> = context
|
||||
.get("app_context")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
|
||||
|
||||
let policy_hint = labels.get("policy").map(|s| s.as_str());
|
||||
|
||||
// Check if we have multiple workers (DP-aware) or single worker
|
||||
if config.dp_aware {
|
||||
// DP-aware path: Update policies for multiple workers
|
||||
let workers: Arc<Vec<Arc<dyn Worker>>> = context
|
||||
.get("workers")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("workers".to_string()))?;
|
||||
|
||||
// Get model_id from first worker (all DP workers have same model)
|
||||
let model_id = workers[0].model_id().to_string();
|
||||
|
||||
// Notify policy registry for each worker
|
||||
for _ in 0..workers.len() {
|
||||
app_context
|
||||
.policy_registry
|
||||
.on_worker_added(&model_id, policy_hint);
|
||||
}
|
||||
|
||||
// Initialize cache-aware policy if needed
|
||||
let all_workers = app_context.worker_registry.get_by_model_fast(&model_id);
|
||||
if let Some(policy) = app_context.policy_registry.get_policy(&model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
app_context
|
||||
.policy_registry
|
||||
.init_cache_aware_policy(&model_id, &all_workers);
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Updated policies for {} DP-aware workers {} (model: {})",
|
||||
workers.len(),
|
||||
config.url,
|
||||
model_id
|
||||
);
|
||||
} else {
|
||||
// Non-DP-aware path: Update policy for single worker
|
||||
let worker: Arc<Arc<dyn Worker>> = context
|
||||
.get("worker")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker".to_string()))?;
|
||||
|
||||
let model_id = worker.model_id().to_string();
|
||||
|
||||
// Notify policy registry
|
||||
app_context
|
||||
.policy_registry
|
||||
.on_worker_added(&model_id, policy_hint);
|
||||
|
||||
// Initialize cache-aware policy if needed
|
||||
let all_workers = app_context.worker_registry.get_by_model_fast(&model_id);
|
||||
if let Some(policy) = app_context.policy_registry.get_policy(&model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
app_context
|
||||
.policy_registry
|
||||
.init_cache_aware_policy(&model_id, &all_workers);
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Updated policies for worker {} (model: {})",
|
||||
config.url, model_id
|
||||
);
|
||||
}
|
||||
|
||||
Ok(StepResult::Success)
|
||||
}
|
||||
|
||||
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||
false // Policy update failures are not retryable
|
||||
}
|
||||
}
|
||||
|
||||
/// Step 6: Activate worker(s) by marking them as healthy
|
||||
pub struct ActivateWorkerStep;
|
||||
|
||||
#[async_trait]
|
||||
impl StepExecutor for ActivateWorkerStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
let config: Arc<WorkerConfigRequest> = context
|
||||
.get("worker_config")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?;
|
||||
|
||||
// Check if we have multiple workers (DP-aware) or single worker
|
||||
if config.dp_aware {
|
||||
// DP-aware path: Activate multiple workers
|
||||
let workers: Arc<Vec<Arc<dyn Worker>>> = context
|
||||
.get("workers")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("workers".to_string()))?;
|
||||
|
||||
for worker in workers.iter() {
|
||||
worker.set_healthy(true);
|
||||
}
|
||||
|
||||
info!(
|
||||
"Activated {} DP-aware workers {} (marked as healthy)",
|
||||
workers.len(),
|
||||
config.url
|
||||
);
|
||||
} else {
|
||||
// Non-DP-aware path: Activate single worker
|
||||
let worker: Arc<Arc<dyn Worker>> = context
|
||||
.get("worker")
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker".to_string()))?;
|
||||
|
||||
worker.set_healthy(true);
|
||||
|
||||
info!("Activated worker {} (marked as healthy)", config.url);
|
||||
}
|
||||
|
||||
Ok(StepResult::Success)
|
||||
}
|
||||
|
||||
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||
false // Activation is just setting a flag, not retryable
|
||||
}
|
||||
}
|
||||
|
||||
/// Create worker registration workflow definition
|
||||
///
|
||||
/// Note: Actual health check timeouts and retry attempts are configured per-worker
|
||||
/// via WorkerConfigRequest (populated from router config). The timeouts and retry
|
||||
/// policies here serve as workflow-level bounds to prevent infinite waiting.
|
||||
pub fn create_worker_registration_workflow() -> WorkflowDefinition {
|
||||
WorkflowDefinition::new("worker_registration", "Worker Registration")
|
||||
.add_step(
|
||||
StepDefinition::new(
|
||||
"detect_connection_mode",
|
||||
"Detect Connection Mode",
|
||||
Arc::new(DetectConnectionModeStep),
|
||||
)
|
||||
.with_retry(RetryPolicy {
|
||||
max_attempts: 100,
|
||||
backoff: BackoffStrategy::Linear {
|
||||
increment: Duration::from_secs(1),
|
||||
max: Duration::from_secs(5),
|
||||
},
|
||||
})
|
||||
// Workflow-level timeout (upper bound); step uses config.health_check_timeout_secs
|
||||
.with_timeout(Duration::from_secs(7200)) // 2 hours max
|
||||
.with_failure_action(FailureAction::FailWorkflow),
|
||||
)
|
||||
.add_step(
|
||||
StepDefinition::new(
|
||||
"discover_metadata",
|
||||
"Discover Metadata",
|
||||
Arc::new(DiscoverMetadataStep),
|
||||
)
|
||||
.with_retry(RetryPolicy {
|
||||
max_attempts: 3,
|
||||
backoff: BackoffStrategy::Fixed(Duration::from_secs(1)),
|
||||
})
|
||||
.with_timeout(Duration::from_secs(10))
|
||||
.with_failure_action(FailureAction::ContinueNextStep), // Metadata discovery is optional
|
||||
)
|
||||
.add_step(
|
||||
StepDefinition::new(
|
||||
"discover_dp_info",
|
||||
"Discover DP Info",
|
||||
Arc::new(DiscoverDPInfoStep),
|
||||
)
|
||||
.with_retry(RetryPolicy {
|
||||
max_attempts: 3,
|
||||
backoff: BackoffStrategy::Fixed(Duration::from_secs(1)),
|
||||
})
|
||||
.with_timeout(Duration::from_secs(10))
|
||||
.with_failure_action(FailureAction::FailWorkflow), // DP info is required for DP-aware workers
|
||||
)
|
||||
.add_step(
|
||||
StepDefinition::new("create_worker", "Create Worker", Arc::new(CreateWorkerStep))
|
||||
.with_timeout(Duration::from_secs(5))
|
||||
.with_failure_action(FailureAction::FailWorkflow),
|
||||
)
|
||||
.add_step(
|
||||
StepDefinition::new(
|
||||
"register_worker",
|
||||
"Register Worker",
|
||||
Arc::new(RegisterWorkerStep),
|
||||
)
|
||||
.with_timeout(Duration::from_secs(5))
|
||||
.with_failure_action(FailureAction::FailWorkflow),
|
||||
)
|
||||
.add_step(
|
||||
StepDefinition::new(
|
||||
"update_policies",
|
||||
"Update Policies",
|
||||
Arc::new(UpdatePoliciesStep),
|
||||
)
|
||||
.with_timeout(Duration::from_secs(5))
|
||||
.with_failure_action(FailureAction::ContinueNextStep), // Policy updates are optional
|
||||
)
|
||||
.add_step(
|
||||
StepDefinition::new(
|
||||
"activate_worker",
|
||||
"Activate Worker",
|
||||
Arc::new(ActivateWorkerStep),
|
||||
)
|
||||
.with_timeout(Duration::from_secs(5))
|
||||
.with_failure_action(FailureAction::FailWorkflow),
|
||||
)
|
||||
}
|
||||
271
sgl-router/src/core/workflow/types.rs
Normal file
271
sgl-router/src/core/workflow/types.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
//! Core workflow types and definitions
|
||||
|
||||
use std::{collections::HashMap, fmt, sync::Arc, time::Duration};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Unique identifier for a workflow definition
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct WorkflowId(String);
|
||||
|
||||
impl WorkflowId {
|
||||
pub fn new(id: impl Into<String>) -> Self {
|
||||
Self(id.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for WorkflowId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier for a workflow instance
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct WorkflowInstanceId(Uuid);
|
||||
|
||||
impl WorkflowInstanceId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WorkflowInstanceId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for WorkflowInstanceId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier for a workflow step
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct StepId(String);
|
||||
|
||||
impl StepId {
|
||||
pub fn new(id: impl Into<String>) -> Self {
|
||||
Self(id.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for StepId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Retry policy configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetryPolicy {
|
||||
pub max_attempts: u32,
|
||||
pub backoff: BackoffStrategy,
|
||||
}
|
||||
|
||||
impl Default for RetryPolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_attempts: 3,
|
||||
backoff: BackoffStrategy::Exponential {
|
||||
base: Duration::from_secs(1),
|
||||
max: Duration::from_secs(30),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Backoff strategy for retries
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum BackoffStrategy {
|
||||
/// Fixed delay between retries
|
||||
Fixed(Duration),
|
||||
/// Exponential backoff with base and max duration
|
||||
Exponential { base: Duration, max: Duration },
|
||||
/// Linear backoff with increment and max duration
|
||||
Linear { increment: Duration, max: Duration },
|
||||
}
|
||||
|
||||
/// Action to take when a step fails
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum FailureAction {
|
||||
/// Stop the entire workflow
|
||||
FailWorkflow,
|
||||
/// Skip this step and continue to the next
|
||||
ContinueNextStep,
|
||||
/// Keep retrying indefinitely until manual intervention
|
||||
RetryIndefinitely,
|
||||
}
|
||||
|
||||
/// Workflow execution status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum WorkflowStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Paused,
|
||||
Completed,
|
||||
Failed,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Step execution status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum StepStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Succeeded,
|
||||
Failed,
|
||||
Retrying,
|
||||
Skipped,
|
||||
}
|
||||
|
||||
/// State of a workflow step
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StepState {
|
||||
pub status: StepStatus,
|
||||
pub attempt: u32,
|
||||
pub last_error: Option<String>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl Default for StepState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
status: StepStatus::Pending,
|
||||
attempt: 0,
|
||||
last_error: None,
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Workflow instance state
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkflowState {
|
||||
pub instance_id: WorkflowInstanceId,
|
||||
pub definition_id: WorkflowId,
|
||||
pub status: WorkflowStatus,
|
||||
pub current_step: Option<StepId>,
|
||||
pub step_states: HashMap<StepId, StepState>,
|
||||
pub context: WorkflowContext,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl WorkflowState {
|
||||
pub fn new(instance_id: WorkflowInstanceId, definition_id: WorkflowId) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
instance_id,
|
||||
definition_id,
|
||||
status: WorkflowStatus::Pending,
|
||||
current_step: None,
|
||||
step_states: HashMap::new(),
|
||||
context: WorkflowContext::new(instance_id),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared context passed between workflow steps
|
||||
///
|
||||
/// # Serialization Warning
|
||||
///
|
||||
/// The `data` field contains type-erased values that cannot be serialized.
|
||||
/// This means workflow context is **not preserved** across:
|
||||
/// - Process restarts
|
||||
/// - State persistence to disk
|
||||
/// - Network serialization
|
||||
///
|
||||
/// The workflow engine only supports **in-memory execution**. If you need
|
||||
/// durable workflows, consider implementing a custom serializable context type.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkflowContext {
|
||||
pub instance_id: WorkflowInstanceId,
|
||||
#[serde(skip)]
|
||||
data: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl WorkflowContext {
|
||||
pub fn new(instance_id: WorkflowInstanceId) -> Self {
|
||||
Self {
|
||||
instance_id,
|
||||
data: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Store a value in the context (will be wrapped in Arc)
|
||||
pub fn set<T: Send + Sync + 'static>(&mut self, key: impl Into<String>, value: T) {
|
||||
self.data.insert(key.into(), Arc::new(value));
|
||||
}
|
||||
|
||||
/// Store an Arc directly without double-wrapping
|
||||
pub fn set_arc<T: Send + Sync + 'static>(&mut self, key: impl Into<String>, value: Arc<T>) {
|
||||
self.data.insert(key.into(), value);
|
||||
}
|
||||
|
||||
/// Retrieve a value from the context
|
||||
pub fn get<T: Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
|
||||
self.data
|
||||
.get(key)
|
||||
.and_then(|v| v.clone().downcast::<T>().ok())
|
||||
}
|
||||
|
||||
/// Check if the context has any data that would be lost during serialization
|
||||
pub fn has_unserializable_data(&self) -> bool {
|
||||
!self.data.is_empty()
|
||||
}
|
||||
|
||||
/// Get the number of context entries (useful for debugging)
|
||||
pub fn data_len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Result returned by a step execution
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StepResult {
|
||||
Success,
|
||||
Failure,
|
||||
Skip,
|
||||
}
|
||||
|
||||
/// Error kinds for workflow operations
|
||||
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
|
||||
pub enum WorkflowError {
|
||||
#[error("Workflow not found: {0}")]
|
||||
NotFound(WorkflowInstanceId),
|
||||
|
||||
#[error("Workflow definition not found: {0}")]
|
||||
DefinitionNotFound(WorkflowId),
|
||||
|
||||
#[error("Step failed: {step_id} - {message}")]
|
||||
StepFailed { step_id: StepId, message: String },
|
||||
|
||||
#[error("Step timeout: {step_id}")]
|
||||
StepTimeout { step_id: StepId },
|
||||
|
||||
#[error("Workflow cancelled: {0}")]
|
||||
Cancelled(WorkflowInstanceId),
|
||||
|
||||
#[error("Invalid state transition: {from:?} -> {to:?}")]
|
||||
InvalidStateTransition {
|
||||
from: WorkflowStatus,
|
||||
to: WorkflowStatus,
|
||||
},
|
||||
|
||||
#[error("Context value not found: {0}")]
|
||||
ContextValueNotFound(String),
|
||||
|
||||
#[error("Context value type mismatch: {0}")]
|
||||
ContextTypeMismatch(String),
|
||||
}
|
||||
|
||||
pub type WorkflowResult<T> = Result<T, WorkflowError>;
|
||||
@@ -530,6 +530,39 @@ impl RouterMetrics {
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
pub fn set_job_queue_depth(depth: usize) {
|
||||
gauge!("sgl_router_job_queue_depth").set(depth as f64);
|
||||
}
|
||||
|
||||
pub fn record_job_duration(job_type: &str, duration: Duration) {
|
||||
histogram!("sgl_router_job_duration_seconds",
|
||||
"job_type" => job_type.to_string()
|
||||
)
|
||||
.record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_job_success(job_type: &str) {
|
||||
counter!("sgl_router_job_success_total",
|
||||
"job_type" => job_type.to_string()
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
pub fn record_job_failure(job_type: &str) {
|
||||
counter!("sgl_router_job_failure_total",
|
||||
"job_type" => job_type.to_string()
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
pub fn record_job_queue_full() {
|
||||
counter!("sgl_router_job_queue_full_total").increment(1);
|
||||
}
|
||||
|
||||
pub fn record_job_shutdown_rejected() {
|
||||
counter!("sgl_router_job_shutdown_rejected_total").increment(1);
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenizerMetrics {
|
||||
|
||||
@@ -56,6 +56,51 @@ pub struct WorkerConfigRequest {
|
||||
/// Additional labels (optional)
|
||||
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
|
||||
pub labels: HashMap<String, String>,
|
||||
|
||||
/// Health check timeout in seconds (default: 30)
|
||||
#[serde(default = "default_health_check_timeout")]
|
||||
pub health_check_timeout_secs: u64,
|
||||
|
||||
/// Health check interval in seconds (default: 60)
|
||||
#[serde(default = "default_health_check_interval")]
|
||||
pub health_check_interval_secs: u64,
|
||||
|
||||
/// Number of successful health checks needed to mark worker as healthy (default: 2)
|
||||
#[serde(default = "default_health_success_threshold")]
|
||||
pub health_success_threshold: u32,
|
||||
|
||||
/// Number of failed health checks before marking worker as unhealthy (default: 3)
|
||||
#[serde(default = "default_health_failure_threshold")]
|
||||
pub health_failure_threshold: u32,
|
||||
|
||||
/// Maximum connection attempts during worker registration (default: 20)
|
||||
#[serde(default = "default_max_connection_attempts")]
|
||||
pub max_connection_attempts: u32,
|
||||
|
||||
/// Enable data parallelism aware scheduling (default: false)
|
||||
#[serde(default)]
|
||||
pub dp_aware: bool,
|
||||
}
|
||||
|
||||
// Default value functions for serde
|
||||
fn default_health_check_timeout() -> u64 {
|
||||
30
|
||||
}
|
||||
|
||||
fn default_health_check_interval() -> u64 {
|
||||
60
|
||||
}
|
||||
|
||||
fn default_health_success_threshold() -> u32 {
|
||||
2
|
||||
}
|
||||
|
||||
fn default_health_failure_threshold() -> u32 {
|
||||
3
|
||||
}
|
||||
|
||||
fn default_max_connection_attempts() -> u32 {
|
||||
20
|
||||
}
|
||||
|
||||
/// Worker information for API responses
|
||||
|
||||
@@ -22,8 +22,8 @@ use tracing::{error, info, warn, Level};
|
||||
use crate::{
|
||||
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
||||
core::{
|
||||
worker_to_info, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry,
|
||||
WorkerType,
|
||||
worker_to_info, workflow::WorkflowEngine, Job, JobQueue, JobQueueConfig, LoadMonitor,
|
||||
WorkerManager, WorkerRegistry, WorkerType,
|
||||
},
|
||||
data_connector::{
|
||||
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||
@@ -77,6 +77,7 @@ pub struct AppContext {
|
||||
pub configured_reasoning_parser: Option<String>,
|
||||
pub configured_tool_parser: Option<String>,
|
||||
pub worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
|
||||
pub workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
@@ -95,6 +96,7 @@ impl AppContext {
|
||||
conversation_item_storage: SharedConversationItemStorage,
|
||||
load_monitor: Option<Arc<LoadMonitor>>,
|
||||
worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
|
||||
workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
|
||||
) -> Self {
|
||||
let configured_reasoning_parser = router_config.reasoning_parser.clone();
|
||||
let configured_tool_parser = router_config.tool_call_parser.clone();
|
||||
@@ -116,6 +118,7 @@ impl AppContext {
|
||||
configured_reasoning_parser,
|
||||
configured_tool_parser,
|
||||
worker_job_queue,
|
||||
workflow_engine,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -979,8 +982,9 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
config.router_config.worker_startup_check_interval_secs,
|
||||
)));
|
||||
|
||||
// Create empty OnceLock for worker job queue (will be initialized below)
|
||||
// Create empty OnceLock for worker job queue and workflow engine (will be initialized below)
|
||||
let worker_job_queue = Arc::new(OnceLock::new());
|
||||
let workflow_engine = Arc::new(OnceLock::new());
|
||||
|
||||
// Create AppContext with all initialized components
|
||||
let app_context = AppContext::new(
|
||||
@@ -997,6 +1001,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
conversation_item_storage,
|
||||
load_monitor,
|
||||
worker_job_queue,
|
||||
workflow_engine,
|
||||
);
|
||||
|
||||
let app_context = Arc::new(app_context);
|
||||
@@ -1008,17 +1013,38 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
.set(worker_job_queue)
|
||||
.expect("JobQueue should only be initialized once");
|
||||
|
||||
// Initialize workflow engine and register workflows
|
||||
let engine = Arc::new(WorkflowEngine::new());
|
||||
|
||||
engine
|
||||
.event_bus()
|
||||
.subscribe(Arc::new(crate::core::workflow::LoggingSubscriber))
|
||||
.await;
|
||||
|
||||
engine.register_workflow(crate::core::workflow::create_worker_registration_workflow());
|
||||
app_context
|
||||
.workflow_engine
|
||||
.set(engine)
|
||||
.expect("WorkflowEngine should only be initialized once");
|
||||
info!("Workflow engine initialized with worker registration workflow");
|
||||
|
||||
info!(
|
||||
"Initializing workers for routing mode: {:?}",
|
||||
config.router_config.mode
|
||||
);
|
||||
WorkerManager::initialize_workers(
|
||||
&config.router_config,
|
||||
&app_context.worker_registry,
|
||||
Some(&app_context.policy_registry),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
|
||||
|
||||
// Submit worker initialization job to queue
|
||||
let job_queue = app_context
|
||||
.worker_job_queue
|
||||
.get()
|
||||
.expect("JobQueue should be initialized");
|
||||
let job = Job::InitializeWorkersFromConfig {
|
||||
router_config: Box::new(config.router_config.clone()),
|
||||
};
|
||||
job_queue
|
||||
.submit(job)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to submit worker initialization job: {}", e))?;
|
||||
|
||||
let worker_stats = app_context.worker_registry.stats();
|
||||
info!(
|
||||
|
||||
@@ -18,7 +18,11 @@ use rustls;
|
||||
use tokio::{task, time};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{core::WorkerManager, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
|
||||
use crate::{
|
||||
core::{Job, WorkerManager},
|
||||
protocols::worker_spec::WorkerConfigRequest,
|
||||
server::AppContext,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServiceDiscoveryConfig {
|
||||
@@ -157,6 +161,7 @@ impl PodInfo {
|
||||
}
|
||||
|
||||
pub fn worker_url(&self, port: u16) -> String {
|
||||
// Default to http:// prefix; workflow will detect actual protocol (HTTP vs gRPC)
|
||||
format!("http://{}:{}", self.ip, port)
|
||||
}
|
||||
}
|
||||
@@ -382,10 +387,18 @@ async fn handle_pod_event(
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
api_key: None,
|
||||
health_check_timeout_secs: app_context.router_config.health_check.timeout_secs,
|
||||
health_check_interval_secs: app_context
|
||||
.router_config
|
||||
.health_check
|
||||
.check_interval_secs,
|
||||
health_success_threshold: app_context.router_config.health_check.success_threshold,
|
||||
health_failure_threshold: app_context.router_config.health_check.failure_threshold,
|
||||
max_connection_attempts: app_context.router_config.health_check.success_threshold
|
||||
* 20,
|
||||
dp_aware: false,
|
||||
};
|
||||
|
||||
// Submit job for async worker addition
|
||||
use crate::core::Job;
|
||||
let job = Job::AddWorker {
|
||||
config: Box::new(config.clone()),
|
||||
};
|
||||
@@ -568,6 +581,7 @@ mod tests {
|
||||
configured_reasoning_parser: None,
|
||||
configured_tool_parser: None,
|
||||
worker_job_queue: Arc::new(std::sync::OnceLock::new()),
|
||||
workflow_engine: Arc::new(std::sync::OnceLock::new()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -815,19 +829,6 @@ mod tests {
|
||||
assert!(!not_running_pod.is_healthy());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pod_info_worker_url() {
|
||||
let pod_info = PodInfo {
|
||||
name: "p1".into(),
|
||||
ip: "1.2.3.4".into(),
|
||||
status: "Running".into(),
|
||||
is_ready: true,
|
||||
pod_type: None,
|
||||
bootstrap_port: None,
|
||||
};
|
||||
assert_eq!(pod_info.worker_url(8080), "http://1.2.3.4:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pod_info_equality_with_pod_type() {
|
||||
let pod1 = PodInfo {
|
||||
|
||||
@@ -62,8 +62,9 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||
config.worker_startup_check_interval_secs,
|
||||
)));
|
||||
|
||||
// Create empty OnceLock for worker job queue
|
||||
// Create empty OnceLock for worker job queue and workflow engine
|
||||
let worker_job_queue = Arc::new(OnceLock::new());
|
||||
let workflow_engine = Arc::new(OnceLock::new());
|
||||
|
||||
Arc::new(AppContext::new(
|
||||
config,
|
||||
@@ -79,6 +80,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||
conversation_item_storage,
|
||||
load_monitor,
|
||||
worker_job_queue,
|
||||
workflow_engine,
|
||||
))
|
||||
}
|
||||
|
||||
|
||||
@@ -53,8 +53,9 @@ pub fn create_test_app(
|
||||
router_config.worker_startup_check_interval_secs,
|
||||
)));
|
||||
|
||||
// Create empty OnceLock for worker job queue
|
||||
// Create empty OnceLock for worker job queue and workflow engine
|
||||
let worker_job_queue = Arc::new(OnceLock::new());
|
||||
let workflow_engine = Arc::new(OnceLock::new());
|
||||
|
||||
// Create AppContext
|
||||
let app_context = Arc::new(AppContext::new(
|
||||
@@ -71,6 +72,7 @@ pub fn create_test_app(
|
||||
conversation_item_storage,
|
||||
load_monitor,
|
||||
worker_job_queue,
|
||||
workflow_engine,
|
||||
));
|
||||
|
||||
// Create AppState with the test router and context
|
||||
|
||||
@@ -36,6 +36,12 @@ async fn test_policy_registry_with_router_manager() {
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
health_check_timeout_secs: 30,
|
||||
health_check_interval_secs: 60,
|
||||
health_success_threshold: 2,
|
||||
health_failure_threshold: 3,
|
||||
max_connection_attempts: 20,
|
||||
dp_aware: false,
|
||||
};
|
||||
|
||||
// This would normally connect to a real worker, but for testing we'll just verify the structure
|
||||
@@ -61,6 +67,12 @@ async fn test_policy_registry_with_router_manager() {
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
health_check_timeout_secs: 30,
|
||||
health_check_interval_secs: 60,
|
||||
health_success_threshold: 2,
|
||||
health_failure_threshold: 3,
|
||||
max_connection_attempts: 20,
|
||||
dp_aware: false,
|
||||
};
|
||||
|
||||
// The second worker should use the same policy as the first (cache_aware)
|
||||
@@ -82,6 +94,12 @@ async fn test_policy_registry_with_router_manager() {
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
health_check_timeout_secs: 30,
|
||||
health_check_interval_secs: 60,
|
||||
health_success_threshold: 2,
|
||||
health_failure_threshold: 3,
|
||||
max_connection_attempts: 20,
|
||||
dp_aware: false,
|
||||
};
|
||||
|
||||
let _gpt_policy = policy_registry.get_policy("gpt-4");
|
||||
|
||||
@@ -238,8 +238,9 @@ mod test_pd_routing {
|
||||
config.worker_startup_check_interval_secs,
|
||||
)));
|
||||
|
||||
// Create empty OnceLock for worker job queue
|
||||
// Create empty OnceLock for worker job queue and workflow engine
|
||||
let worker_job_queue = Arc::new(OnceLock::new());
|
||||
let workflow_engine = Arc::new(OnceLock::new());
|
||||
|
||||
Arc::new(sglang_router_rs::server::AppContext::new(
|
||||
config,
|
||||
@@ -255,6 +256,7 @@ mod test_pd_routing {
|
||||
conversation_item_storage,
|
||||
load_monitor,
|
||||
worker_job_queue,
|
||||
workflow_engine,
|
||||
))
|
||||
};
|
||||
let result = RouterFactory::create_router(&app_context).await;
|
||||
|
||||
320
sgl-router/tests/workflow_test.rs
Normal file
320
sgl-router/tests/workflow_test.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! Integration tests for workflow engine
|
||||
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use sglang_router_rs::core::workflow::*;
|
||||
use tokio::time::sleep;
|
||||
|
||||
// Test step that counts invocations
|
||||
struct CountingStep {
|
||||
counter: Arc<AtomicU32>,
|
||||
should_succeed_after: u32,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StepExecutor for CountingStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
let count = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
|
||||
// Store count in context
|
||||
context.set("execution_count", count);
|
||||
|
||||
if count >= self.should_succeed_after {
|
||||
Ok(StepResult::Success)
|
||||
} else {
|
||||
Err(WorkflowError::StepFailed {
|
||||
step_id: StepId::new("counting_step"),
|
||||
message: format!("Not ready yet, attempt {}", count),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test step that always succeeds
|
||||
struct AlwaysSucceedStep;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StepExecutor for AlwaysSucceedStep {
|
||||
async fn execute(&self, _context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
Ok(StepResult::Success)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_simple_workflow_execution() {
|
||||
let engine = WorkflowEngine::new();
|
||||
|
||||
// Subscribe to events for logging
|
||||
engine
|
||||
.event_bus()
|
||||
.subscribe(Arc::new(LoggingSubscriber))
|
||||
.await;
|
||||
|
||||
// Create a simple workflow
|
||||
let workflow = WorkflowDefinition::new("test_workflow", "Simple Test Workflow")
|
||||
.add_step(StepDefinition::new(
|
||||
"step1",
|
||||
"First Step",
|
||||
Arc::new(AlwaysSucceedStep),
|
||||
))
|
||||
.add_step(StepDefinition::new(
|
||||
"step2",
|
||||
"Second Step",
|
||||
Arc::new(AlwaysSucceedStep),
|
||||
));
|
||||
|
||||
let workflow_id = workflow.id.clone();
|
||||
engine.register_workflow(workflow);
|
||||
|
||||
// Start workflow
|
||||
let instance_id = engine
|
||||
.start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for completion
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
|
||||
// Check status
|
||||
let state = engine.get_status(instance_id).unwrap();
|
||||
assert_eq!(state.status, WorkflowStatus::Completed);
|
||||
assert_eq!(state.step_states.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_workflow_with_retry() {
|
||||
let engine = WorkflowEngine::new();
|
||||
engine
|
||||
.event_bus()
|
||||
.subscribe(Arc::new(LoggingSubscriber))
|
||||
.await;
|
||||
|
||||
let counter = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// Create workflow with retry logic
|
||||
let workflow = WorkflowDefinition::new("retry_workflow", "Workflow with Retry").add_step(
|
||||
StepDefinition::new(
|
||||
"retry_step",
|
||||
"Step that retries",
|
||||
Arc::new(CountingStep {
|
||||
counter: Arc::clone(&counter),
|
||||
should_succeed_after: 3,
|
||||
}),
|
||||
)
|
||||
.with_retry(RetryPolicy {
|
||||
max_attempts: 5,
|
||||
backoff: BackoffStrategy::Fixed(Duration::from_millis(10)),
|
||||
})
|
||||
.with_timeout(Duration::from_secs(5)),
|
||||
);
|
||||
|
||||
let workflow_id = workflow.id.clone();
|
||||
engine.register_workflow(workflow);
|
||||
|
||||
// Start workflow
|
||||
let instance_id = engine
|
||||
.start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for completion
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Check that step was retried and eventually succeeded
|
||||
let state = engine.get_status(instance_id).unwrap();
|
||||
assert_eq!(state.status, WorkflowStatus::Completed);
|
||||
|
||||
let step_state = state.step_states.get(&StepId::new("retry_step")).unwrap();
|
||||
assert_eq!(step_state.status, StepStatus::Succeeded);
|
||||
assert_eq!(step_state.attempt, 3); // Should have taken 3 attempts
|
||||
|
||||
// Verify counter
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_workflow_failure_after_max_retries() {
|
||||
let engine = WorkflowEngine::new();
|
||||
engine
|
||||
.event_bus()
|
||||
.subscribe(Arc::new(LoggingSubscriber))
|
||||
.await;
|
||||
|
||||
let counter = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// Create workflow that will fail
|
||||
let workflow = WorkflowDefinition::new("failing_workflow", "Workflow that Fails").add_step(
|
||||
StepDefinition::new(
|
||||
"failing_step",
|
||||
"Step that always fails",
|
||||
Arc::new(CountingStep {
|
||||
counter: Arc::clone(&counter),
|
||||
should_succeed_after: 10, // Will never succeed within max_attempts
|
||||
}),
|
||||
)
|
||||
.with_retry(RetryPolicy {
|
||||
max_attempts: 3,
|
||||
backoff: BackoffStrategy::Fixed(Duration::from_millis(10)),
|
||||
})
|
||||
.with_failure_action(FailureAction::FailWorkflow),
|
||||
);
|
||||
|
||||
let workflow_id = workflow.id.clone();
|
||||
engine.register_workflow(workflow);
|
||||
|
||||
// Start workflow
|
||||
let instance_id = engine
|
||||
.start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for completion
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Check that workflow failed
|
||||
let state = engine.get_status(instance_id).unwrap();
|
||||
assert_eq!(state.status, WorkflowStatus::Failed);
|
||||
|
||||
let step_state = state.step_states.get(&StepId::new("failing_step")).unwrap();
|
||||
assert_eq!(step_state.status, StepStatus::Failed);
|
||||
assert_eq!(step_state.attempt, 3); // Should have tried 3 times
|
||||
|
||||
// Verify counter
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_workflow_continue_on_failure() {
|
||||
let engine = WorkflowEngine::new();
|
||||
engine
|
||||
.event_bus()
|
||||
.subscribe(Arc::new(LoggingSubscriber))
|
||||
.await;
|
||||
|
||||
let counter = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// Create workflow where first step fails but workflow continues
|
||||
let workflow = WorkflowDefinition::new("continue_workflow", "Continue on Failure")
|
||||
.add_step(
|
||||
StepDefinition::new(
|
||||
"failing_step",
|
||||
"Step that fails",
|
||||
Arc::new(CountingStep {
|
||||
counter: Arc::clone(&counter),
|
||||
should_succeed_after: 10,
|
||||
}),
|
||||
)
|
||||
.with_retry(RetryPolicy {
|
||||
max_attempts: 2,
|
||||
backoff: BackoffStrategy::Fixed(Duration::from_millis(10)),
|
||||
})
|
||||
.with_failure_action(FailureAction::ContinueNextStep),
|
||||
)
|
||||
.add_step(StepDefinition::new(
|
||||
"success_step",
|
||||
"Step that succeeds",
|
||||
Arc::new(AlwaysSucceedStep),
|
||||
));
|
||||
|
||||
let workflow_id = workflow.id.clone();
|
||||
engine.register_workflow(workflow);
|
||||
|
||||
// Start workflow
|
||||
let instance_id = engine
|
||||
.start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for completion
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Workflow should complete despite first step failing
|
||||
let state = engine.get_status(instance_id).unwrap();
|
||||
assert_eq!(state.status, WorkflowStatus::Completed);
|
||||
|
||||
// First step should be skipped
|
||||
let step1_state = state.step_states.get(&StepId::new("failing_step")).unwrap();
|
||||
assert_eq!(step1_state.status, StepStatus::Skipped);
|
||||
|
||||
// Second step should succeed
|
||||
let step2_state = state.step_states.get(&StepId::new("success_step")).unwrap();
|
||||
assert_eq!(step2_state.status, StepStatus::Succeeded);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_workflow_context_sharing() {
|
||||
let engine = WorkflowEngine::new();
|
||||
|
||||
struct ContextWriterStep {
|
||||
key: String,
|
||||
value: String,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StepExecutor for ContextWriterStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
context.set(self.key.clone(), self.value.clone());
|
||||
Ok(StepResult::Success)
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextReaderStep {
|
||||
key: String,
|
||||
expected_value: String,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StepExecutor for ContextReaderStep {
|
||||
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||
let value: Arc<String> = context
|
||||
.get(&self.key)
|
||||
.ok_or_else(|| WorkflowError::ContextValueNotFound(self.key.clone()))?;
|
||||
|
||||
if *value == self.expected_value {
|
||||
Ok(StepResult::Success)
|
||||
} else {
|
||||
Err(WorkflowError::StepFailed {
|
||||
step_id: StepId::new("reader"),
|
||||
message: format!("Expected {}, got {}", self.expected_value, value),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let workflow = WorkflowDefinition::new("context_workflow", "Context Sharing Test")
|
||||
.add_step(StepDefinition::new(
|
||||
"writer",
|
||||
"Write to context",
|
||||
Arc::new(ContextWriterStep {
|
||||
key: "test_key".to_string(),
|
||||
value: "test_value".to_string(),
|
||||
}),
|
||||
))
|
||||
.add_step(StepDefinition::new(
|
||||
"reader",
|
||||
"Read from context",
|
||||
Arc::new(ContextReaderStep {
|
||||
key: "test_key".to_string(),
|
||||
expected_value: "test_value".to_string(),
|
||||
}),
|
||||
));
|
||||
|
||||
let workflow_id = workflow.id.clone();
|
||||
engine.register_workflow(workflow);
|
||||
|
||||
let instance_id = engine
|
||||
.start_workflow(workflow_id, WorkflowContext::new(WorkflowInstanceId::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
|
||||
let state = engine.get_status(instance_id).unwrap();
|
||||
assert_eq!(state.status, WorkflowStatus::Completed);
|
||||
}
|
||||
Reference in New Issue
Block a user