[router] add worker abstraction (#7960)
This commit is contained in:
@@ -30,6 +30,8 @@ tracing-appender = "0.2.3"
|
||||
kube = { version = "0.88.1", features = ["runtime", "derive"] }
|
||||
k8s-openapi = { version = "0.21.0", features = ["v1_29"] }
|
||||
futures = "0.3"
|
||||
async-trait = "0.1"
|
||||
once_cell = "1.21"
|
||||
# Added for metrics
|
||||
metrics = "0.24.2"
|
||||
metrics-exporter-prometheus = "0.17.0"
|
||||
|
||||
57
sgl-router/src/core/error.rs
Normal file
57
sgl-router/src/core/error.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
//! Error types for the SGLang router core
|
||||
//!
|
||||
//! This module defines error types used throughout the router for worker operations.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Worker-related errors
|
||||
#[derive(Debug)]
|
||||
pub enum WorkerError {
|
||||
/// Health check failed
|
||||
HealthCheckFailed { url: String, reason: String },
|
||||
/// Worker not found
|
||||
WorkerNotFound { url: String },
|
||||
/// Invalid worker configuration
|
||||
InvalidConfiguration { message: String },
|
||||
/// Network error
|
||||
NetworkError { url: String, error: String },
|
||||
/// Worker is at capacity
|
||||
WorkerAtCapacity { url: String },
|
||||
}
|
||||
|
||||
impl fmt::Display for WorkerError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
WorkerError::HealthCheckFailed { url, reason } => {
|
||||
write!(f, "Health check failed for worker {}: {}", url, reason)
|
||||
}
|
||||
WorkerError::WorkerNotFound { url } => {
|
||||
write!(f, "Worker not found: {}", url)
|
||||
}
|
||||
WorkerError::InvalidConfiguration { message } => {
|
||||
write!(f, "Invalid worker configuration: {}", message)
|
||||
}
|
||||
WorkerError::NetworkError { url, error } => {
|
||||
write!(f, "Network error for worker {}: {}", url, error)
|
||||
}
|
||||
WorkerError::WorkerAtCapacity { url } => {
|
||||
write!(f, "Worker at capacity: {}", url)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WorkerError {}
|
||||
|
||||
/// Result type for worker operations
|
||||
pub type WorkerResult<T> = Result<T, WorkerError>;
|
||||
|
||||
/// Convert from reqwest errors to worker errors
|
||||
impl From<reqwest::Error> for WorkerError {
|
||||
fn from(err: reqwest::Error) -> Self {
|
||||
WorkerError::NetworkError {
|
||||
url: err.url().map(|u| u.to_string()).unwrap_or_default(),
|
||||
error: err.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
16
sgl-router/src/core/mod.rs
Normal file
16
sgl-router/src/core/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! Core abstractions for the SGLang router
|
||||
//!
|
||||
//! This module contains the fundamental types and traits used throughout the router:
|
||||
//! - Worker trait and implementations
|
||||
//! - Error types
|
||||
//! - Common utilities
|
||||
|
||||
pub mod error;
|
||||
pub mod worker;
|
||||
|
||||
// Re-export commonly used types at the module level
|
||||
pub use error::{WorkerError, WorkerResult};
|
||||
pub use worker::{
|
||||
start_health_checker, BasicWorker, HealthChecker, Worker, WorkerCollection, WorkerFactory,
|
||||
WorkerLoadGuard, WorkerType,
|
||||
};
|
||||
454
sgl-router/src/core/worker.rs
Normal file
454
sgl-router/src/core/worker.rs
Normal file
@@ -0,0 +1,454 @@
|
||||
use super::{WorkerError, WorkerResult};
|
||||
use async_trait::async_trait;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::fmt;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
// Shared HTTP client for health checks
|
||||
static HEALTH_CHECK_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request
|
||||
.build()
|
||||
.expect("Failed to create health check HTTP client")
|
||||
});
|
||||
|
||||
/// Core worker abstraction that represents a backend service
|
||||
#[async_trait]
|
||||
pub trait Worker: Send + Sync + fmt::Debug {
|
||||
/// Get the worker's URL
|
||||
fn url(&self) -> &str;
|
||||
|
||||
/// Get the worker's type (Regular, Prefill, or Decode)
|
||||
fn worker_type(&self) -> WorkerType;
|
||||
|
||||
/// Check if the worker is currently healthy
|
||||
fn is_healthy(&self) -> bool;
|
||||
|
||||
/// Set the worker's health status
|
||||
fn set_healthy(&self, healthy: bool);
|
||||
|
||||
/// Perform an async health check on the worker
|
||||
async fn check_health_async(&self) -> WorkerResult<()>;
|
||||
|
||||
/// Synchronous health check wrapper (for compatibility)
|
||||
fn check_health(&self) -> WorkerResult<()> {
|
||||
// Use a small runtime for synchronous contexts
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|e| WorkerError::HealthCheckFailed {
|
||||
url: self.url().to_string(),
|
||||
reason: format!("Failed to create runtime: {}", e),
|
||||
})?
|
||||
.block_on(self.check_health_async())
|
||||
}
|
||||
|
||||
/// Get the current load (number of active requests)
|
||||
fn load(&self) -> usize;
|
||||
|
||||
/// Increment the load counter
|
||||
fn increment_load(&self);
|
||||
|
||||
/// Decrement the load counter
|
||||
fn decrement_load(&self);
|
||||
|
||||
/// Get the number of processed requests
|
||||
fn processed_requests(&self) -> usize;
|
||||
|
||||
/// Increment the processed requests counter
|
||||
fn increment_processed(&self);
|
||||
|
||||
/// Get worker-specific metadata
|
||||
fn metadata(&self) -> &WorkerMetadata;
|
||||
|
||||
/// Clone the worker (for trait objects)
|
||||
fn clone_worker(&self) -> Box<dyn Worker>;
|
||||
}
|
||||
|
||||
/// Worker type classification
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum WorkerType {
|
||||
/// Regular worker for standard routing
|
||||
Regular,
|
||||
/// Prefill worker for PD disaggregated mode
|
||||
Prefill {
|
||||
/// Bootstrap port for communication with decode workers
|
||||
bootstrap_port: Option<u16>,
|
||||
},
|
||||
/// Decode worker for PD disaggregated mode
|
||||
Decode,
|
||||
}
|
||||
|
||||
impl fmt::Display for WorkerType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
WorkerType::Regular => write!(f, "Regular"),
|
||||
WorkerType::Prefill { bootstrap_port } => match bootstrap_port {
|
||||
Some(port) => write!(f, "Prefill(bootstrap:{})", port),
|
||||
None => write!(f, "Prefill"),
|
||||
},
|
||||
WorkerType::Decode => write!(f, "Decode"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Health check configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HealthConfig {
|
||||
/// Timeout for health checks in seconds
|
||||
pub timeout_secs: u64,
|
||||
/// Interval between health checks in seconds
|
||||
pub check_interval_secs: u64,
|
||||
/// Health check endpoint path
|
||||
pub endpoint: String,
|
||||
}
|
||||
|
||||
impl Default for HealthConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
timeout_secs: 5,
|
||||
check_interval_secs: 30,
|
||||
endpoint: "/health".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata associated with a worker
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkerMetadata {
|
||||
/// Worker URL
|
||||
pub url: String,
|
||||
/// Worker type
|
||||
pub worker_type: WorkerType,
|
||||
/// Additional labels/tags
|
||||
pub labels: std::collections::HashMap<String, String>,
|
||||
/// Health check configuration
|
||||
pub health_config: HealthConfig,
|
||||
}
|
||||
|
||||
/// Basic worker implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BasicWorker {
|
||||
metadata: WorkerMetadata,
|
||||
load_counter: Arc<AtomicUsize>,
|
||||
processed_counter: Arc<AtomicUsize>,
|
||||
healthy: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl BasicWorker {
|
||||
pub fn new(url: String, worker_type: WorkerType) -> Self {
|
||||
let metadata = WorkerMetadata {
|
||||
url: url.clone(),
|
||||
worker_type,
|
||||
labels: std::collections::HashMap::new(),
|
||||
health_config: HealthConfig::default(),
|
||||
};
|
||||
|
||||
Self {
|
||||
metadata,
|
||||
load_counter: Arc::new(AtomicUsize::new(0)),
|
||||
processed_counter: Arc::new(AtomicUsize::new(0)),
|
||||
healthy: Arc::new(AtomicBool::new(true)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_labels(mut self, labels: std::collections::HashMap<String, String>) -> Self {
|
||||
self.metadata.labels = labels;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_health_config(mut self, config: HealthConfig) -> Self {
|
||||
self.metadata.health_config = config;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Worker for BasicWorker {
|
||||
fn url(&self) -> &str {
|
||||
&self.metadata.url
|
||||
}
|
||||
|
||||
fn worker_type(&self) -> WorkerType {
|
||||
self.metadata.worker_type.clone()
|
||||
}
|
||||
|
||||
fn is_healthy(&self) -> bool {
|
||||
self.healthy.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
fn set_healthy(&self, healthy: bool) {
|
||||
self.healthy.store(healthy, Ordering::Release);
|
||||
}
|
||||
|
||||
async fn check_health_async(&self) -> WorkerResult<()> {
|
||||
use std::time::Duration;
|
||||
|
||||
// Perform actual HTTP health check
|
||||
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
|
||||
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
||||
|
||||
// Use the shared client with a custom timeout for this request
|
||||
match HEALTH_CHECK_CLIENT
|
||||
.get(&health_url)
|
||||
.timeout(timeout)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
self.set_healthy(true);
|
||||
Ok(())
|
||||
} else {
|
||||
self.set_healthy(false);
|
||||
Err(WorkerError::HealthCheckFailed {
|
||||
url: self.url().to_string(),
|
||||
reason: format!("Health check returned status: {}", response.status()),
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
self.set_healthy(false);
|
||||
Err(WorkerError::HealthCheckFailed {
|
||||
url: self.url().to_string(),
|
||||
reason: format!("Health check request failed: {}", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load(&self) -> usize {
|
||||
self.load_counter.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn increment_load(&self) {
|
||||
self.load_counter.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn decrement_load(&self) {
|
||||
self.load_counter
|
||||
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
|
||||
current.checked_sub(1)
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn processed_requests(&self) -> usize {
|
||||
self.processed_counter.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn increment_processed(&self) {
|
||||
self.processed_counter.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn metadata(&self) -> &WorkerMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
fn clone_worker(&self) -> Box<dyn Worker> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Worker factory for creating workers of different types
|
||||
pub struct WorkerFactory;
|
||||
|
||||
impl WorkerFactory {
|
||||
/// Create a regular worker
|
||||
pub fn create_regular(url: String) -> Box<dyn Worker> {
|
||||
Box::new(BasicWorker::new(url, WorkerType::Regular))
|
||||
}
|
||||
|
||||
/// Create a prefill worker with optional bootstrap port
|
||||
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
|
||||
Box::new(BasicWorker::new(
|
||||
url,
|
||||
WorkerType::Prefill { bootstrap_port },
|
||||
))
|
||||
}
|
||||
|
||||
/// Create a decode worker
|
||||
pub fn create_decode(url: String) -> Box<dyn Worker> {
|
||||
Box::new(BasicWorker::new(url, WorkerType::Decode))
|
||||
}
|
||||
|
||||
/// Create workers from URLs with automatic type detection
|
||||
pub fn create_from_urls(
|
||||
regular_urls: Vec<String>,
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
decode_urls: Vec<String>,
|
||||
) -> (
|
||||
Vec<Box<dyn Worker>>,
|
||||
Vec<Box<dyn Worker>>,
|
||||
Vec<Box<dyn Worker>>,
|
||||
) {
|
||||
let regular_workers: Vec<Box<dyn Worker>> =
|
||||
regular_urls.into_iter().map(Self::create_regular).collect();
|
||||
|
||||
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||
.into_iter()
|
||||
.map(|(url, port)| Self::create_prefill(url, port))
|
||||
.collect();
|
||||
|
||||
let decode_workers: Vec<Box<dyn Worker>> =
|
||||
decode_urls.into_iter().map(Self::create_decode).collect();
|
||||
|
||||
(regular_workers, prefill_workers, decode_workers)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper trait for collections of workers
|
||||
pub trait WorkerCollection {
|
||||
fn healthy_workers(&self) -> Vec<&dyn Worker>;
|
||||
fn total_load(&self) -> usize;
|
||||
fn find_worker(&self, url: &str) -> Option<&dyn Worker>;
|
||||
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>>;
|
||||
}
|
||||
|
||||
impl WorkerCollection for Vec<Box<dyn Worker>> {
|
||||
fn healthy_workers(&self) -> Vec<&dyn Worker> {
|
||||
self.iter()
|
||||
.filter(|w| w.is_healthy())
|
||||
.map(|w| w.as_ref())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn total_load(&self) -> usize {
|
||||
self.iter().map(|w| w.load()).sum()
|
||||
}
|
||||
|
||||
fn find_worker(&self, url: &str) -> Option<&dyn Worker> {
|
||||
self.iter().find(|w| w.url() == url).map(|w| w.as_ref())
|
||||
}
|
||||
|
||||
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>> {
|
||||
self.iter_mut().find(|w| w.url() == url)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a list of worker URLs to worker trait objects
|
||||
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
|
||||
urls.into_iter()
|
||||
.map(WorkerFactory::create_regular)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convert worker trait objects back to URLs
|
||||
pub fn workers_to_urls(workers: &[Box<dyn Worker>]) -> Vec<String> {
|
||||
workers.iter().map(|w| w.url().to_string()).collect()
|
||||
}
|
||||
|
||||
/// RAII guard for worker load management
|
||||
pub struct WorkerLoadGuard<'a> {
|
||||
workers: Vec<&'a dyn Worker>,
|
||||
}
|
||||
|
||||
impl<'a> WorkerLoadGuard<'a> {
|
||||
/// Create a new load guard for a single worker
|
||||
pub fn new(worker: &'a dyn Worker) -> Self {
|
||||
worker.increment_load();
|
||||
Self {
|
||||
workers: vec![worker],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new load guard for multiple workers
|
||||
pub fn new_multi(workers: Vec<&'a dyn Worker>) -> Self {
|
||||
// Increment load counters for all workers
|
||||
for worker in &workers {
|
||||
worker.increment_load();
|
||||
}
|
||||
Self { workers }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for WorkerLoadGuard<'a> {
|
||||
fn drop(&mut self) {
|
||||
// Decrement load counters for all workers
|
||||
for worker in &self.workers {
|
||||
worker.decrement_load();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Health checker handle with graceful shutdown
|
||||
pub struct HealthChecker {
|
||||
handle: tokio::task::JoinHandle<()>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for HealthChecker {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("HealthChecker")
|
||||
.field("shutdown", &self.shutdown.load(Ordering::Relaxed))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl HealthChecker {
|
||||
/// Shutdown the health checker gracefully
|
||||
pub async fn shutdown(self) {
|
||||
self.shutdown.store(true, Ordering::Release);
|
||||
let _ = self.handle.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Start an async background health checker for a collection of workers
|
||||
pub fn start_health_checker(
|
||||
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
|
||||
check_interval_secs: u64,
|
||||
) -> HealthChecker {
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
let shutdown_clone = shutdown.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut interval =
|
||||
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
// Check for shutdown signal
|
||||
if shutdown_clone.load(Ordering::Acquire) {
|
||||
tracing::info!("Health checker shutting down");
|
||||
break;
|
||||
}
|
||||
|
||||
// Check health of all workers
|
||||
let workers_to_check = match workers.read() {
|
||||
Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::<Vec<_>>(),
|
||||
Err(poisoned) => {
|
||||
tracing::error!("Worker lock poisoned: {}", poisoned);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Perform health checks concurrently
|
||||
let health_checks = workers_to_check.iter().map(|worker| {
|
||||
let worker_url = worker.url().to_string();
|
||||
let was_healthy = worker.is_healthy();
|
||||
|
||||
async move {
|
||||
match worker.check_health_async().await {
|
||||
Ok(_) => {
|
||||
if !was_healthy {
|
||||
tracing::info!("Worker {} is now healthy", worker_url);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if was_healthy {
|
||||
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Execute all health checks concurrently
|
||||
futures::future::join_all(health_checks).await;
|
||||
}
|
||||
});
|
||||
|
||||
HealthChecker { handle, shutdown }
|
||||
}
|
||||
@@ -2,6 +2,7 @@ use pyo3::prelude::*;
|
||||
pub mod config;
|
||||
pub mod logging;
|
||||
use std::collections::HashMap;
|
||||
pub mod core;
|
||||
pub mod openai_api_types;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// PD (Prefill-Decode) Router Implementation
|
||||
// This module handles routing for disaggregated prefill-decode systems
|
||||
|
||||
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
||||
use crate::pd_types::{
|
||||
Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy,
|
||||
api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy,
|
||||
};
|
||||
use crate::tree::Tree;
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt};
|
||||
use metrics::{counter, histogram};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -21,49 +21,17 @@ use uuid::Uuid;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PDRouter {
|
||||
pub prefill_workers: Arc<RwLock<Vec<EngineInfo>>>,
|
||||
pub decode_workers: Arc<RwLock<Vec<EngineInfo>>>,
|
||||
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
pub selection_policy: PDSelectionPolicy,
|
||||
pub load_tracking: Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
|
||||
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
||||
pub timeout_secs: u64,
|
||||
pub interval_secs: u64,
|
||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
// RAII guard for load tracking to ensure cleanup even on panic
|
||||
struct LoadGuard<'a> {
|
||||
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
|
||||
urls: Vec<String>,
|
||||
}
|
||||
|
||||
impl<'a> LoadGuard<'a> {
|
||||
fn new(
|
||||
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
|
||||
urls: Vec<String>,
|
||||
) -> Self {
|
||||
// Increment counters
|
||||
for url in &urls {
|
||||
let counter = tracking
|
||||
.entry(url.clone())
|
||||
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
|
||||
counter.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
LoadGuard { tracking, urls }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LoadGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
// Guaranteed cleanup even on panic
|
||||
for url in &self.urls {
|
||||
if let Some(counter) = self.tracking.get(url) {
|
||||
counter.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
}
|
||||
|
||||
impl PDRouter {
|
||||
@@ -73,9 +41,6 @@ impl PDRouter {
|
||||
url: String,
|
||||
bootstrap_port: Option<u16>,
|
||||
) -> Result<String, PDRouterError> {
|
||||
// Create EngineInfo for the new prefill server
|
||||
let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port);
|
||||
|
||||
// Wait for the new server to be healthy
|
||||
crate::router::Router::wait_for_healthy_workers(
|
||||
&[url.clone()],
|
||||
@@ -84,6 +49,9 @@ impl PDRouter {
|
||||
)
|
||||
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
||||
|
||||
// Create Worker for the new prefill server
|
||||
let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port);
|
||||
|
||||
// Add to prefill workers list
|
||||
let mut workers = self
|
||||
.prefill_workers
|
||||
@@ -93,15 +61,11 @@ impl PDRouter {
|
||||
})?;
|
||||
|
||||
// Check if already exists
|
||||
if workers.iter().any(|w| w.url == url) {
|
||||
if workers.iter().any(|w| w.url() == &url) {
|
||||
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
||||
}
|
||||
|
||||
workers.push(engine_info);
|
||||
|
||||
// Initialize load tracking
|
||||
self.load_tracking
|
||||
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||
workers.push(worker);
|
||||
|
||||
// Add to cache tree if using cache-aware policy
|
||||
if let Some(ref tree) = self.prefill_tree {
|
||||
@@ -113,9 +77,6 @@ impl PDRouter {
|
||||
}
|
||||
|
||||
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
|
||||
// Create EngineInfo for the new decode server
|
||||
let engine_info = EngineInfo::new_decode(url.clone());
|
||||
|
||||
// Wait for the new server to be healthy
|
||||
crate::router::Router::wait_for_healthy_workers(
|
||||
&[url.clone()],
|
||||
@@ -124,6 +85,9 @@ impl PDRouter {
|
||||
)
|
||||
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
||||
|
||||
// Create Worker for the new decode server
|
||||
let worker = WorkerFactory::create_decode(url.clone());
|
||||
|
||||
// Add to decode workers list
|
||||
let mut workers = self
|
||||
.decode_workers
|
||||
@@ -133,15 +97,14 @@ impl PDRouter {
|
||||
})?;
|
||||
|
||||
// Check if already exists
|
||||
if workers.iter().any(|w| w.url == url) {
|
||||
if workers.iter().any(|w| w.url() == &url) {
|
||||
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
||||
}
|
||||
|
||||
workers.push(engine_info);
|
||||
workers.push(worker);
|
||||
|
||||
// Initialize load tracking
|
||||
self.load_tracking
|
||||
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||
// Worker tracks its own load internally
|
||||
|
||||
info!("Added decode server: {}", url);
|
||||
Ok(format!("Successfully added decode server: {}", url))
|
||||
@@ -157,7 +120,7 @@ impl PDRouter {
|
||||
|
||||
// Find and remove the server
|
||||
let initial_len = workers.len();
|
||||
workers.retain(|w| w.url != url);
|
||||
workers.retain(|w| w.url() != url);
|
||||
|
||||
if workers.len() == initial_len {
|
||||
return Err(PDRouterError::WorkerNotFound {
|
||||
@@ -166,7 +129,7 @@ impl PDRouter {
|
||||
}
|
||||
|
||||
// Remove from load tracking
|
||||
self.load_tracking.remove(url);
|
||||
// Worker load tracking is internal
|
||||
|
||||
// Remove from cache tree if using cache-aware policy
|
||||
if let Some(ref tree) = self.prefill_tree {
|
||||
@@ -174,7 +137,7 @@ impl PDRouter {
|
||||
let mut tree_guard = tree.lock().unwrap();
|
||||
*tree_guard = Tree::new();
|
||||
for worker in workers.iter() {
|
||||
tree_guard.insert("", &worker.url);
|
||||
tree_guard.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,7 +155,7 @@ impl PDRouter {
|
||||
|
||||
// Find and remove the server
|
||||
let initial_len = workers.len();
|
||||
workers.retain(|w| w.url != url);
|
||||
workers.retain(|w| w.url() != url);
|
||||
|
||||
if workers.len() == initial_len {
|
||||
return Err(PDRouterError::WorkerNotFound {
|
||||
@@ -200,9 +163,6 @@ impl PDRouter {
|
||||
});
|
||||
}
|
||||
|
||||
// Remove from load tracking
|
||||
self.load_tracking.remove(url);
|
||||
|
||||
info!("Removed decode server: {}", url);
|
||||
Ok(format!("Successfully removed decode server: {}", url))
|
||||
}
|
||||
@@ -214,41 +174,32 @@ impl PDRouter {
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<Self, String> {
|
||||
// Convert URLs to EngineInfo
|
||||
let prefill_workers: Vec<EngineInfo> = prefill_urls
|
||||
// Convert URLs to Worker trait objects
|
||||
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||
.into_iter()
|
||||
.map(|(url, port)| EngineInfo::new_prefill(url, port))
|
||||
.map(|(url, port)| WorkerFactory::create_prefill(url, port))
|
||||
.collect();
|
||||
|
||||
let decode_workers: Vec<EngineInfo> = decode_urls
|
||||
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
||||
.into_iter()
|
||||
.map(EngineInfo::new_decode)
|
||||
.map(WorkerFactory::create_decode)
|
||||
.collect();
|
||||
|
||||
// Wait for PD workers to be healthy
|
||||
let all_urls: Vec<String> = prefill_workers
|
||||
.iter()
|
||||
.chain(decode_workers.iter())
|
||||
.map(|engine| engine.url.clone())
|
||||
.map(|worker| worker.url().to_string())
|
||||
.collect();
|
||||
crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?;
|
||||
|
||||
// Initialize load tracking with atomic counters
|
||||
let load_tracking = Arc::new(dashmap::DashMap::new());
|
||||
for engine in &prefill_workers {
|
||||
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||
}
|
||||
for engine in &decode_workers {
|
||||
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||
}
|
||||
|
||||
// Initialize cache-aware components if needed
|
||||
let prefill_tree = match &selection_policy {
|
||||
PDSelectionPolicy::CacheAware { .. } => {
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
// Initialize tree with prefill workers
|
||||
for engine in &prefill_workers {
|
||||
tree.lock().unwrap().insert("", &engine.url);
|
||||
for worker in &prefill_workers {
|
||||
tree.lock().unwrap().insert("", worker.url());
|
||||
}
|
||||
Some(tree)
|
||||
}
|
||||
@@ -283,17 +234,27 @@ impl PDRouter {
|
||||
None
|
||||
};
|
||||
|
||||
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
|
||||
let decode_workers = Arc::new(RwLock::new(decode_workers));
|
||||
|
||||
// Start health checkers for both worker pools
|
||||
let prefill_health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs);
|
||||
let decode_health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
|
||||
|
||||
Ok(PDRouter {
|
||||
prefill_workers: Arc::new(RwLock::new(prefill_workers)),
|
||||
decode_workers: Arc::new(RwLock::new(decode_workers)),
|
||||
prefill_workers,
|
||||
decode_workers,
|
||||
selection_policy,
|
||||
load_tracking,
|
||||
prefill_tree,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
worker_loads,
|
||||
load_monitor_handle,
|
||||
http_client,
|
||||
_prefill_health_checker: Some(prefill_health_checker),
|
||||
_decode_health_checker: Some(decode_health_checker),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -330,11 +291,13 @@ impl PDRouter {
|
||||
// Log routing decision
|
||||
info!(
|
||||
"PD routing: {} -> prefill={}, decode={}",
|
||||
route, prefill.url, decode.url
|
||||
route,
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Add bootstrap info using the trait method
|
||||
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
|
||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||
error!("Failed to add bootstrap info: {}", e);
|
||||
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
|
||||
return HttpResponse::InternalServerError()
|
||||
@@ -356,8 +319,8 @@ impl PDRouter {
|
||||
req,
|
||||
json_with_bootstrap,
|
||||
route,
|
||||
&prefill,
|
||||
&decode,
|
||||
prefill.as_ref(),
|
||||
decode.as_ref(),
|
||||
is_stream,
|
||||
return_logprob,
|
||||
start,
|
||||
@@ -397,11 +360,13 @@ impl PDRouter {
|
||||
// Log routing decision
|
||||
info!(
|
||||
"PD routing: {} -> prefill={}, decode={}",
|
||||
route, prefill.url, decode.url
|
||||
route,
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Add bootstrap info using the trait method
|
||||
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
|
||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||
error!("Failed to add bootstrap info: {}", e);
|
||||
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
|
||||
return HttpResponse::InternalServerError()
|
||||
@@ -423,8 +388,8 @@ impl PDRouter {
|
||||
req,
|
||||
json_with_bootstrap,
|
||||
route,
|
||||
&prefill,
|
||||
&decode,
|
||||
prefill.as_ref(),
|
||||
decode.as_ref(),
|
||||
is_stream,
|
||||
return_logprob,
|
||||
start,
|
||||
@@ -440,22 +405,23 @@ impl PDRouter {
|
||||
req: &HttpRequest,
|
||||
json_request: serde_json::Value,
|
||||
route: &str,
|
||||
prefill: &EngineInfo,
|
||||
decode: &EngineInfo,
|
||||
prefill: &dyn Worker,
|
||||
decode: &dyn Worker,
|
||||
is_stream: bool,
|
||||
return_logprob: bool,
|
||||
start_time: Instant,
|
||||
) -> HttpResponse {
|
||||
// Update load tracking for both workers
|
||||
let _guard = LoadGuard::new(
|
||||
&self.load_tracking,
|
||||
vec![prefill.url.clone(), decode.url.clone()],
|
||||
);
|
||||
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
||||
|
||||
// Build requests using .json() method
|
||||
let mut prefill_request = client.post(prefill.api_path(route)).json(&json_request);
|
||||
let mut prefill_request = client
|
||||
.post(api_path(prefill.url(), route))
|
||||
.json(&json_request);
|
||||
|
||||
let mut decode_request = client.post(decode.api_path(route)).json(&json_request);
|
||||
let mut decode_request = client
|
||||
.post(api_path(decode.url(), route))
|
||||
.json(&json_request);
|
||||
|
||||
// Copy headers from original request
|
||||
for (name, value) in crate::router::copy_request_headers(req) {
|
||||
@@ -474,9 +440,9 @@ impl PDRouter {
|
||||
histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string())
|
||||
.record(duration.as_secs_f64());
|
||||
counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1);
|
||||
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url.to_string())
|
||||
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url().to_string())
|
||||
.increment(1);
|
||||
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url.to_string())
|
||||
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url().to_string())
|
||||
.increment(1);
|
||||
|
||||
// Process decode response
|
||||
@@ -486,10 +452,11 @@ impl PDRouter {
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !status.is_success() {
|
||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()).increment(1);
|
||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()).increment(1);
|
||||
error!(
|
||||
"Decode server {} returned error status: {}",
|
||||
decode.url, status
|
||||
decode.url(),
|
||||
status
|
||||
);
|
||||
|
||||
// Return the error response from decode server
|
||||
@@ -508,9 +475,10 @@ impl PDRouter {
|
||||
if let Err(e) = &prefill_result {
|
||||
error!(
|
||||
"Prefill server {} failed (non-critical): {}",
|
||||
prefill.url, e
|
||||
prefill.url(),
|
||||
e
|
||||
);
|
||||
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url.to_string()).increment(1);
|
||||
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url().to_string()).increment(1);
|
||||
}
|
||||
|
||||
if is_stream {
|
||||
@@ -559,7 +527,7 @@ impl PDRouter {
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming({
|
||||
let decode_url = decode.url.clone();
|
||||
let decode_url = decode.url().to_string();
|
||||
res.bytes_stream().map_err(move |e| {
|
||||
error!("Stream error from decode server {}: {}", decode_url, e);
|
||||
counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1);
|
||||
@@ -587,7 +555,7 @@ impl PDRouter {
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Decode request failed: {}", e);
|
||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string())
|
||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string())
|
||||
.increment(1);
|
||||
HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
|
||||
}
|
||||
@@ -652,7 +620,7 @@ impl PDRouter {
|
||||
async fn select_pd_pair(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
) -> Result<(EngineInfo, EngineInfo), String> {
|
||||
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
// Check we have workers
|
||||
if self
|
||||
.prefill_workers
|
||||
@@ -681,17 +649,17 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
fn select_random(&self) -> Result<(EngineInfo, EngineInfo), String> {
|
||||
fn select_random(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
||||
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
||||
|
||||
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone();
|
||||
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone();
|
||||
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone_worker();
|
||||
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone_worker();
|
||||
|
||||
Ok((prefill, decode))
|
||||
}
|
||||
|
||||
async fn select_power_of_two(&self) -> Result<(EngineInfo, EngineInfo), String> {
|
||||
async fn select_power_of_two(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
||||
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
||||
|
||||
@@ -700,33 +668,45 @@ impl PDRouter {
|
||||
|
||||
let loads = self.worker_loads.borrow();
|
||||
|
||||
let p1_load = loads.get(&prefill_list[p1_idx].url).copied().unwrap_or(0);
|
||||
let p2_load = loads.get(&prefill_list[p2_idx].url).copied().unwrap_or(0);
|
||||
let d1_load = loads.get(&decode_list[d1_idx].url).copied().unwrap_or(0);
|
||||
let d2_load = loads.get(&decode_list[d2_idx].url).copied().unwrap_or(0);
|
||||
let p1_load = loads
|
||||
.get(prefill_list[p1_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
let p2_load = loads
|
||||
.get(prefill_list[p2_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
let d1_load = loads
|
||||
.get(decode_list[d1_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
let d2_load = loads
|
||||
.get(decode_list[d2_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
|
||||
info!(
|
||||
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}",
|
||||
prefill_list[p1_idx].url,
|
||||
prefill_list[p1_idx].url(),
|
||||
p1_load,
|
||||
prefill_list[p2_idx].url,
|
||||
prefill_list[p2_idx].url(),
|
||||
p2_load,
|
||||
decode_list[d1_idx].url,
|
||||
decode_list[d1_idx].url(),
|
||||
d1_load,
|
||||
decode_list[d2_idx].url,
|
||||
decode_list[d2_idx].url(),
|
||||
d2_load
|
||||
);
|
||||
|
||||
let selected_prefill = if p1_load <= p2_load {
|
||||
prefill_list[p1_idx].clone()
|
||||
prefill_list[p1_idx].clone_worker()
|
||||
} else {
|
||||
prefill_list[p2_idx].clone()
|
||||
prefill_list[p2_idx].clone_worker()
|
||||
};
|
||||
|
||||
let selected_decode = if d1_load <= d2_load {
|
||||
decode_list[d1_idx].clone()
|
||||
decode_list[d1_idx].clone_worker()
|
||||
} else {
|
||||
decode_list[d2_idx].clone()
|
||||
decode_list[d2_idx].clone_worker()
|
||||
};
|
||||
|
||||
Ok((selected_prefill, selected_decode))
|
||||
@@ -868,11 +848,11 @@ impl PDRouter {
|
||||
let mut worker_infos = Vec::new();
|
||||
|
||||
for worker in self.prefill_workers.read().unwrap().iter() {
|
||||
worker_infos.push((worker.url.clone(), "prefill"));
|
||||
worker_infos.push((worker.url().to_string(), "prefill"));
|
||||
}
|
||||
|
||||
for worker in self.decode_workers.read().unwrap().iter() {
|
||||
worker_infos.push((worker.url.clone(), "decode"));
|
||||
worker_infos.push((worker.url().to_string(), "decode"));
|
||||
}
|
||||
|
||||
// Create tasks with URL tracking
|
||||
@@ -922,7 +902,7 @@ impl PDRouter {
|
||||
pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse {
|
||||
// Get info from the first decode server to match sglang's server info format
|
||||
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
|
||||
workers.first().map(|w| w.url.clone())
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return HttpResponse::InternalServerError().body("Failed to access decode workers");
|
||||
};
|
||||
@@ -967,7 +947,7 @@ impl PDRouter {
|
||||
pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse {
|
||||
// Get first prefill worker URL to avoid holding lock across await
|
||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers.first().map(|w| w.url.clone())
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
||||
};
|
||||
@@ -1005,14 +985,14 @@ impl PDRouter {
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url.clone())
|
||||
.map(|w| w.url().to_string())
|
||||
.collect();
|
||||
let d_urls: Vec<_> = self
|
||||
.decode_workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url.clone())
|
||||
.map(|w| w.url().to_string())
|
||||
.collect();
|
||||
|
||||
let mut prefill_loads = Vec::new();
|
||||
@@ -1048,7 +1028,7 @@ impl PDRouter {
|
||||
// Get model info from the first prefill server (matches original Rust PDLB behavior)
|
||||
// Get first prefill worker URL to avoid holding lock across await
|
||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers.first().map(|w| w.url.clone())
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
||||
};
|
||||
@@ -1084,13 +1064,13 @@ impl PDRouter {
|
||||
|
||||
// Flush cache on all prefill servers
|
||||
for worker in self.prefill_workers.read().unwrap().iter() {
|
||||
let url = format!("{}/flush_cache", worker.url);
|
||||
let url = format!("{}/flush_cache", worker.url());
|
||||
tasks.push(client.post(&url).send());
|
||||
}
|
||||
|
||||
// Flush cache on all decode servers
|
||||
for worker in self.decode_workers.read().unwrap().iter() {
|
||||
let url = format!("{}/flush_cache", worker.url);
|
||||
let url = format!("{}/flush_cache", worker.url());
|
||||
tasks.push(client.post(&url).send());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// Essential PDLB types extracted for PD routing
|
||||
|
||||
use crate::core::{Worker, WorkerType};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
@@ -28,52 +29,21 @@ pub enum PDRouterError {
|
||||
Timeout { url: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EngineType {
|
||||
Prefill,
|
||||
Decode,
|
||||
// Helper functions for workers
|
||||
pub fn api_path(url: &str, api_path: &str) -> String {
|
||||
if api_path.starts_with("/") {
|
||||
format!("{}{}", url, api_path)
|
||||
} else {
|
||||
format!("{}/{}", url, api_path)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EngineInfo {
|
||||
pub engine_type: EngineType,
|
||||
pub url: String,
|
||||
pub bootstrap_port: Option<u16>,
|
||||
}
|
||||
|
||||
impl EngineInfo {
|
||||
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Prefill,
|
||||
url,
|
||||
bootstrap_port,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_decode(url: String) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Decode,
|
||||
url,
|
||||
bootstrap_port: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn api_path(&self, api_path: &str) -> String {
|
||||
if api_path.starts_with("/") {
|
||||
format!("{}{}", self.url, api_path)
|
||||
} else {
|
||||
format!("{}/{}", self.url, api_path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_hostname(&self) -> String {
|
||||
// Simple hostname extraction without external dependencies
|
||||
let url = self
|
||||
.url
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://");
|
||||
url.split(':').next().unwrap_or("localhost").to_string()
|
||||
}
|
||||
pub fn get_hostname(url: &str) -> String {
|
||||
// Simple hostname extraction without external dependencies
|
||||
let url = url
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://");
|
||||
url.split(':').next().unwrap_or("localhost").to_string()
|
||||
}
|
||||
|
||||
// PD-specific routing policies
|
||||
@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync {
|
||||
bootstrap_room: BootstrapRoom,
|
||||
);
|
||||
|
||||
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
|
||||
fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> {
|
||||
let batch_size = self.get_batch_size()?;
|
||||
|
||||
// Extract bootstrap port from prefill worker if it's a prefill type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
|
||||
if let Some(batch_size) = batch_size {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
|
||||
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
|
||||
BootstrapHost::Batch(vec![hostname; batch_size]),
|
||||
BootstrapPort::Batch(vec![bootstrap_port; batch_size]),
|
||||
// Use high-quality random numbers to minimize collision risk
|
||||
BootstrapRoom::Batch(
|
||||
(0..batch_size)
|
||||
@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync {
|
||||
);
|
||||
} else {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Single(prefill_info.get_hostname()),
|
||||
BootstrapPort::Single(prefill_info.bootstrap_port),
|
||||
BootstrapHost::Single(hostname),
|
||||
BootstrapPort::Single(bootstrap_port),
|
||||
BootstrapRoom::Single({
|
||||
// Use high-quality random number for single requests too
|
||||
let r1 = rand::random::<u64>();
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::core::{HealthChecker, Worker, WorkerFactory};
|
||||
use crate::pd_router::PDRouter;
|
||||
use crate::pd_types::PDSelectionPolicy;
|
||||
use crate::tree::Tree;
|
||||
@@ -5,7 +6,6 @@ use ::metrics::{counter, gauge, histogram};
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
@@ -30,15 +30,17 @@ pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||
#[derive(Debug)]
|
||||
pub enum Router {
|
||||
RoundRobin {
|
||||
worker_urls: Arc<RwLock<Vec<String>>>,
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
current_index: AtomicUsize,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
_health_checker: Option<HealthChecker>,
|
||||
},
|
||||
Random {
|
||||
worker_urls: Arc<RwLock<Vec<String>>>,
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
_health_checker: Option<HealthChecker>,
|
||||
},
|
||||
PrefillDecode {
|
||||
pd_router: Arc<PDRouter>,
|
||||
@@ -104,16 +106,15 @@ pub enum Router {
|
||||
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
|
||||
during the next eviction cycle.
|
||||
*/
|
||||
worker_urls: Arc<RwLock<Vec<String>>>,
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
tree: Arc<Mutex<Tree>>,
|
||||
running_queue: Arc<Mutex<HashMap<String, usize>>>,
|
||||
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
_eviction_thread: Option<thread::JoinHandle<()>>,
|
||||
_health_checker: Option<HealthChecker>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -192,25 +193,43 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// Create Worker trait objects from URLs
|
||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||
.iter()
|
||||
.map(|url| WorkerFactory::create_regular(url.clone()))
|
||||
.collect();
|
||||
|
||||
// Create router based on policy...
|
||||
Ok(match policy_config {
|
||||
PolicyConfig::RandomConfig {
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
} => Router::Random {
|
||||
worker_urls: Arc::new(RwLock::new(worker_urls)),
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
},
|
||||
} => {
|
||||
let workers = Arc::new(RwLock::new(workers));
|
||||
let health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
|
||||
Router::Random {
|
||||
workers,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
_health_checker: Some(health_checker),
|
||||
}
|
||||
}
|
||||
PolicyConfig::RoundRobinConfig {
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
} => Router::RoundRobin {
|
||||
worker_urls: Arc::new(RwLock::new(worker_urls)),
|
||||
current_index: std::sync::atomic::AtomicUsize::new(0),
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
},
|
||||
} => {
|
||||
let workers = Arc::new(RwLock::new(workers));
|
||||
let health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
|
||||
Router::RoundRobin {
|
||||
workers,
|
||||
current_index: std::sync::atomic::AtomicUsize::new(0),
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
_health_checker: Some(health_checker),
|
||||
}
|
||||
}
|
||||
PolicyConfig::CacheAwareConfig {
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
@@ -220,24 +239,12 @@ impl Router {
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
} => {
|
||||
let mut running_queue = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
running_queue.insert(url.clone(), 0);
|
||||
}
|
||||
|
||||
let mut processed_queue = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
processed_queue.insert(url.clone(), 0);
|
||||
}
|
||||
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
let running_queue = Arc::new(Mutex::new(running_queue));
|
||||
let processed_queue = Arc::new(Mutex::new(processed_queue));
|
||||
|
||||
// Create background eviction thread
|
||||
let tree_clone = Arc::clone(&tree);
|
||||
let processed_queue_clone = Arc::clone(&processed_queue);
|
||||
let running_queue_clone = Arc::clone(&running_queue);
|
||||
let workers = Arc::new(RwLock::new(workers));
|
||||
let workers_clone = Arc::clone(&workers);
|
||||
let eviction_thread = thread::spawn(move || {
|
||||
loop {
|
||||
// Sleep for the specified interval
|
||||
@@ -246,32 +253,41 @@ impl Router {
|
||||
let locked_tree_clone = tree_clone.lock().unwrap();
|
||||
// Run eviction
|
||||
locked_tree_clone.evict_tenant_by_size(max_tree_size);
|
||||
drop(locked_tree_clone);
|
||||
|
||||
// Print the process queue
|
||||
let locked_processed_queue = processed_queue_clone.lock().unwrap();
|
||||
info!("Processed Queue: {:?}", locked_processed_queue);
|
||||
// Log worker loads and processed requests
|
||||
let workers_guard = workers_clone.read().unwrap();
|
||||
let loads: Vec<(String, usize)> = workers_guard
|
||||
.iter()
|
||||
.map(|w| (w.url().to_string(), w.load()))
|
||||
.collect();
|
||||
info!("Worker loads: {:?}", loads);
|
||||
|
||||
// Print the running queue
|
||||
let locked_running_queue = running_queue_clone.lock().unwrap();
|
||||
info!("Running Queue: {:?}", locked_running_queue);
|
||||
let processed: Vec<(String, usize)> = workers_guard
|
||||
.iter()
|
||||
.map(|w| (w.url().to_string(), w.processed_requests()))
|
||||
.collect();
|
||||
info!("Processed requests: {:?}", processed);
|
||||
}
|
||||
});
|
||||
|
||||
for url in &worker_urls {
|
||||
tree.lock().unwrap().insert("", url);
|
||||
for worker in workers.read().unwrap().iter() {
|
||||
tree.lock().unwrap().insert("", worker.url());
|
||||
}
|
||||
|
||||
let health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
|
||||
|
||||
Router::CacheAware {
|
||||
worker_urls: Arc::new(RwLock::new(worker_urls)),
|
||||
workers,
|
||||
tree,
|
||||
running_queue,
|
||||
processed_queue,
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
_eviction_thread: Some(eviction_thread),
|
||||
_health_checker: Some(health_checker),
|
||||
}
|
||||
}
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
@@ -297,16 +313,18 @@ impl Router {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a reference to the worker URLs shared across threads
|
||||
pub fn get_worker_urls(&self) -> Arc<RwLock<Vec<String>>> {
|
||||
/// Get the current list of worker URLs
|
||||
pub fn get_worker_urls(&self) -> Vec<String> {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, return empty list since we manage workers differently
|
||||
Arc::new(RwLock::new(Vec::new()))
|
||||
}
|
||||
Router::RoundRobin { workers, .. }
|
||||
| Router::Random { workers, .. }
|
||||
| Router::CacheAware { workers, .. } => workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect(),
|
||||
Router::PrefillDecode { .. } => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,13 +391,14 @@ impl Router {
|
||||
|
||||
fn select_first_worker(&self) -> Result<String, String> {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls, .. }
|
||||
| Router::CacheAware { worker_urls, .. } => {
|
||||
if worker_urls.read().unwrap().is_empty() {
|
||||
Router::RoundRobin { workers, .. }
|
||||
| Router::Random { workers, .. }
|
||||
| Router::CacheAware { workers, .. } => {
|
||||
let workers_guard = workers.read().unwrap();
|
||||
if workers_guard.is_empty() {
|
||||
Err("No workers are available".to_string())
|
||||
} else {
|
||||
Ok(worker_urls.read().unwrap()[0].clone())
|
||||
Ok(workers_guard[0].url().to_string())
|
||||
}
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
@@ -514,7 +533,7 @@ impl Router {
|
||||
return HttpResponse::NotImplemented()
|
||||
.body("route_to_all not implemented for PrefillDecode mode");
|
||||
}
|
||||
_ => self.get_worker_urls().read().unwrap().clone(),
|
||||
_ => self.get_worker_urls(),
|
||||
};
|
||||
|
||||
// Send requests to all workers concurrently
|
||||
@@ -562,7 +581,7 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
let urls = self.get_worker_urls().read().unwrap().clone();
|
||||
let urls = self.get_worker_urls();
|
||||
let prefill_urls: Vec<String> = Vec::new();
|
||||
let decode_urls = urls;
|
||||
|
||||
@@ -631,6 +650,24 @@ impl Router {
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// For CacheAware router, increment load before request
|
||||
let load_incremented = match self {
|
||||
Router::CacheAware { workers, .. } => {
|
||||
let workers_guard = workers.read().unwrap();
|
||||
if let Some(worker) =
|
||||
workers_guard.iter().find(|w| w.url() == &worker_url)
|
||||
{
|
||||
worker.increment_load();
|
||||
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
||||
.set(worker.load() as f64);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
|
||||
// Send typed request directly
|
||||
let response = self
|
||||
.send_typed_request(
|
||||
@@ -640,6 +677,7 @@ impl Router {
|
||||
route,
|
||||
&worker_url,
|
||||
is_stream,
|
||||
load_incremented,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -684,44 +722,47 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper method to select worker from text
|
||||
// Helper method to select worker from text (returns index for RoundRobin/Random, URL for CacheAware)
|
||||
fn select_generate_worker_from_text(&self, text: &str) -> String {
|
||||
match self {
|
||||
Router::RoundRobin {
|
||||
worker_urls,
|
||||
workers,
|
||||
current_index,
|
||||
..
|
||||
} => {
|
||||
let workers_guard = workers.read().unwrap();
|
||||
let idx = current_index
|
||||
.fetch_update(
|
||||
std::sync::atomic::Ordering::SeqCst,
|
||||
std::sync::atomic::Ordering::SeqCst,
|
||||
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
|
||||
|x| Some((x + 1) % workers_guard.len()),
|
||||
)
|
||||
.unwrap();
|
||||
worker_urls.read().unwrap()[idx].clone()
|
||||
workers_guard[idx].url().to_string()
|
||||
}
|
||||
|
||||
Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
|
||||
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
|
||||
.clone(),
|
||||
Router::Random { workers, .. } => {
|
||||
let workers_guard = workers.read().unwrap();
|
||||
workers_guard[rand::random::<usize>() % workers_guard.len()]
|
||||
.url()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
Router::CacheAware {
|
||||
worker_urls,
|
||||
workers,
|
||||
tree,
|
||||
running_queue,
|
||||
processed_queue,
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
..
|
||||
} => {
|
||||
let tree = tree.lock().unwrap();
|
||||
let mut running_queue = running_queue.lock().unwrap();
|
||||
let workers_guard = workers.read().unwrap();
|
||||
|
||||
// Get current load statistics
|
||||
let max_load = *running_queue.values().max().unwrap_or(&0);
|
||||
let min_load = *running_queue.values().min().unwrap_or(&0);
|
||||
// Get current load statistics from workers
|
||||
let loads: Vec<usize> = workers_guard.iter().map(|w| w.load()).collect();
|
||||
let max_load = *loads.iter().max().unwrap_or(&0);
|
||||
let min_load = *loads.iter().min().unwrap_or(&0);
|
||||
|
||||
// Load is considered imbalanced if:
|
||||
// 1. (max - min) > abs_threshold AND
|
||||
@@ -731,11 +772,16 @@ impl Router {
|
||||
|
||||
let selected_url = if is_imbalanced {
|
||||
// Log load balancing trigger and current queue state
|
||||
let worker_loads: Vec<(String, usize)> = workers_guard
|
||||
.iter()
|
||||
.map(|w| (w.url().to_string(), w.load()))
|
||||
.collect();
|
||||
|
||||
info!(
|
||||
"Load balancing triggered due to workload imbalance:\n\
|
||||
Max load: {}, Min load: {}\n\
|
||||
Current running queue: {:?}",
|
||||
max_load, min_load, running_queue
|
||||
Current worker loads: {:?}",
|
||||
max_load, min_load, worker_loads
|
||||
);
|
||||
|
||||
counter!("sgl_router_load_balancing_events_total").increment(1);
|
||||
@@ -743,11 +789,11 @@ impl Router {
|
||||
gauge!("sgl_router_min_load").set(min_load as f64);
|
||||
|
||||
// Use shortest queue routing when load is imbalanced
|
||||
running_queue
|
||||
workers_guard
|
||||
.iter()
|
||||
.min_by_key(|(_url, &count)| count)
|
||||
.map(|(url, _)| url.clone())
|
||||
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
|
||||
.min_by_key(|w| w.load())
|
||||
.map(|w| w.url().to_string())
|
||||
.unwrap_or_else(|| workers_guard[0].url().to_string())
|
||||
} else {
|
||||
// Use cache-aware routing when load is balanced
|
||||
let (matched_text, matched_worker) = tree.prefix_match(&text);
|
||||
@@ -763,18 +809,12 @@ impl Router {
|
||||
}
|
||||
};
|
||||
|
||||
// Update queues and tree
|
||||
*running_queue.get_mut(&selected_url).unwrap() += 1;
|
||||
|
||||
*processed_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&selected_url)
|
||||
.unwrap() += 1;
|
||||
|
||||
gauge!("sgl_router_running_requests", "worker" => selected_url.to_string())
|
||||
.set(*running_queue.get(&selected_url).unwrap() as f64);
|
||||
counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()).increment(1);
|
||||
// Find the selected worker and increment processed counter only
|
||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == &selected_url) {
|
||||
worker.increment_processed();
|
||||
counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string())
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
tree.insert(&text, &selected_url);
|
||||
|
||||
@@ -796,6 +836,7 @@ impl Router {
|
||||
route: &str,
|
||||
worker_url: &str,
|
||||
is_stream: bool,
|
||||
load_incremented: bool, // Whether load was incremented for this request
|
||||
) -> HttpResponse {
|
||||
let start = Instant::now();
|
||||
|
||||
@@ -820,6 +861,22 @@ impl Router {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
error!("Failed to send request to {}: {}", worker_url, e);
|
||||
|
||||
// Decrement load on error for CacheAware router
|
||||
if load_incremented {
|
||||
if let Router::CacheAware { workers, .. } = self {
|
||||
if let Ok(workers_guard) = workers.read() {
|
||||
if let Some(worker) =
|
||||
workers_guard.iter().find(|w| w.url() == worker_url)
|
||||
{
|
||||
worker.decrement_load();
|
||||
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
||||
.set(worker.load() as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
|
||||
}
|
||||
};
|
||||
@@ -837,13 +894,15 @@ impl Router {
|
||||
}
|
||||
};
|
||||
|
||||
// Then decrement running queue counter if using CacheAware
|
||||
if let Router::CacheAware { running_queue, .. } = self {
|
||||
if let Ok(mut queue) = running_queue.lock() {
|
||||
if let Some(count) = queue.get_mut(worker_url) {
|
||||
*count = count.saturating_sub(1);
|
||||
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
||||
.set(*count as f64);
|
||||
// Decrement load counter for non-streaming CacheAware requests
|
||||
if load_incremented && !is_stream {
|
||||
if let Router::CacheAware { workers, .. } = self {
|
||||
if let Ok(workers_guard) = workers.read() {
|
||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
||||
worker.decrement_load();
|
||||
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
||||
.set(worker.load() as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -855,8 +914,9 @@ impl Router {
|
||||
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
|
||||
|
||||
response
|
||||
} else if let Router::CacheAware { running_queue, .. } = self {
|
||||
let running_queue = Arc::clone(running_queue);
|
||||
} else if let Router::CacheAware { workers, .. } = self {
|
||||
// For streaming with CacheAware router, we need to manually decrement when done
|
||||
let workers = Arc::clone(workers);
|
||||
let worker_url = worker_url.to_string();
|
||||
|
||||
HttpResponse::build(status)
|
||||
@@ -867,21 +927,28 @@ impl Router {
|
||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
||||
})
|
||||
.inspect(move |bytes| {
|
||||
let bytes = bytes.as_ref().unwrap();
|
||||
if bytes
|
||||
.as_ref()
|
||||
.windows(12)
|
||||
.any(|window| window == b"data: [DONE]")
|
||||
{
|
||||
let mut locked_queue = running_queue.lock().unwrap();
|
||||
let count = locked_queue.get_mut(&worker_url).unwrap();
|
||||
*count = count.saturating_sub(1);
|
||||
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()).set(*count as f64);
|
||||
debug!("Streaming is done!!")
|
||||
if let Ok(bytes) = bytes {
|
||||
if bytes
|
||||
.as_ref()
|
||||
.windows(12)
|
||||
.any(|window| window == b"data: [DONE]")
|
||||
{
|
||||
if let Ok(workers_guard) = workers.read() {
|
||||
if let Some(worker) =
|
||||
workers_guard.iter().find(|w| w.url() == &worker_url)
|
||||
{
|
||||
worker.decrement_load();
|
||||
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
||||
.set(worker.load() as f64);
|
||||
debug!("Streaming is done!!")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
// For non-CacheAware routers, just stream without load tracking
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming(res.bytes_stream().map_err(|_| {
|
||||
@@ -935,43 +1002,27 @@ impl Router {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls, .. }
|
||||
| Router::CacheAware { worker_urls, .. } => {
|
||||
Router::RoundRobin { workers, .. }
|
||||
| Router::Random { workers, .. }
|
||||
| Router::CacheAware { workers, .. } => {
|
||||
info!("Worker {} health check passed", worker_url);
|
||||
let mut urls = worker_urls.write().unwrap();
|
||||
if urls.contains(&worker_url.to_string()) {
|
||||
let mut workers_guard = workers.write().unwrap();
|
||||
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||
return Err(format!("Worker {} already exists", worker_url));
|
||||
}
|
||||
info!("Added worker: {}", worker_url);
|
||||
urls.push(worker_url.to_string());
|
||||
gauge!("sgl_router_active_workers").set(urls.len() as f64);
|
||||
let new_worker =
|
||||
WorkerFactory::create_regular(worker_url.to_string());
|
||||
workers_guard.push(new_worker);
|
||||
gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// If cache aware, initialize the queues for the new worker
|
||||
if let Router::CacheAware {
|
||||
running_queue,
|
||||
processed_queue,
|
||||
tree,
|
||||
..
|
||||
} = self
|
||||
{
|
||||
// Add worker to running queue with initial count of 0
|
||||
running_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(worker_url.to_string(), 0);
|
||||
|
||||
// Add worker to processed queue with initial count of 0
|
||||
processed_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(worker_url.to_string(), 0);
|
||||
|
||||
// If cache aware, add worker to tree
|
||||
if let Router::CacheAware { tree, .. } = self {
|
||||
// Add worker to tree
|
||||
tree.lock().unwrap().insert("", worker_url);
|
||||
}
|
||||
@@ -1013,14 +1064,14 @@ impl Router {
|
||||
|
||||
pub fn remove_worker(&self, worker_url: &str) {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls, .. }
|
||||
| Router::CacheAware { worker_urls, .. } => {
|
||||
let mut urls = worker_urls.write().unwrap();
|
||||
if let Some(index) = urls.iter().position(|url| url == &worker_url) {
|
||||
urls.remove(index);
|
||||
Router::RoundRobin { workers, .. }
|
||||
| Router::Random { workers, .. }
|
||||
| Router::CacheAware { workers, .. } => {
|
||||
let mut workers_guard = workers.write().unwrap();
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", worker_url);
|
||||
gauge!("sgl_router_active_workers").set(urls.len() as f64);
|
||||
gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
return;
|
||||
@@ -1033,26 +1084,9 @@ impl Router {
|
||||
}
|
||||
|
||||
// if cache aware, remove the worker from the tree
|
||||
if let Router::CacheAware {
|
||||
tree,
|
||||
running_queue,
|
||||
processed_queue,
|
||||
..
|
||||
} = self
|
||||
{
|
||||
if let Router::CacheAware { tree, .. } = self {
|
||||
tree.lock().unwrap().remove_tenant(&worker_url);
|
||||
running_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&worker_url.to_string());
|
||||
processed_queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&worker_url.to_string());
|
||||
info!(
|
||||
"Removed worker from tree and cleaned up queues: {}",
|
||||
worker_url
|
||||
);
|
||||
info!("Removed worker from tree: {}", worker_url);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1241,21 +1275,22 @@ mod tests {
|
||||
use crate::service_discovery::PodType;
|
||||
|
||||
fn create_test_regular_router() -> Router {
|
||||
let workers = vec![
|
||||
WorkerFactory::create_regular("http://worker1:8080".to_string()),
|
||||
WorkerFactory::create_regular("http://worker2:8080".to_string()),
|
||||
];
|
||||
Router::Random {
|
||||
worker_urls: Arc::new(RwLock::new(vec![
|
||||
"http://worker1:8080".to_string(),
|
||||
"http://worker2:8080".to_string(),
|
||||
])),
|
||||
workers: Arc::new(RwLock::new(workers)),
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
_health_checker: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_router_get_worker_urls_regular() {
|
||||
let router = create_test_regular_router();
|
||||
let worker_urls = router.get_worker_urls();
|
||||
let urls = worker_urls.read().unwrap();
|
||||
let urls = router.get_worker_urls();
|
||||
|
||||
assert_eq!(urls.len(), 2);
|
||||
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
||||
|
||||
@@ -236,8 +236,7 @@ async fn add_worker(
|
||||
|
||||
#[get("/list_workers")]
|
||||
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
|
||||
let workers = data.router.get_worker_urls();
|
||||
let worker_list = workers.read().unwrap().clone();
|
||||
let worker_list = data.router.get_worker_urls();
|
||||
HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
|
||||
}
|
||||
|
||||
@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
info!("✅ Serving router on {}:{}", config.host, config.port);
|
||||
info!(
|
||||
"✅ Serving workers on {:?}",
|
||||
app_state.router.get_worker_urls().read().unwrap()
|
||||
app_state.router.get_worker_urls()
|
||||
);
|
||||
|
||||
HttpServer::new(move || {
|
||||
|
||||
@@ -547,11 +547,12 @@ mod tests {
|
||||
|
||||
// Helper to create a Router instance for testing event handlers
|
||||
fn create_test_router() -> Arc<Router> {
|
||||
let worker_urls = Arc::new(RwLock::new(Vec::new()));
|
||||
let workers = Arc::new(RwLock::new(Vec::new()));
|
||||
Arc::new(Router::Random {
|
||||
worker_urls,
|
||||
workers,
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
_health_checker: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -878,8 +879,6 @@ mod tests {
|
||||
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
assert!(!router
|
||||
.get_worker_urls()
|
||||
.read()
|
||||
.unwrap()
|
||||
.contains(&pod_info.worker_url(port)));
|
||||
}
|
||||
|
||||
@@ -907,7 +906,7 @@ mod tests {
|
||||
.await;
|
||||
|
||||
assert!(tracked_pods.lock().unwrap().is_empty());
|
||||
assert!(router.get_worker_urls().read().unwrap().is_empty());
|
||||
assert!(router.get_worker_urls().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
mod test_pd_routing {
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy};
|
||||
use sglang_router_rs::pd_types::PDSelectionPolicy;
|
||||
use sglang_router_rs::router::{PolicyConfig, Router};
|
||||
|
||||
// Test-only struct to help validate PD request parsing
|
||||
@@ -51,40 +51,35 @@ mod test_pd_routing {
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_engine_info_creation() {
|
||||
// Test EngineInfo creation for prefill servers
|
||||
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
match prefill_engine.engine_type {
|
||||
EngineType::Prefill => (),
|
||||
_ => panic!("Expected Prefill engine type"),
|
||||
}
|
||||
assert_eq!(prefill_engine.url, "http://prefill:8080");
|
||||
assert_eq!(prefill_engine.bootstrap_port, Some(9000));
|
||||
assert_eq!(prefill_engine.get_hostname(), "prefill");
|
||||
fn test_worker_types() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
|
||||
// Test EngineInfo creation for decode servers
|
||||
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string());
|
||||
match decode_engine.engine_type {
|
||||
EngineType::Decode => (),
|
||||
_ => panic!("Expected Decode engine type"),
|
||||
// Test worker creation for prefill servers
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
assert_eq!(prefill_worker.url(), "http://prefill:8080");
|
||||
match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => {
|
||||
assert_eq!(bootstrap_port, Some(9000));
|
||||
}
|
||||
_ => panic!("Expected Prefill worker type"),
|
||||
}
|
||||
assert_eq!(decode_engine.url, "http://decode:8080");
|
||||
assert_eq!(decode_engine.bootstrap_port, None);
|
||||
assert_eq!(decode_engine.get_hostname(), "decode");
|
||||
|
||||
// Test API path generation
|
||||
assert_eq!(
|
||||
prefill_engine.api_path("/generate"),
|
||||
"http://prefill:8080/generate"
|
||||
);
|
||||
assert_eq!(
|
||||
prefill_engine.api_path("health"),
|
||||
"http://prefill:8080/health"
|
||||
);
|
||||
assert_eq!(
|
||||
decode_engine.api_path("/v1/chat/completions"),
|
||||
"http://decode:8080/v1/chat/completions"
|
||||
);
|
||||
// Test worker creation for decode servers
|
||||
let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
|
||||
assert_eq!(decode_worker.url(), "http://decode:8080");
|
||||
match decode_worker.worker_type() {
|
||||
WorkerType::Decode => (),
|
||||
_ => panic!("Expected Decode worker type"),
|
||||
}
|
||||
|
||||
// Test regular worker creation
|
||||
let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
|
||||
assert_eq!(regular_worker.url(), "http://regular:8080");
|
||||
match regular_worker.worker_type() {
|
||||
WorkerType::Regular => (),
|
||||
_ => panic!("Expected Regular worker type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -230,6 +225,9 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_simulation() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Since we can't test the actual inject_bootstrap_fields function here
|
||||
// (it's private in the router module), we'll test the expected behavior
|
||||
|
||||
@@ -240,15 +238,24 @@ mod test_pd_routing {
|
||||
"temperature": 0.7
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Simulate what inject_bootstrap_fields would do
|
||||
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000));
|
||||
single_json["bootstrap_host"] = json!(prefill_info.get_hostname());
|
||||
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
|
||||
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
|
||||
single_json["bootstrap_port"] = json!(bootstrap_port);
|
||||
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
||||
|
||||
// Verify bootstrap fields are added correctly
|
||||
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
||||
assert_eq!(single_json["bootstrap_port"], 9000);
|
||||
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
|
||||
assert!(single_json["bootstrap_room"].is_u64());
|
||||
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
||||
|
||||
@@ -259,8 +266,9 @@ mod test_pd_routing {
|
||||
});
|
||||
|
||||
let batch_size = 3;
|
||||
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
||||
|
||||
// Verify batch bootstrap fields
|
||||
@@ -306,7 +314,9 @@ mod test_pd_routing {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_engine_info_hostname_extraction() {
|
||||
fn test_hostname_extraction() {
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test various URL formats
|
||||
let test_cases = vec![
|
||||
("http://localhost:8080", "localhost"),
|
||||
@@ -318,8 +328,7 @@ mod test_pd_routing {
|
||||
];
|
||||
|
||||
for (url, expected_hostname) in test_cases {
|
||||
let engine = EngineInfo::new_prefill(url.to_string(), None);
|
||||
assert_eq!(engine.get_hostname(), expected_hostname);
|
||||
assert_eq!(get_hostname(url), expected_hostname);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,6 +661,9 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test bootstrap injection with actual benchmark request patterns
|
||||
let mut benchmark_request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
||||
@@ -664,12 +676,20 @@ mod test_pd_routing {
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Simulate bootstrap injection
|
||||
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
let batch_size = 16;
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
|
||||
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
let batch_size = 16;
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
|
||||
benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
benchmark_request["bootstrap_room"] =
|
||||
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
||||
|
||||
@@ -770,6 +790,9 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_large_batch_bootstrap_injection() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test bootstrap injection performance with very large batches
|
||||
// This simulates the bench_one_batch_server.py scenario
|
||||
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||
@@ -787,14 +810,19 @@ mod test_pd_routing {
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Simulate bootstrap injection
|
||||
let prefill_info =
|
||||
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
|
||||
large_batch_request["bootstrap_host"] =
|
||||
json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||
large_batch_request["bootstrap_port"] =
|
||||
json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
|
||||
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
large_batch_request["bootstrap_room"] = json!((0..batch_size)
|
||||
.map(|_| rand::thread_rng().gen::<u64>())
|
||||
.collect::<Vec<_>>());
|
||||
|
||||
Reference in New Issue
Block a user