[router] add worker abstraction (#7960)
This commit is contained in:
@@ -30,6 +30,8 @@ tracing-appender = "0.2.3"
|
|||||||
kube = { version = "0.88.1", features = ["runtime", "derive"] }
|
kube = { version = "0.88.1", features = ["runtime", "derive"] }
|
||||||
k8s-openapi = { version = "0.21.0", features = ["v1_29"] }
|
k8s-openapi = { version = "0.21.0", features = ["v1_29"] }
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
|
async-trait = "0.1"
|
||||||
|
once_cell = "1.21"
|
||||||
# Added for metrics
|
# Added for metrics
|
||||||
metrics = "0.24.2"
|
metrics = "0.24.2"
|
||||||
metrics-exporter-prometheus = "0.17.0"
|
metrics-exporter-prometheus = "0.17.0"
|
||||||
|
|||||||
57
sgl-router/src/core/error.rs
Normal file
57
sgl-router/src/core/error.rs
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
//! Error types for the SGLang router core
|
||||||
|
//!
|
||||||
|
//! This module defines error types used throughout the router for worker operations.
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
/// Worker-related errors
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum WorkerError {
|
||||||
|
/// Health check failed
|
||||||
|
HealthCheckFailed { url: String, reason: String },
|
||||||
|
/// Worker not found
|
||||||
|
WorkerNotFound { url: String },
|
||||||
|
/// Invalid worker configuration
|
||||||
|
InvalidConfiguration { message: String },
|
||||||
|
/// Network error
|
||||||
|
NetworkError { url: String, error: String },
|
||||||
|
/// Worker is at capacity
|
||||||
|
WorkerAtCapacity { url: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for WorkerError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
WorkerError::HealthCheckFailed { url, reason } => {
|
||||||
|
write!(f, "Health check failed for worker {}: {}", url, reason)
|
||||||
|
}
|
||||||
|
WorkerError::WorkerNotFound { url } => {
|
||||||
|
write!(f, "Worker not found: {}", url)
|
||||||
|
}
|
||||||
|
WorkerError::InvalidConfiguration { message } => {
|
||||||
|
write!(f, "Invalid worker configuration: {}", message)
|
||||||
|
}
|
||||||
|
WorkerError::NetworkError { url, error } => {
|
||||||
|
write!(f, "Network error for worker {}: {}", url, error)
|
||||||
|
}
|
||||||
|
WorkerError::WorkerAtCapacity { url } => {
|
||||||
|
write!(f, "Worker at capacity: {}", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for WorkerError {}
|
||||||
|
|
||||||
|
/// Result type for worker operations
|
||||||
|
pub type WorkerResult<T> = Result<T, WorkerError>;
|
||||||
|
|
||||||
|
/// Convert from reqwest errors to worker errors
|
||||||
|
impl From<reqwest::Error> for WorkerError {
|
||||||
|
fn from(err: reqwest::Error) -> Self {
|
||||||
|
WorkerError::NetworkError {
|
||||||
|
url: err.url().map(|u| u.to_string()).unwrap_or_default(),
|
||||||
|
error: err.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
16
sgl-router/src/core/mod.rs
Normal file
16
sgl-router/src/core/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
//! Core abstractions for the SGLang router
|
||||||
|
//!
|
||||||
|
//! This module contains the fundamental types and traits used throughout the router:
|
||||||
|
//! - Worker trait and implementations
|
||||||
|
//! - Error types
|
||||||
|
//! - Common utilities
|
||||||
|
|
||||||
|
pub mod error;
|
||||||
|
pub mod worker;
|
||||||
|
|
||||||
|
// Re-export commonly used types at the module level
|
||||||
|
pub use error::{WorkerError, WorkerResult};
|
||||||
|
pub use worker::{
|
||||||
|
start_health_checker, BasicWorker, HealthChecker, Worker, WorkerCollection, WorkerFactory,
|
||||||
|
WorkerLoadGuard, WorkerType,
|
||||||
|
};
|
||||||
454
sgl-router/src/core/worker.rs
Normal file
454
sgl-router/src/core/worker.rs
Normal file
@@ -0,0 +1,454 @@
|
|||||||
|
use super::{WorkerError, WorkerResult};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use std::fmt;
|
||||||
|
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
// Shared HTTP client for health checks
|
||||||
|
static HEALTH_CHECK_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
|
||||||
|
reqwest::Client::builder()
|
||||||
|
.timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request
|
||||||
|
.build()
|
||||||
|
.expect("Failed to create health check HTTP client")
|
||||||
|
});
|
||||||
|
|
||||||
|
/// Core worker abstraction that represents a backend service
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Worker: Send + Sync + fmt::Debug {
|
||||||
|
/// Get the worker's URL
|
||||||
|
fn url(&self) -> &str;
|
||||||
|
|
||||||
|
/// Get the worker's type (Regular, Prefill, or Decode)
|
||||||
|
fn worker_type(&self) -> WorkerType;
|
||||||
|
|
||||||
|
/// Check if the worker is currently healthy
|
||||||
|
fn is_healthy(&self) -> bool;
|
||||||
|
|
||||||
|
/// Set the worker's health status
|
||||||
|
fn set_healthy(&self, healthy: bool);
|
||||||
|
|
||||||
|
/// Perform an async health check on the worker
|
||||||
|
async fn check_health_async(&self) -> WorkerResult<()>;
|
||||||
|
|
||||||
|
/// Synchronous health check wrapper (for compatibility)
|
||||||
|
fn check_health(&self) -> WorkerResult<()> {
|
||||||
|
// Use a small runtime for synchronous contexts
|
||||||
|
tokio::runtime::Builder::new_current_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.map_err(|e| WorkerError::HealthCheckFailed {
|
||||||
|
url: self.url().to_string(),
|
||||||
|
reason: format!("Failed to create runtime: {}", e),
|
||||||
|
})?
|
||||||
|
.block_on(self.check_health_async())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current load (number of active requests)
|
||||||
|
fn load(&self) -> usize;
|
||||||
|
|
||||||
|
/// Increment the load counter
|
||||||
|
fn increment_load(&self);
|
||||||
|
|
||||||
|
/// Decrement the load counter
|
||||||
|
fn decrement_load(&self);
|
||||||
|
|
||||||
|
/// Get the number of processed requests
|
||||||
|
fn processed_requests(&self) -> usize;
|
||||||
|
|
||||||
|
/// Increment the processed requests counter
|
||||||
|
fn increment_processed(&self);
|
||||||
|
|
||||||
|
/// Get worker-specific metadata
|
||||||
|
fn metadata(&self) -> &WorkerMetadata;
|
||||||
|
|
||||||
|
/// Clone the worker (for trait objects)
|
||||||
|
fn clone_worker(&self) -> Box<dyn Worker>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Worker type classification
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum WorkerType {
|
||||||
|
/// Regular worker for standard routing
|
||||||
|
Regular,
|
||||||
|
/// Prefill worker for PD disaggregated mode
|
||||||
|
Prefill {
|
||||||
|
/// Bootstrap port for communication with decode workers
|
||||||
|
bootstrap_port: Option<u16>,
|
||||||
|
},
|
||||||
|
/// Decode worker for PD disaggregated mode
|
||||||
|
Decode,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for WorkerType {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
WorkerType::Regular => write!(f, "Regular"),
|
||||||
|
WorkerType::Prefill { bootstrap_port } => match bootstrap_port {
|
||||||
|
Some(port) => write!(f, "Prefill(bootstrap:{})", port),
|
||||||
|
None => write!(f, "Prefill"),
|
||||||
|
},
|
||||||
|
WorkerType::Decode => write!(f, "Decode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Health check configuration
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct HealthConfig {
|
||||||
|
/// Timeout for health checks in seconds
|
||||||
|
pub timeout_secs: u64,
|
||||||
|
/// Interval between health checks in seconds
|
||||||
|
pub check_interval_secs: u64,
|
||||||
|
/// Health check endpoint path
|
||||||
|
pub endpoint: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for HealthConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
timeout_secs: 5,
|
||||||
|
check_interval_secs: 30,
|
||||||
|
endpoint: "/health".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Metadata associated with a worker
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct WorkerMetadata {
|
||||||
|
/// Worker URL
|
||||||
|
pub url: String,
|
||||||
|
/// Worker type
|
||||||
|
pub worker_type: WorkerType,
|
||||||
|
/// Additional labels/tags
|
||||||
|
pub labels: std::collections::HashMap<String, String>,
|
||||||
|
/// Health check configuration
|
||||||
|
pub health_config: HealthConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Basic worker implementation
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BasicWorker {
|
||||||
|
metadata: WorkerMetadata,
|
||||||
|
load_counter: Arc<AtomicUsize>,
|
||||||
|
processed_counter: Arc<AtomicUsize>,
|
||||||
|
healthy: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BasicWorker {
|
||||||
|
pub fn new(url: String, worker_type: WorkerType) -> Self {
|
||||||
|
let metadata = WorkerMetadata {
|
||||||
|
url: url.clone(),
|
||||||
|
worker_type,
|
||||||
|
labels: std::collections::HashMap::new(),
|
||||||
|
health_config: HealthConfig::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
metadata,
|
||||||
|
load_counter: Arc::new(AtomicUsize::new(0)),
|
||||||
|
processed_counter: Arc::new(AtomicUsize::new(0)),
|
||||||
|
healthy: Arc::new(AtomicBool::new(true)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_labels(mut self, labels: std::collections::HashMap<String, String>) -> Self {
|
||||||
|
self.metadata.labels = labels;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_health_config(mut self, config: HealthConfig) -> Self {
|
||||||
|
self.metadata.health_config = config;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Worker for BasicWorker {
|
||||||
|
fn url(&self) -> &str {
|
||||||
|
&self.metadata.url
|
||||||
|
}
|
||||||
|
|
||||||
|
fn worker_type(&self) -> WorkerType {
|
||||||
|
self.metadata.worker_type.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_healthy(&self) -> bool {
|
||||||
|
self.healthy.load(Ordering::Acquire)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_healthy(&self, healthy: bool) {
|
||||||
|
self.healthy.store(healthy, Ordering::Release);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_health_async(&self) -> WorkerResult<()> {
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
// Perform actual HTTP health check
|
||||||
|
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
|
||||||
|
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
||||||
|
|
||||||
|
// Use the shared client with a custom timeout for this request
|
||||||
|
match HEALTH_CHECK_CLIENT
|
||||||
|
.get(&health_url)
|
||||||
|
.timeout(timeout)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(response) => {
|
||||||
|
if response.status().is_success() {
|
||||||
|
self.set_healthy(true);
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
self.set_healthy(false);
|
||||||
|
Err(WorkerError::HealthCheckFailed {
|
||||||
|
url: self.url().to_string(),
|
||||||
|
reason: format!("Health check returned status: {}", response.status()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
self.set_healthy(false);
|
||||||
|
Err(WorkerError::HealthCheckFailed {
|
||||||
|
url: self.url().to_string(),
|
||||||
|
reason: format!("Health check request failed: {}", e),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load(&self) -> usize {
|
||||||
|
self.load_counter.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn increment_load(&self) {
|
||||||
|
self.load_counter.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decrement_load(&self) {
|
||||||
|
self.load_counter
|
||||||
|
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
|
||||||
|
current.checked_sub(1)
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn processed_requests(&self) -> usize {
|
||||||
|
self.processed_counter.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn increment_processed(&self) {
|
||||||
|
self.processed_counter.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn metadata(&self) -> &WorkerMetadata {
|
||||||
|
&self.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_worker(&self) -> Box<dyn Worker> {
|
||||||
|
Box::new(self.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Worker factory for creating workers of different types
|
||||||
|
pub struct WorkerFactory;
|
||||||
|
|
||||||
|
impl WorkerFactory {
|
||||||
|
/// Create a regular worker
|
||||||
|
pub fn create_regular(url: String) -> Box<dyn Worker> {
|
||||||
|
Box::new(BasicWorker::new(url, WorkerType::Regular))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a prefill worker with optional bootstrap port
|
||||||
|
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
|
||||||
|
Box::new(BasicWorker::new(
|
||||||
|
url,
|
||||||
|
WorkerType::Prefill { bootstrap_port },
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a decode worker
|
||||||
|
pub fn create_decode(url: String) -> Box<dyn Worker> {
|
||||||
|
Box::new(BasicWorker::new(url, WorkerType::Decode))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create workers from URLs with automatic type detection
|
||||||
|
pub fn create_from_urls(
|
||||||
|
regular_urls: Vec<String>,
|
||||||
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
|
decode_urls: Vec<String>,
|
||||||
|
) -> (
|
||||||
|
Vec<Box<dyn Worker>>,
|
||||||
|
Vec<Box<dyn Worker>>,
|
||||||
|
Vec<Box<dyn Worker>>,
|
||||||
|
) {
|
||||||
|
let regular_workers: Vec<Box<dyn Worker>> =
|
||||||
|
regular_urls.into_iter().map(Self::create_regular).collect();
|
||||||
|
|
||||||
|
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||||
|
.into_iter()
|
||||||
|
.map(|(url, port)| Self::create_prefill(url, port))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let decode_workers: Vec<Box<dyn Worker>> =
|
||||||
|
decode_urls.into_iter().map(Self::create_decode).collect();
|
||||||
|
|
||||||
|
(regular_workers, prefill_workers, decode_workers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper trait for collections of workers
|
||||||
|
pub trait WorkerCollection {
|
||||||
|
fn healthy_workers(&self) -> Vec<&dyn Worker>;
|
||||||
|
fn total_load(&self) -> usize;
|
||||||
|
fn find_worker(&self, url: &str) -> Option<&dyn Worker>;
|
||||||
|
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerCollection for Vec<Box<dyn Worker>> {
|
||||||
|
fn healthy_workers(&self) -> Vec<&dyn Worker> {
|
||||||
|
self.iter()
|
||||||
|
.filter(|w| w.is_healthy())
|
||||||
|
.map(|w| w.as_ref())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn total_load(&self) -> usize {
|
||||||
|
self.iter().map(|w| w.load()).sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_worker(&self, url: &str) -> Option<&dyn Worker> {
|
||||||
|
self.iter().find(|w| w.url() == url).map(|w| w.as_ref())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>> {
|
||||||
|
self.iter_mut().find(|w| w.url() == url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a list of worker URLs to worker trait objects
|
||||||
|
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
|
||||||
|
urls.into_iter()
|
||||||
|
.map(WorkerFactory::create_regular)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert worker trait objects back to URLs
|
||||||
|
pub fn workers_to_urls(workers: &[Box<dyn Worker>]) -> Vec<String> {
|
||||||
|
workers.iter().map(|w| w.url().to_string()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// RAII guard for worker load management
|
||||||
|
pub struct WorkerLoadGuard<'a> {
|
||||||
|
workers: Vec<&'a dyn Worker>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> WorkerLoadGuard<'a> {
|
||||||
|
/// Create a new load guard for a single worker
|
||||||
|
pub fn new(worker: &'a dyn Worker) -> Self {
|
||||||
|
worker.increment_load();
|
||||||
|
Self {
|
||||||
|
workers: vec![worker],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new load guard for multiple workers
|
||||||
|
pub fn new_multi(workers: Vec<&'a dyn Worker>) -> Self {
|
||||||
|
// Increment load counters for all workers
|
||||||
|
for worker in &workers {
|
||||||
|
worker.increment_load();
|
||||||
|
}
|
||||||
|
Self { workers }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Drop for WorkerLoadGuard<'a> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// Decrement load counters for all workers
|
||||||
|
for worker in &self.workers {
|
||||||
|
worker.decrement_load();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Health checker handle with graceful shutdown
|
||||||
|
pub struct HealthChecker {
|
||||||
|
handle: tokio::task::JoinHandle<()>,
|
||||||
|
shutdown: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for HealthChecker {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("HealthChecker")
|
||||||
|
.field("shutdown", &self.shutdown.load(Ordering::Relaxed))
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HealthChecker {
|
||||||
|
/// Shutdown the health checker gracefully
|
||||||
|
pub async fn shutdown(self) {
|
||||||
|
self.shutdown.store(true, Ordering::Release);
|
||||||
|
let _ = self.handle.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start an async background health checker for a collection of workers
|
||||||
|
pub fn start_health_checker(
|
||||||
|
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
|
check_interval_secs: u64,
|
||||||
|
) -> HealthChecker {
|
||||||
|
let shutdown = Arc::new(AtomicBool::new(false));
|
||||||
|
let shutdown_clone = shutdown.clone();
|
||||||
|
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
let mut interval =
|
||||||
|
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
// Check for shutdown signal
|
||||||
|
if shutdown_clone.load(Ordering::Acquire) {
|
||||||
|
tracing::info!("Health checker shutting down");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check health of all workers
|
||||||
|
let workers_to_check = match workers.read() {
|
||||||
|
Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::<Vec<_>>(),
|
||||||
|
Err(poisoned) => {
|
||||||
|
tracing::error!("Worker lock poisoned: {}", poisoned);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Perform health checks concurrently
|
||||||
|
let health_checks = workers_to_check.iter().map(|worker| {
|
||||||
|
let worker_url = worker.url().to_string();
|
||||||
|
let was_healthy = worker.is_healthy();
|
||||||
|
|
||||||
|
async move {
|
||||||
|
match worker.check_health_async().await {
|
||||||
|
Ok(_) => {
|
||||||
|
if !was_healthy {
|
||||||
|
tracing::info!("Worker {} is now healthy", worker_url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if was_healthy {
|
||||||
|
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Execute all health checks concurrently
|
||||||
|
futures::future::join_all(health_checks).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
HealthChecker { handle, shutdown }
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ use pyo3::prelude::*;
|
|||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
pub mod core;
|
||||||
pub mod openai_api_types;
|
pub mod openai_api_types;
|
||||||
pub mod pd_router;
|
pub mod pd_router;
|
||||||
pub mod pd_types;
|
pub mod pd_types;
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
// PD (Prefill-Decode) Router Implementation
|
// PD (Prefill-Decode) Router Implementation
|
||||||
// This module handles routing for disaggregated prefill-decode systems
|
// This module handles routing for disaggregated prefill-decode systems
|
||||||
|
|
||||||
|
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
||||||
use crate::pd_types::{
|
use crate::pd_types::{
|
||||||
Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy,
|
api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy,
|
||||||
};
|
};
|
||||||
use crate::tree::Tree;
|
use crate::tree::Tree;
|
||||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||||
@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt};
|
|||||||
use metrics::{counter, histogram};
|
use metrics::{counter, histogram};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
use std::sync::{Arc, Mutex, RwLock};
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
@@ -21,49 +21,17 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct PDRouter {
|
pub struct PDRouter {
|
||||||
pub prefill_workers: Arc<RwLock<Vec<EngineInfo>>>,
|
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
pub decode_workers: Arc<RwLock<Vec<EngineInfo>>>,
|
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
pub selection_policy: PDSelectionPolicy,
|
pub selection_policy: PDSelectionPolicy,
|
||||||
pub load_tracking: Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
|
|
||||||
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
||||||
pub timeout_secs: u64,
|
pub timeout_secs: u64,
|
||||||
pub interval_secs: u64,
|
pub interval_secs: u64,
|
||||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||||
pub http_client: reqwest::Client,
|
pub http_client: reqwest::Client,
|
||||||
}
|
_prefill_health_checker: Option<HealthChecker>,
|
||||||
|
_decode_health_checker: Option<HealthChecker>,
|
||||||
// RAII guard for load tracking to ensure cleanup even on panic
|
|
||||||
struct LoadGuard<'a> {
|
|
||||||
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
|
|
||||||
urls: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> LoadGuard<'a> {
|
|
||||||
fn new(
|
|
||||||
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
|
|
||||||
urls: Vec<String>,
|
|
||||||
) -> Self {
|
|
||||||
// Increment counters
|
|
||||||
for url in &urls {
|
|
||||||
let counter = tracking
|
|
||||||
.entry(url.clone())
|
|
||||||
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
|
|
||||||
counter.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
LoadGuard { tracking, urls }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for LoadGuard<'_> {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
// Guaranteed cleanup even on panic
|
|
||||||
for url in &self.urls {
|
|
||||||
if let Some(counter) = self.tracking.get(url) {
|
|
||||||
counter.fetch_sub(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PDRouter {
|
impl PDRouter {
|
||||||
@@ -73,9 +41,6 @@ impl PDRouter {
|
|||||||
url: String,
|
url: String,
|
||||||
bootstrap_port: Option<u16>,
|
bootstrap_port: Option<u16>,
|
||||||
) -> Result<String, PDRouterError> {
|
) -> Result<String, PDRouterError> {
|
||||||
// Create EngineInfo for the new prefill server
|
|
||||||
let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port);
|
|
||||||
|
|
||||||
// Wait for the new server to be healthy
|
// Wait for the new server to be healthy
|
||||||
crate::router::Router::wait_for_healthy_workers(
|
crate::router::Router::wait_for_healthy_workers(
|
||||||
&[url.clone()],
|
&[url.clone()],
|
||||||
@@ -84,6 +49,9 @@ impl PDRouter {
|
|||||||
)
|
)
|
||||||
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
||||||
|
|
||||||
|
// Create Worker for the new prefill server
|
||||||
|
let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port);
|
||||||
|
|
||||||
// Add to prefill workers list
|
// Add to prefill workers list
|
||||||
let mut workers = self
|
let mut workers = self
|
||||||
.prefill_workers
|
.prefill_workers
|
||||||
@@ -93,15 +61,11 @@ impl PDRouter {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Check if already exists
|
// Check if already exists
|
||||||
if workers.iter().any(|w| w.url == url) {
|
if workers.iter().any(|w| w.url() == &url) {
|
||||||
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
||||||
}
|
}
|
||||||
|
|
||||||
workers.push(engine_info);
|
workers.push(worker);
|
||||||
|
|
||||||
// Initialize load tracking
|
|
||||||
self.load_tracking
|
|
||||||
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
|
|
||||||
|
|
||||||
// Add to cache tree if using cache-aware policy
|
// Add to cache tree if using cache-aware policy
|
||||||
if let Some(ref tree) = self.prefill_tree {
|
if let Some(ref tree) = self.prefill_tree {
|
||||||
@@ -113,9 +77,6 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
|
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
|
||||||
// Create EngineInfo for the new decode server
|
|
||||||
let engine_info = EngineInfo::new_decode(url.clone());
|
|
||||||
|
|
||||||
// Wait for the new server to be healthy
|
// Wait for the new server to be healthy
|
||||||
crate::router::Router::wait_for_healthy_workers(
|
crate::router::Router::wait_for_healthy_workers(
|
||||||
&[url.clone()],
|
&[url.clone()],
|
||||||
@@ -124,6 +85,9 @@ impl PDRouter {
|
|||||||
)
|
)
|
||||||
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
||||||
|
|
||||||
|
// Create Worker for the new decode server
|
||||||
|
let worker = WorkerFactory::create_decode(url.clone());
|
||||||
|
|
||||||
// Add to decode workers list
|
// Add to decode workers list
|
||||||
let mut workers = self
|
let mut workers = self
|
||||||
.decode_workers
|
.decode_workers
|
||||||
@@ -133,15 +97,14 @@ impl PDRouter {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Check if already exists
|
// Check if already exists
|
||||||
if workers.iter().any(|w| w.url == url) {
|
if workers.iter().any(|w| w.url() == &url) {
|
||||||
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
||||||
}
|
}
|
||||||
|
|
||||||
workers.push(engine_info);
|
workers.push(worker);
|
||||||
|
|
||||||
// Initialize load tracking
|
// Initialize load tracking
|
||||||
self.load_tracking
|
// Worker tracks its own load internally
|
||||||
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
|
|
||||||
|
|
||||||
info!("Added decode server: {}", url);
|
info!("Added decode server: {}", url);
|
||||||
Ok(format!("Successfully added decode server: {}", url))
|
Ok(format!("Successfully added decode server: {}", url))
|
||||||
@@ -157,7 +120,7 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Find and remove the server
|
// Find and remove the server
|
||||||
let initial_len = workers.len();
|
let initial_len = workers.len();
|
||||||
workers.retain(|w| w.url != url);
|
workers.retain(|w| w.url() != url);
|
||||||
|
|
||||||
if workers.len() == initial_len {
|
if workers.len() == initial_len {
|
||||||
return Err(PDRouterError::WorkerNotFound {
|
return Err(PDRouterError::WorkerNotFound {
|
||||||
@@ -166,7 +129,7 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove from load tracking
|
// Remove from load tracking
|
||||||
self.load_tracking.remove(url);
|
// Worker load tracking is internal
|
||||||
|
|
||||||
// Remove from cache tree if using cache-aware policy
|
// Remove from cache tree if using cache-aware policy
|
||||||
if let Some(ref tree) = self.prefill_tree {
|
if let Some(ref tree) = self.prefill_tree {
|
||||||
@@ -174,7 +137,7 @@ impl PDRouter {
|
|||||||
let mut tree_guard = tree.lock().unwrap();
|
let mut tree_guard = tree.lock().unwrap();
|
||||||
*tree_guard = Tree::new();
|
*tree_guard = Tree::new();
|
||||||
for worker in workers.iter() {
|
for worker in workers.iter() {
|
||||||
tree_guard.insert("", &worker.url);
|
tree_guard.insert("", worker.url());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,7 +155,7 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Find and remove the server
|
// Find and remove the server
|
||||||
let initial_len = workers.len();
|
let initial_len = workers.len();
|
||||||
workers.retain(|w| w.url != url);
|
workers.retain(|w| w.url() != url);
|
||||||
|
|
||||||
if workers.len() == initial_len {
|
if workers.len() == initial_len {
|
||||||
return Err(PDRouterError::WorkerNotFound {
|
return Err(PDRouterError::WorkerNotFound {
|
||||||
@@ -200,9 +163,6 @@ impl PDRouter {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove from load tracking
|
|
||||||
self.load_tracking.remove(url);
|
|
||||||
|
|
||||||
info!("Removed decode server: {}", url);
|
info!("Removed decode server: {}", url);
|
||||||
Ok(format!("Successfully removed decode server: {}", url))
|
Ok(format!("Successfully removed decode server: {}", url))
|
||||||
}
|
}
|
||||||
@@ -214,41 +174,32 @@ impl PDRouter {
|
|||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
// Convert URLs to EngineInfo
|
// Convert URLs to Worker trait objects
|
||||||
let prefill_workers: Vec<EngineInfo> = prefill_urls
|
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(url, port)| EngineInfo::new_prefill(url, port))
|
.map(|(url, port)| WorkerFactory::create_prefill(url, port))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let decode_workers: Vec<EngineInfo> = decode_urls
|
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(EngineInfo::new_decode)
|
.map(WorkerFactory::create_decode)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Wait for PD workers to be healthy
|
// Wait for PD workers to be healthy
|
||||||
let all_urls: Vec<String> = prefill_workers
|
let all_urls: Vec<String> = prefill_workers
|
||||||
.iter()
|
.iter()
|
||||||
.chain(decode_workers.iter())
|
.chain(decode_workers.iter())
|
||||||
.map(|engine| engine.url.clone())
|
.map(|worker| worker.url().to_string())
|
||||||
.collect();
|
.collect();
|
||||||
crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?;
|
crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?;
|
||||||
|
|
||||||
// Initialize load tracking with atomic counters
|
|
||||||
let load_tracking = Arc::new(dashmap::DashMap::new());
|
|
||||||
for engine in &prefill_workers {
|
|
||||||
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
|
|
||||||
}
|
|
||||||
for engine in &decode_workers {
|
|
||||||
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize cache-aware components if needed
|
// Initialize cache-aware components if needed
|
||||||
let prefill_tree = match &selection_policy {
|
let prefill_tree = match &selection_policy {
|
||||||
PDSelectionPolicy::CacheAware { .. } => {
|
PDSelectionPolicy::CacheAware { .. } => {
|
||||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||||
// Initialize tree with prefill workers
|
// Initialize tree with prefill workers
|
||||||
for engine in &prefill_workers {
|
for worker in &prefill_workers {
|
||||||
tree.lock().unwrap().insert("", &engine.url);
|
tree.lock().unwrap().insert("", worker.url());
|
||||||
}
|
}
|
||||||
Some(tree)
|
Some(tree)
|
||||||
}
|
}
|
||||||
@@ -283,17 +234,27 @@ impl PDRouter {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
|
||||||
|
let decode_workers = Arc::new(RwLock::new(decode_workers));
|
||||||
|
|
||||||
|
// Start health checkers for both worker pools
|
||||||
|
let prefill_health_checker =
|
||||||
|
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs);
|
||||||
|
let decode_health_checker =
|
||||||
|
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
|
||||||
|
|
||||||
Ok(PDRouter {
|
Ok(PDRouter {
|
||||||
prefill_workers: Arc::new(RwLock::new(prefill_workers)),
|
prefill_workers,
|
||||||
decode_workers: Arc::new(RwLock::new(decode_workers)),
|
decode_workers,
|
||||||
selection_policy,
|
selection_policy,
|
||||||
load_tracking,
|
|
||||||
prefill_tree,
|
prefill_tree,
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
worker_loads,
|
worker_loads,
|
||||||
load_monitor_handle,
|
load_monitor_handle,
|
||||||
http_client,
|
http_client,
|
||||||
|
_prefill_health_checker: Some(prefill_health_checker),
|
||||||
|
_decode_health_checker: Some(decode_health_checker),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,11 +291,13 @@ impl PDRouter {
|
|||||||
// Log routing decision
|
// Log routing decision
|
||||||
info!(
|
info!(
|
||||||
"PD routing: {} -> prefill={}, decode={}",
|
"PD routing: {} -> prefill={}, decode={}",
|
||||||
route, prefill.url, decode.url
|
route,
|
||||||
|
prefill.url(),
|
||||||
|
decode.url()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add bootstrap info using the trait method
|
// Add bootstrap info using the trait method
|
||||||
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
|
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||||
error!("Failed to add bootstrap info: {}", e);
|
error!("Failed to add bootstrap info: {}", e);
|
||||||
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
|
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
|
||||||
return HttpResponse::InternalServerError()
|
return HttpResponse::InternalServerError()
|
||||||
@@ -356,8 +319,8 @@ impl PDRouter {
|
|||||||
req,
|
req,
|
||||||
json_with_bootstrap,
|
json_with_bootstrap,
|
||||||
route,
|
route,
|
||||||
&prefill,
|
prefill.as_ref(),
|
||||||
&decode,
|
decode.as_ref(),
|
||||||
is_stream,
|
is_stream,
|
||||||
return_logprob,
|
return_logprob,
|
||||||
start,
|
start,
|
||||||
@@ -397,11 +360,13 @@ impl PDRouter {
|
|||||||
// Log routing decision
|
// Log routing decision
|
||||||
info!(
|
info!(
|
||||||
"PD routing: {} -> prefill={}, decode={}",
|
"PD routing: {} -> prefill={}, decode={}",
|
||||||
route, prefill.url, decode.url
|
route,
|
||||||
|
prefill.url(),
|
||||||
|
decode.url()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add bootstrap info using the trait method
|
// Add bootstrap info using the trait method
|
||||||
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
|
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||||
error!("Failed to add bootstrap info: {}", e);
|
error!("Failed to add bootstrap info: {}", e);
|
||||||
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
|
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
|
||||||
return HttpResponse::InternalServerError()
|
return HttpResponse::InternalServerError()
|
||||||
@@ -423,8 +388,8 @@ impl PDRouter {
|
|||||||
req,
|
req,
|
||||||
json_with_bootstrap,
|
json_with_bootstrap,
|
||||||
route,
|
route,
|
||||||
&prefill,
|
prefill.as_ref(),
|
||||||
&decode,
|
decode.as_ref(),
|
||||||
is_stream,
|
is_stream,
|
||||||
return_logprob,
|
return_logprob,
|
||||||
start,
|
start,
|
||||||
@@ -440,22 +405,23 @@ impl PDRouter {
|
|||||||
req: &HttpRequest,
|
req: &HttpRequest,
|
||||||
json_request: serde_json::Value,
|
json_request: serde_json::Value,
|
||||||
route: &str,
|
route: &str,
|
||||||
prefill: &EngineInfo,
|
prefill: &dyn Worker,
|
||||||
decode: &EngineInfo,
|
decode: &dyn Worker,
|
||||||
is_stream: bool,
|
is_stream: bool,
|
||||||
return_logprob: bool,
|
return_logprob: bool,
|
||||||
start_time: Instant,
|
start_time: Instant,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
// Update load tracking for both workers
|
// Update load tracking for both workers
|
||||||
let _guard = LoadGuard::new(
|
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
||||||
&self.load_tracking,
|
|
||||||
vec![prefill.url.clone(), decode.url.clone()],
|
|
||||||
);
|
|
||||||
|
|
||||||
// Build requests using .json() method
|
// Build requests using .json() method
|
||||||
let mut prefill_request = client.post(prefill.api_path(route)).json(&json_request);
|
let mut prefill_request = client
|
||||||
|
.post(api_path(prefill.url(), route))
|
||||||
|
.json(&json_request);
|
||||||
|
|
||||||
let mut decode_request = client.post(decode.api_path(route)).json(&json_request);
|
let mut decode_request = client
|
||||||
|
.post(api_path(decode.url(), route))
|
||||||
|
.json(&json_request);
|
||||||
|
|
||||||
// Copy headers from original request
|
// Copy headers from original request
|
||||||
for (name, value) in crate::router::copy_request_headers(req) {
|
for (name, value) in crate::router::copy_request_headers(req) {
|
||||||
@@ -474,9 +440,9 @@ impl PDRouter {
|
|||||||
histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string())
|
histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string())
|
||||||
.record(duration.as_secs_f64());
|
.record(duration.as_secs_f64());
|
||||||
counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1);
|
counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1);
|
||||||
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url.to_string())
|
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url().to_string())
|
||||||
.increment(1);
|
.increment(1);
|
||||||
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url.to_string())
|
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url().to_string())
|
||||||
.increment(1);
|
.increment(1);
|
||||||
|
|
||||||
// Process decode response
|
// Process decode response
|
||||||
@@ -486,10 +452,11 @@ impl PDRouter {
|
|||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()).increment(1);
|
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()).increment(1);
|
||||||
error!(
|
error!(
|
||||||
"Decode server {} returned error status: {}",
|
"Decode server {} returned error status: {}",
|
||||||
decode.url, status
|
decode.url(),
|
||||||
|
status
|
||||||
);
|
);
|
||||||
|
|
||||||
// Return the error response from decode server
|
// Return the error response from decode server
|
||||||
@@ -508,9 +475,10 @@ impl PDRouter {
|
|||||||
if let Err(e) = &prefill_result {
|
if let Err(e) = &prefill_result {
|
||||||
error!(
|
error!(
|
||||||
"Prefill server {} failed (non-critical): {}",
|
"Prefill server {} failed (non-critical): {}",
|
||||||
prefill.url, e
|
prefill.url(),
|
||||||
|
e
|
||||||
);
|
);
|
||||||
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url.to_string()).increment(1);
|
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url().to_string()).increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_stream {
|
if is_stream {
|
||||||
@@ -559,7 +527,7 @@ impl PDRouter {
|
|||||||
HttpResponse::build(status)
|
HttpResponse::build(status)
|
||||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||||
.streaming({
|
.streaming({
|
||||||
let decode_url = decode.url.clone();
|
let decode_url = decode.url().to_string();
|
||||||
res.bytes_stream().map_err(move |e| {
|
res.bytes_stream().map_err(move |e| {
|
||||||
error!("Stream error from decode server {}: {}", decode_url, e);
|
error!("Stream error from decode server {}: {}", decode_url, e);
|
||||||
counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1);
|
counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1);
|
||||||
@@ -587,7 +555,7 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Decode request failed: {}", e);
|
error!("Decode request failed: {}", e);
|
||||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string())
|
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string())
|
||||||
.increment(1);
|
.increment(1);
|
||||||
HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
|
HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
|
||||||
}
|
}
|
||||||
@@ -652,7 +620,7 @@ impl PDRouter {
|
|||||||
async fn select_pd_pair(
|
async fn select_pd_pair(
|
||||||
&self,
|
&self,
|
||||||
_client: &reqwest::Client,
|
_client: &reqwest::Client,
|
||||||
) -> Result<(EngineInfo, EngineInfo), String> {
|
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||||
// Check we have workers
|
// Check we have workers
|
||||||
if self
|
if self
|
||||||
.prefill_workers
|
.prefill_workers
|
||||||
@@ -681,17 +649,17 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn select_random(&self) -> Result<(EngineInfo, EngineInfo), String> {
|
fn select_random(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||||
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
||||||
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
||||||
|
|
||||||
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone();
|
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone_worker();
|
||||||
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone();
|
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone_worker();
|
||||||
|
|
||||||
Ok((prefill, decode))
|
Ok((prefill, decode))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn select_power_of_two(&self) -> Result<(EngineInfo, EngineInfo), String> {
|
async fn select_power_of_two(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||||
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
||||||
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
||||||
|
|
||||||
@@ -700,33 +668,45 @@ impl PDRouter {
|
|||||||
|
|
||||||
let loads = self.worker_loads.borrow();
|
let loads = self.worker_loads.borrow();
|
||||||
|
|
||||||
let p1_load = loads.get(&prefill_list[p1_idx].url).copied().unwrap_or(0);
|
let p1_load = loads
|
||||||
let p2_load = loads.get(&prefill_list[p2_idx].url).copied().unwrap_or(0);
|
.get(prefill_list[p1_idx].url())
|
||||||
let d1_load = loads.get(&decode_list[d1_idx].url).copied().unwrap_or(0);
|
.copied()
|
||||||
let d2_load = loads.get(&decode_list[d2_idx].url).copied().unwrap_or(0);
|
.unwrap_or(isize::MAX);
|
||||||
|
let p2_load = loads
|
||||||
|
.get(prefill_list[p2_idx].url())
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(isize::MAX);
|
||||||
|
let d1_load = loads
|
||||||
|
.get(decode_list[d1_idx].url())
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(isize::MAX);
|
||||||
|
let d2_load = loads
|
||||||
|
.get(decode_list[d2_idx].url())
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(isize::MAX);
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}",
|
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}",
|
||||||
prefill_list[p1_idx].url,
|
prefill_list[p1_idx].url(),
|
||||||
p1_load,
|
p1_load,
|
||||||
prefill_list[p2_idx].url,
|
prefill_list[p2_idx].url(),
|
||||||
p2_load,
|
p2_load,
|
||||||
decode_list[d1_idx].url,
|
decode_list[d1_idx].url(),
|
||||||
d1_load,
|
d1_load,
|
||||||
decode_list[d2_idx].url,
|
decode_list[d2_idx].url(),
|
||||||
d2_load
|
d2_load
|
||||||
);
|
);
|
||||||
|
|
||||||
let selected_prefill = if p1_load <= p2_load {
|
let selected_prefill = if p1_load <= p2_load {
|
||||||
prefill_list[p1_idx].clone()
|
prefill_list[p1_idx].clone_worker()
|
||||||
} else {
|
} else {
|
||||||
prefill_list[p2_idx].clone()
|
prefill_list[p2_idx].clone_worker()
|
||||||
};
|
};
|
||||||
|
|
||||||
let selected_decode = if d1_load <= d2_load {
|
let selected_decode = if d1_load <= d2_load {
|
||||||
decode_list[d1_idx].clone()
|
decode_list[d1_idx].clone_worker()
|
||||||
} else {
|
} else {
|
||||||
decode_list[d2_idx].clone()
|
decode_list[d2_idx].clone_worker()
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((selected_prefill, selected_decode))
|
Ok((selected_prefill, selected_decode))
|
||||||
@@ -868,11 +848,11 @@ impl PDRouter {
|
|||||||
let mut worker_infos = Vec::new();
|
let mut worker_infos = Vec::new();
|
||||||
|
|
||||||
for worker in self.prefill_workers.read().unwrap().iter() {
|
for worker in self.prefill_workers.read().unwrap().iter() {
|
||||||
worker_infos.push((worker.url.clone(), "prefill"));
|
worker_infos.push((worker.url().to_string(), "prefill"));
|
||||||
}
|
}
|
||||||
|
|
||||||
for worker in self.decode_workers.read().unwrap().iter() {
|
for worker in self.decode_workers.read().unwrap().iter() {
|
||||||
worker_infos.push((worker.url.clone(), "decode"));
|
worker_infos.push((worker.url().to_string(), "decode"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create tasks with URL tracking
|
// Create tasks with URL tracking
|
||||||
@@ -922,7 +902,7 @@ impl PDRouter {
|
|||||||
pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse {
|
pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse {
|
||||||
// Get info from the first decode server to match sglang's server info format
|
// Get info from the first decode server to match sglang's server info format
|
||||||
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
|
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
|
||||||
workers.first().map(|w| w.url.clone())
|
workers.first().map(|w| w.url().to_string())
|
||||||
} else {
|
} else {
|
||||||
return HttpResponse::InternalServerError().body("Failed to access decode workers");
|
return HttpResponse::InternalServerError().body("Failed to access decode workers");
|
||||||
};
|
};
|
||||||
@@ -967,7 +947,7 @@ impl PDRouter {
|
|||||||
pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse {
|
pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse {
|
||||||
// Get first prefill worker URL to avoid holding lock across await
|
// Get first prefill worker URL to avoid holding lock across await
|
||||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||||
workers.first().map(|w| w.url.clone())
|
workers.first().map(|w| w.url().to_string())
|
||||||
} else {
|
} else {
|
||||||
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
||||||
};
|
};
|
||||||
@@ -1005,14 +985,14 @@ impl PDRouter {
|
|||||||
.read()
|
.read()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|w| w.url.clone())
|
.map(|w| w.url().to_string())
|
||||||
.collect();
|
.collect();
|
||||||
let d_urls: Vec<_> = self
|
let d_urls: Vec<_> = self
|
||||||
.decode_workers
|
.decode_workers
|
||||||
.read()
|
.read()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|w| w.url.clone())
|
.map(|w| w.url().to_string())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut prefill_loads = Vec::new();
|
let mut prefill_loads = Vec::new();
|
||||||
@@ -1048,7 +1028,7 @@ impl PDRouter {
|
|||||||
// Get model info from the first prefill server (matches original Rust PDLB behavior)
|
// Get model info from the first prefill server (matches original Rust PDLB behavior)
|
||||||
// Get first prefill worker URL to avoid holding lock across await
|
// Get first prefill worker URL to avoid holding lock across await
|
||||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||||
workers.first().map(|w| w.url.clone())
|
workers.first().map(|w| w.url().to_string())
|
||||||
} else {
|
} else {
|
||||||
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
||||||
};
|
};
|
||||||
@@ -1084,13 +1064,13 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Flush cache on all prefill servers
|
// Flush cache on all prefill servers
|
||||||
for worker in self.prefill_workers.read().unwrap().iter() {
|
for worker in self.prefill_workers.read().unwrap().iter() {
|
||||||
let url = format!("{}/flush_cache", worker.url);
|
let url = format!("{}/flush_cache", worker.url());
|
||||||
tasks.push(client.post(&url).send());
|
tasks.push(client.post(&url).send());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush cache on all decode servers
|
// Flush cache on all decode servers
|
||||||
for worker in self.decode_workers.read().unwrap().iter() {
|
for worker in self.decode_workers.read().unwrap().iter() {
|
||||||
let url = format!("{}/flush_cache", worker.url);
|
let url = format!("{}/flush_cache", worker.url());
|
||||||
tasks.push(client.post(&url).send());
|
tasks.push(client.post(&url).send());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// Essential PDLB types extracted for PD routing
|
// Essential PDLB types extracted for PD routing
|
||||||
|
|
||||||
|
use crate::core::{Worker, WorkerType};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
@@ -28,52 +29,21 @@ pub enum PDRouterError {
|
|||||||
Timeout { url: String },
|
Timeout { url: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
// Helper functions for workers
|
||||||
pub enum EngineType {
|
pub fn api_path(url: &str, api_path: &str) -> String {
|
||||||
Prefill,
|
|
||||||
Decode,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct EngineInfo {
|
|
||||||
pub engine_type: EngineType,
|
|
||||||
pub url: String,
|
|
||||||
pub bootstrap_port: Option<u16>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EngineInfo {
|
|
||||||
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
|
|
||||||
EngineInfo {
|
|
||||||
engine_type: EngineType::Prefill,
|
|
||||||
url,
|
|
||||||
bootstrap_port,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_decode(url: String) -> Self {
|
|
||||||
EngineInfo {
|
|
||||||
engine_type: EngineType::Decode,
|
|
||||||
url,
|
|
||||||
bootstrap_port: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn api_path(&self, api_path: &str) -> String {
|
|
||||||
if api_path.starts_with("/") {
|
if api_path.starts_with("/") {
|
||||||
format!("{}{}", self.url, api_path)
|
format!("{}{}", url, api_path)
|
||||||
} else {
|
} else {
|
||||||
format!("{}/{}", self.url, api_path)
|
format!("{}/{}", url, api_path)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_hostname(&self) -> String {
|
pub fn get_hostname(url: &str) -> String {
|
||||||
// Simple hostname extraction without external dependencies
|
// Simple hostname extraction without external dependencies
|
||||||
let url = self
|
let url = url
|
||||||
.url
|
|
||||||
.trim_start_matches("http://")
|
.trim_start_matches("http://")
|
||||||
.trim_start_matches("https://");
|
.trim_start_matches("https://");
|
||||||
url.split(':').next().unwrap_or("localhost").to_string()
|
url.split(':').next().unwrap_or("localhost").to_string()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PD-specific routing policies
|
// PD-specific routing policies
|
||||||
@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync {
|
|||||||
bootstrap_room: BootstrapRoom,
|
bootstrap_room: BootstrapRoom,
|
||||||
);
|
);
|
||||||
|
|
||||||
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
|
fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> {
|
||||||
let batch_size = self.get_batch_size()?;
|
let batch_size = self.get_batch_size()?;
|
||||||
|
|
||||||
|
// Extract bootstrap port from prefill worker if it's a prefill type
|
||||||
|
let bootstrap_port = match prefill_worker.worker_type() {
|
||||||
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let hostname = get_hostname(prefill_worker.url());
|
||||||
|
|
||||||
if let Some(batch_size) = batch_size {
|
if let Some(batch_size) = batch_size {
|
||||||
self.set_bootstrap_info(
|
self.set_bootstrap_info(
|
||||||
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
|
BootstrapHost::Batch(vec![hostname; batch_size]),
|
||||||
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
|
BootstrapPort::Batch(vec![bootstrap_port; batch_size]),
|
||||||
// Use high-quality random numbers to minimize collision risk
|
// Use high-quality random numbers to minimize collision risk
|
||||||
BootstrapRoom::Batch(
|
BootstrapRoom::Batch(
|
||||||
(0..batch_size)
|
(0..batch_size)
|
||||||
@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync {
|
|||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
self.set_bootstrap_info(
|
self.set_bootstrap_info(
|
||||||
BootstrapHost::Single(prefill_info.get_hostname()),
|
BootstrapHost::Single(hostname),
|
||||||
BootstrapPort::Single(prefill_info.bootstrap_port),
|
BootstrapPort::Single(bootstrap_port),
|
||||||
BootstrapRoom::Single({
|
BootstrapRoom::Single({
|
||||||
// Use high-quality random number for single requests too
|
// Use high-quality random number for single requests too
|
||||||
let r1 = rand::random::<u64>();
|
let r1 = rand::random::<u64>();
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::core::{HealthChecker, Worker, WorkerFactory};
|
||||||
use crate::pd_router::PDRouter;
|
use crate::pd_router::PDRouter;
|
||||||
use crate::pd_types::PDSelectionPolicy;
|
use crate::pd_types::PDSelectionPolicy;
|
||||||
use crate::tree::Tree;
|
use crate::tree::Tree;
|
||||||
@@ -5,7 +6,6 @@ use ::metrics::{counter, gauge, histogram};
|
|||||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||||
use actix_web::{HttpRequest, HttpResponse};
|
use actix_web::{HttpRequest, HttpResponse};
|
||||||
use futures_util::{StreamExt, TryStreamExt};
|
use futures_util::{StreamExt, TryStreamExt};
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::sync::atomic::AtomicUsize;
|
use std::sync::atomic::AtomicUsize;
|
||||||
use std::sync::{Arc, Mutex, RwLock};
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
@@ -30,15 +30,17 @@ pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Router {
|
pub enum Router {
|
||||||
RoundRobin {
|
RoundRobin {
|
||||||
worker_urls: Arc<RwLock<Vec<String>>>,
|
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
current_index: AtomicUsize,
|
current_index: AtomicUsize,
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
|
_health_checker: Option<HealthChecker>,
|
||||||
},
|
},
|
||||||
Random {
|
Random {
|
||||||
worker_urls: Arc<RwLock<Vec<String>>>,
|
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
|
_health_checker: Option<HealthChecker>,
|
||||||
},
|
},
|
||||||
PrefillDecode {
|
PrefillDecode {
|
||||||
pd_router: Arc<PDRouter>,
|
pd_router: Arc<PDRouter>,
|
||||||
@@ -104,16 +106,15 @@ pub enum Router {
|
|||||||
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
|
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
|
||||||
during the next eviction cycle.
|
during the next eviction cycle.
|
||||||
*/
|
*/
|
||||||
worker_urls: Arc<RwLock<Vec<String>>>,
|
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
tree: Arc<Mutex<Tree>>,
|
tree: Arc<Mutex<Tree>>,
|
||||||
running_queue: Arc<Mutex<HashMap<String, usize>>>,
|
|
||||||
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
|
|
||||||
cache_threshold: f32,
|
cache_threshold: f32,
|
||||||
balance_abs_threshold: usize,
|
balance_abs_threshold: usize,
|
||||||
balance_rel_threshold: f32,
|
balance_rel_threshold: f32,
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
_eviction_thread: Option<thread::JoinHandle<()>>,
|
_eviction_thread: Option<thread::JoinHandle<()>>,
|
||||||
|
_health_checker: Option<HealthChecker>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,25 +193,43 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create Worker trait objects from URLs
|
||||||
|
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||||
|
.iter()
|
||||||
|
.map(|url| WorkerFactory::create_regular(url.clone()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
// Create router based on policy...
|
// Create router based on policy...
|
||||||
Ok(match policy_config {
|
Ok(match policy_config {
|
||||||
PolicyConfig::RandomConfig {
|
PolicyConfig::RandomConfig {
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
} => Router::Random {
|
} => {
|
||||||
worker_urls: Arc::new(RwLock::new(worker_urls)),
|
let workers = Arc::new(RwLock::new(workers));
|
||||||
|
let health_checker =
|
||||||
|
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
|
||||||
|
Router::Random {
|
||||||
|
workers,
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
},
|
_health_checker: Some(health_checker),
|
||||||
|
}
|
||||||
|
}
|
||||||
PolicyConfig::RoundRobinConfig {
|
PolicyConfig::RoundRobinConfig {
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
} => Router::RoundRobin {
|
} => {
|
||||||
worker_urls: Arc::new(RwLock::new(worker_urls)),
|
let workers = Arc::new(RwLock::new(workers));
|
||||||
|
let health_checker =
|
||||||
|
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
|
||||||
|
Router::RoundRobin {
|
||||||
|
workers,
|
||||||
current_index: std::sync::atomic::AtomicUsize::new(0),
|
current_index: std::sync::atomic::AtomicUsize::new(0),
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
},
|
_health_checker: Some(health_checker),
|
||||||
|
}
|
||||||
|
}
|
||||||
PolicyConfig::CacheAwareConfig {
|
PolicyConfig::CacheAwareConfig {
|
||||||
cache_threshold,
|
cache_threshold,
|
||||||
balance_abs_threshold,
|
balance_abs_threshold,
|
||||||
@@ -220,24 +239,12 @@ impl Router {
|
|||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
} => {
|
} => {
|
||||||
let mut running_queue = HashMap::new();
|
|
||||||
for url in &worker_urls {
|
|
||||||
running_queue.insert(url.clone(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut processed_queue = HashMap::new();
|
|
||||||
for url in &worker_urls {
|
|
||||||
processed_queue.insert(url.clone(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||||
let running_queue = Arc::new(Mutex::new(running_queue));
|
|
||||||
let processed_queue = Arc::new(Mutex::new(processed_queue));
|
|
||||||
|
|
||||||
// Create background eviction thread
|
// Create background eviction thread
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
let processed_queue_clone = Arc::clone(&processed_queue);
|
let workers = Arc::new(RwLock::new(workers));
|
||||||
let running_queue_clone = Arc::clone(&running_queue);
|
let workers_clone = Arc::clone(&workers);
|
||||||
let eviction_thread = thread::spawn(move || {
|
let eviction_thread = thread::spawn(move || {
|
||||||
loop {
|
loop {
|
||||||
// Sleep for the specified interval
|
// Sleep for the specified interval
|
||||||
@@ -246,32 +253,41 @@ impl Router {
|
|||||||
let locked_tree_clone = tree_clone.lock().unwrap();
|
let locked_tree_clone = tree_clone.lock().unwrap();
|
||||||
// Run eviction
|
// Run eviction
|
||||||
locked_tree_clone.evict_tenant_by_size(max_tree_size);
|
locked_tree_clone.evict_tenant_by_size(max_tree_size);
|
||||||
|
drop(locked_tree_clone);
|
||||||
|
|
||||||
// Print the process queue
|
// Log worker loads and processed requests
|
||||||
let locked_processed_queue = processed_queue_clone.lock().unwrap();
|
let workers_guard = workers_clone.read().unwrap();
|
||||||
info!("Processed Queue: {:?}", locked_processed_queue);
|
let loads: Vec<(String, usize)> = workers_guard
|
||||||
|
.iter()
|
||||||
|
.map(|w| (w.url().to_string(), w.load()))
|
||||||
|
.collect();
|
||||||
|
info!("Worker loads: {:?}", loads);
|
||||||
|
|
||||||
// Print the running queue
|
let processed: Vec<(String, usize)> = workers_guard
|
||||||
let locked_running_queue = running_queue_clone.lock().unwrap();
|
.iter()
|
||||||
info!("Running Queue: {:?}", locked_running_queue);
|
.map(|w| (w.url().to_string(), w.processed_requests()))
|
||||||
|
.collect();
|
||||||
|
info!("Processed requests: {:?}", processed);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for url in &worker_urls {
|
for worker in workers.read().unwrap().iter() {
|
||||||
tree.lock().unwrap().insert("", url);
|
tree.lock().unwrap().insert("", worker.url());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let health_checker =
|
||||||
|
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
|
||||||
|
|
||||||
Router::CacheAware {
|
Router::CacheAware {
|
||||||
worker_urls: Arc::new(RwLock::new(worker_urls)),
|
workers,
|
||||||
tree,
|
tree,
|
||||||
running_queue,
|
|
||||||
processed_queue,
|
|
||||||
cache_threshold,
|
cache_threshold,
|
||||||
balance_abs_threshold,
|
balance_abs_threshold,
|
||||||
balance_rel_threshold,
|
balance_rel_threshold,
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
_eviction_thread: Some(eviction_thread),
|
_eviction_thread: Some(eviction_thread),
|
||||||
|
_health_checker: Some(health_checker),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
PolicyConfig::PrefillDecodeConfig {
|
PolicyConfig::PrefillDecodeConfig {
|
||||||
@@ -297,16 +313,18 @@ impl Router {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a reference to the worker URLs shared across threads
|
/// Get the current list of worker URLs
|
||||||
pub fn get_worker_urls(&self) -> Arc<RwLock<Vec<String>>> {
|
pub fn get_worker_urls(&self) -> Vec<String> {
|
||||||
match self {
|
match self {
|
||||||
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
|
Router::RoundRobin { workers, .. }
|
||||||
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
|
| Router::Random { workers, .. }
|
||||||
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
|
| Router::CacheAware { workers, .. } => workers
|
||||||
Router::PrefillDecode { .. } => {
|
.read()
|
||||||
// For PD mode, return empty list since we manage workers differently
|
.unwrap()
|
||||||
Arc::new(RwLock::new(Vec::new()))
|
.iter()
|
||||||
}
|
.map(|w| w.url().to_string())
|
||||||
|
.collect(),
|
||||||
|
Router::PrefillDecode { .. } => Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -373,13 +391,14 @@ impl Router {
|
|||||||
|
|
||||||
fn select_first_worker(&self) -> Result<String, String> {
|
fn select_first_worker(&self) -> Result<String, String> {
|
||||||
match self {
|
match self {
|
||||||
Router::RoundRobin { worker_urls, .. }
|
Router::RoundRobin { workers, .. }
|
||||||
| Router::Random { worker_urls, .. }
|
| Router::Random { workers, .. }
|
||||||
| Router::CacheAware { worker_urls, .. } => {
|
| Router::CacheAware { workers, .. } => {
|
||||||
if worker_urls.read().unwrap().is_empty() {
|
let workers_guard = workers.read().unwrap();
|
||||||
|
if workers_guard.is_empty() {
|
||||||
Err("No workers are available".to_string())
|
Err("No workers are available".to_string())
|
||||||
} else {
|
} else {
|
||||||
Ok(worker_urls.read().unwrap()[0].clone())
|
Ok(workers_guard[0].url().to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Router::PrefillDecode { .. } => {
|
Router::PrefillDecode { .. } => {
|
||||||
@@ -514,7 +533,7 @@ impl Router {
|
|||||||
return HttpResponse::NotImplemented()
|
return HttpResponse::NotImplemented()
|
||||||
.body("route_to_all not implemented for PrefillDecode mode");
|
.body("route_to_all not implemented for PrefillDecode mode");
|
||||||
}
|
}
|
||||||
_ => self.get_worker_urls().read().unwrap().clone(),
|
_ => self.get_worker_urls(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Send requests to all workers concurrently
|
// Send requests to all workers concurrently
|
||||||
@@ -562,7 +581,7 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let urls = self.get_worker_urls().read().unwrap().clone();
|
let urls = self.get_worker_urls();
|
||||||
let prefill_urls: Vec<String> = Vec::new();
|
let prefill_urls: Vec<String> = Vec::new();
|
||||||
let decode_urls = urls;
|
let decode_urls = urls;
|
||||||
|
|
||||||
@@ -631,6 +650,24 @@ impl Router {
|
|||||||
.increment(1);
|
.increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For CacheAware router, increment load before request
|
||||||
|
let load_incremented = match self {
|
||||||
|
Router::CacheAware { workers, .. } => {
|
||||||
|
let workers_guard = workers.read().unwrap();
|
||||||
|
if let Some(worker) =
|
||||||
|
workers_guard.iter().find(|w| w.url() == &worker_url)
|
||||||
|
{
|
||||||
|
worker.increment_load();
|
||||||
|
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
||||||
|
.set(worker.load() as f64);
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
|
||||||
// Send typed request directly
|
// Send typed request directly
|
||||||
let response = self
|
let response = self
|
||||||
.send_typed_request(
|
.send_typed_request(
|
||||||
@@ -640,6 +677,7 @@ impl Router {
|
|||||||
route,
|
route,
|
||||||
&worker_url,
|
&worker_url,
|
||||||
is_stream,
|
is_stream,
|
||||||
|
load_incremented,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -684,44 +722,47 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper method to select worker from text
|
// Helper method to select worker from text (returns index for RoundRobin/Random, URL for CacheAware)
|
||||||
fn select_generate_worker_from_text(&self, text: &str) -> String {
|
fn select_generate_worker_from_text(&self, text: &str) -> String {
|
||||||
match self {
|
match self {
|
||||||
Router::RoundRobin {
|
Router::RoundRobin {
|
||||||
worker_urls,
|
workers,
|
||||||
current_index,
|
current_index,
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
|
let workers_guard = workers.read().unwrap();
|
||||||
let idx = current_index
|
let idx = current_index
|
||||||
.fetch_update(
|
.fetch_update(
|
||||||
std::sync::atomic::Ordering::SeqCst,
|
std::sync::atomic::Ordering::SeqCst,
|
||||||
std::sync::atomic::Ordering::SeqCst,
|
std::sync::atomic::Ordering::SeqCst,
|
||||||
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
|
|x| Some((x + 1) % workers_guard.len()),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
worker_urls.read().unwrap()[idx].clone()
|
workers_guard[idx].url().to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
|
Router::Random { workers, .. } => {
|
||||||
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
|
let workers_guard = workers.read().unwrap();
|
||||||
.clone(),
|
workers_guard[rand::random::<usize>() % workers_guard.len()]
|
||||||
|
.url()
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
Router::CacheAware {
|
Router::CacheAware {
|
||||||
worker_urls,
|
workers,
|
||||||
tree,
|
tree,
|
||||||
running_queue,
|
|
||||||
processed_queue,
|
|
||||||
cache_threshold,
|
cache_threshold,
|
||||||
balance_abs_threshold,
|
balance_abs_threshold,
|
||||||
balance_rel_threshold,
|
balance_rel_threshold,
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
let tree = tree.lock().unwrap();
|
let tree = tree.lock().unwrap();
|
||||||
let mut running_queue = running_queue.lock().unwrap();
|
let workers_guard = workers.read().unwrap();
|
||||||
|
|
||||||
// Get current load statistics
|
// Get current load statistics from workers
|
||||||
let max_load = *running_queue.values().max().unwrap_or(&0);
|
let loads: Vec<usize> = workers_guard.iter().map(|w| w.load()).collect();
|
||||||
let min_load = *running_queue.values().min().unwrap_or(&0);
|
let max_load = *loads.iter().max().unwrap_or(&0);
|
||||||
|
let min_load = *loads.iter().min().unwrap_or(&0);
|
||||||
|
|
||||||
// Load is considered imbalanced if:
|
// Load is considered imbalanced if:
|
||||||
// 1. (max - min) > abs_threshold AND
|
// 1. (max - min) > abs_threshold AND
|
||||||
@@ -731,11 +772,16 @@ impl Router {
|
|||||||
|
|
||||||
let selected_url = if is_imbalanced {
|
let selected_url = if is_imbalanced {
|
||||||
// Log load balancing trigger and current queue state
|
// Log load balancing trigger and current queue state
|
||||||
|
let worker_loads: Vec<(String, usize)> = workers_guard
|
||||||
|
.iter()
|
||||||
|
.map(|w| (w.url().to_string(), w.load()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Load balancing triggered due to workload imbalance:\n\
|
"Load balancing triggered due to workload imbalance:\n\
|
||||||
Max load: {}, Min load: {}\n\
|
Max load: {}, Min load: {}\n\
|
||||||
Current running queue: {:?}",
|
Current worker loads: {:?}",
|
||||||
max_load, min_load, running_queue
|
max_load, min_load, worker_loads
|
||||||
);
|
);
|
||||||
|
|
||||||
counter!("sgl_router_load_balancing_events_total").increment(1);
|
counter!("sgl_router_load_balancing_events_total").increment(1);
|
||||||
@@ -743,11 +789,11 @@ impl Router {
|
|||||||
gauge!("sgl_router_min_load").set(min_load as f64);
|
gauge!("sgl_router_min_load").set(min_load as f64);
|
||||||
|
|
||||||
// Use shortest queue routing when load is imbalanced
|
// Use shortest queue routing when load is imbalanced
|
||||||
running_queue
|
workers_guard
|
||||||
.iter()
|
.iter()
|
||||||
.min_by_key(|(_url, &count)| count)
|
.min_by_key(|w| w.load())
|
||||||
.map(|(url, _)| url.clone())
|
.map(|w| w.url().to_string())
|
||||||
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
|
.unwrap_or_else(|| workers_guard[0].url().to_string())
|
||||||
} else {
|
} else {
|
||||||
// Use cache-aware routing when load is balanced
|
// Use cache-aware routing when load is balanced
|
||||||
let (matched_text, matched_worker) = tree.prefix_match(&text);
|
let (matched_text, matched_worker) = tree.prefix_match(&text);
|
||||||
@@ -763,18 +809,12 @@ impl Router {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Update queues and tree
|
// Find the selected worker and increment processed counter only
|
||||||
*running_queue.get_mut(&selected_url).unwrap() += 1;
|
if let Some(worker) = workers_guard.iter().find(|w| w.url() == &selected_url) {
|
||||||
|
worker.increment_processed();
|
||||||
*processed_queue
|
counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string())
|
||||||
.lock()
|
.increment(1);
|
||||||
.unwrap()
|
}
|
||||||
.get_mut(&selected_url)
|
|
||||||
.unwrap() += 1;
|
|
||||||
|
|
||||||
gauge!("sgl_router_running_requests", "worker" => selected_url.to_string())
|
|
||||||
.set(*running_queue.get(&selected_url).unwrap() as f64);
|
|
||||||
counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()).increment(1);
|
|
||||||
|
|
||||||
tree.insert(&text, &selected_url);
|
tree.insert(&text, &selected_url);
|
||||||
|
|
||||||
@@ -796,6 +836,7 @@ impl Router {
|
|||||||
route: &str,
|
route: &str,
|
||||||
worker_url: &str,
|
worker_url: &str,
|
||||||
is_stream: bool,
|
is_stream: bool,
|
||||||
|
load_incremented: bool, // Whether load was incremented for this request
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
@@ -820,6 +861,22 @@ impl Router {
|
|||||||
Ok(res) => res,
|
Ok(res) => res,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to send request to {}: {}", worker_url, e);
|
error!("Failed to send request to {}: {}", worker_url, e);
|
||||||
|
|
||||||
|
// Decrement load on error for CacheAware router
|
||||||
|
if load_incremented {
|
||||||
|
if let Router::CacheAware { workers, .. } = self {
|
||||||
|
if let Ok(workers_guard) = workers.read() {
|
||||||
|
if let Some(worker) =
|
||||||
|
workers_guard.iter().find(|w| w.url() == worker_url)
|
||||||
|
{
|
||||||
|
worker.decrement_load();
|
||||||
|
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
||||||
|
.set(worker.load() as f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
|
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -837,13 +894,15 @@ impl Router {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Then decrement running queue counter if using CacheAware
|
// Decrement load counter for non-streaming CacheAware requests
|
||||||
if let Router::CacheAware { running_queue, .. } = self {
|
if load_incremented && !is_stream {
|
||||||
if let Ok(mut queue) = running_queue.lock() {
|
if let Router::CacheAware { workers, .. } = self {
|
||||||
if let Some(count) = queue.get_mut(worker_url) {
|
if let Ok(workers_guard) = workers.read() {
|
||||||
*count = count.saturating_sub(1);
|
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
||||||
|
worker.decrement_load();
|
||||||
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
||||||
.set(*count as f64);
|
.set(worker.load() as f64);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -855,8 +914,9 @@ impl Router {
|
|||||||
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
|
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
|
||||||
|
|
||||||
response
|
response
|
||||||
} else if let Router::CacheAware { running_queue, .. } = self {
|
} else if let Router::CacheAware { workers, .. } = self {
|
||||||
let running_queue = Arc::clone(running_queue);
|
// For streaming with CacheAware router, we need to manually decrement when done
|
||||||
|
let workers = Arc::clone(workers);
|
||||||
let worker_url = worker_url.to_string();
|
let worker_url = worker_url.to_string();
|
||||||
|
|
||||||
HttpResponse::build(status)
|
HttpResponse::build(status)
|
||||||
@@ -867,21 +927,28 @@ impl Router {
|
|||||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
||||||
})
|
})
|
||||||
.inspect(move |bytes| {
|
.inspect(move |bytes| {
|
||||||
let bytes = bytes.as_ref().unwrap();
|
if let Ok(bytes) = bytes {
|
||||||
if bytes
|
if bytes
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.windows(12)
|
.windows(12)
|
||||||
.any(|window| window == b"data: [DONE]")
|
.any(|window| window == b"data: [DONE]")
|
||||||
{
|
{
|
||||||
let mut locked_queue = running_queue.lock().unwrap();
|
if let Ok(workers_guard) = workers.read() {
|
||||||
let count = locked_queue.get_mut(&worker_url).unwrap();
|
if let Some(worker) =
|
||||||
*count = count.saturating_sub(1);
|
workers_guard.iter().find(|w| w.url() == &worker_url)
|
||||||
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()).set(*count as f64);
|
{
|
||||||
|
worker.decrement_load();
|
||||||
|
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
|
||||||
|
.set(worker.load() as f64);
|
||||||
debug!("Streaming is done!!")
|
debug!("Streaming is done!!")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
// For non-CacheAware routers, just stream without load tracking
|
||||||
HttpResponse::build(status)
|
HttpResponse::build(status)
|
||||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||||
.streaming(res.bytes_stream().map_err(|_| {
|
.streaming(res.bytes_stream().map_err(|_| {
|
||||||
@@ -935,43 +1002,27 @@ impl Router {
|
|||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
if res.status().is_success() {
|
if res.status().is_success() {
|
||||||
match self {
|
match self {
|
||||||
Router::RoundRobin { worker_urls, .. }
|
Router::RoundRobin { workers, .. }
|
||||||
| Router::Random { worker_urls, .. }
|
| Router::Random { workers, .. }
|
||||||
| Router::CacheAware { worker_urls, .. } => {
|
| Router::CacheAware { workers, .. } => {
|
||||||
info!("Worker {} health check passed", worker_url);
|
info!("Worker {} health check passed", worker_url);
|
||||||
let mut urls = worker_urls.write().unwrap();
|
let mut workers_guard = workers.write().unwrap();
|
||||||
if urls.contains(&worker_url.to_string()) {
|
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||||
return Err(format!("Worker {} already exists", worker_url));
|
return Err(format!("Worker {} already exists", worker_url));
|
||||||
}
|
}
|
||||||
info!("Added worker: {}", worker_url);
|
info!("Added worker: {}", worker_url);
|
||||||
urls.push(worker_url.to_string());
|
let new_worker =
|
||||||
gauge!("sgl_router_active_workers").set(urls.len() as f64);
|
WorkerFactory::create_regular(worker_url.to_string());
|
||||||
|
workers_guard.push(new_worker);
|
||||||
|
gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
|
||||||
}
|
}
|
||||||
Router::PrefillDecode { .. } => {
|
Router::PrefillDecode { .. } => {
|
||||||
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If cache aware, initialize the queues for the new worker
|
// If cache aware, add worker to tree
|
||||||
if let Router::CacheAware {
|
if let Router::CacheAware { tree, .. } = self {
|
||||||
running_queue,
|
|
||||||
processed_queue,
|
|
||||||
tree,
|
|
||||||
..
|
|
||||||
} = self
|
|
||||||
{
|
|
||||||
// Add worker to running queue with initial count of 0
|
|
||||||
running_queue
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(worker_url.to_string(), 0);
|
|
||||||
|
|
||||||
// Add worker to processed queue with initial count of 0
|
|
||||||
processed_queue
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(worker_url.to_string(), 0);
|
|
||||||
|
|
||||||
// Add worker to tree
|
// Add worker to tree
|
||||||
tree.lock().unwrap().insert("", worker_url);
|
tree.lock().unwrap().insert("", worker_url);
|
||||||
}
|
}
|
||||||
@@ -1013,14 +1064,14 @@ impl Router {
|
|||||||
|
|
||||||
pub fn remove_worker(&self, worker_url: &str) {
|
pub fn remove_worker(&self, worker_url: &str) {
|
||||||
match self {
|
match self {
|
||||||
Router::RoundRobin { worker_urls, .. }
|
Router::RoundRobin { workers, .. }
|
||||||
| Router::Random { worker_urls, .. }
|
| Router::Random { workers, .. }
|
||||||
| Router::CacheAware { worker_urls, .. } => {
|
| Router::CacheAware { workers, .. } => {
|
||||||
let mut urls = worker_urls.write().unwrap();
|
let mut workers_guard = workers.write().unwrap();
|
||||||
if let Some(index) = urls.iter().position(|url| url == &worker_url) {
|
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||||
urls.remove(index);
|
workers_guard.remove(index);
|
||||||
info!("Removed worker: {}", worker_url);
|
info!("Removed worker: {}", worker_url);
|
||||||
gauge!("sgl_router_active_workers").set(urls.len() as f64);
|
gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
|
||||||
} else {
|
} else {
|
||||||
warn!("Worker {} not found, skipping removal", worker_url);
|
warn!("Worker {} not found, skipping removal", worker_url);
|
||||||
return;
|
return;
|
||||||
@@ -1033,26 +1084,9 @@ impl Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if cache aware, remove the worker from the tree
|
// if cache aware, remove the worker from the tree
|
||||||
if let Router::CacheAware {
|
if let Router::CacheAware { tree, .. } = self {
|
||||||
tree,
|
|
||||||
running_queue,
|
|
||||||
processed_queue,
|
|
||||||
..
|
|
||||||
} = self
|
|
||||||
{
|
|
||||||
tree.lock().unwrap().remove_tenant(&worker_url);
|
tree.lock().unwrap().remove_tenant(&worker_url);
|
||||||
running_queue
|
info!("Removed worker from tree: {}", worker_url);
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.remove(&worker_url.to_string());
|
|
||||||
processed_queue
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.remove(&worker_url.to_string());
|
|
||||||
info!(
|
|
||||||
"Removed worker from tree and cleaned up queues: {}",
|
|
||||||
worker_url
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1241,21 +1275,22 @@ mod tests {
|
|||||||
use crate::service_discovery::PodType;
|
use crate::service_discovery::PodType;
|
||||||
|
|
||||||
fn create_test_regular_router() -> Router {
|
fn create_test_regular_router() -> Router {
|
||||||
|
let workers = vec![
|
||||||
|
WorkerFactory::create_regular("http://worker1:8080".to_string()),
|
||||||
|
WorkerFactory::create_regular("http://worker2:8080".to_string()),
|
||||||
|
];
|
||||||
Router::Random {
|
Router::Random {
|
||||||
worker_urls: Arc::new(RwLock::new(vec![
|
workers: Arc::new(RwLock::new(workers)),
|
||||||
"http://worker1:8080".to_string(),
|
|
||||||
"http://worker2:8080".to_string(),
|
|
||||||
])),
|
|
||||||
timeout_secs: 5,
|
timeout_secs: 5,
|
||||||
interval_secs: 1,
|
interval_secs: 1,
|
||||||
|
_health_checker: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_router_get_worker_urls_regular() {
|
fn test_router_get_worker_urls_regular() {
|
||||||
let router = create_test_regular_router();
|
let router = create_test_regular_router();
|
||||||
let worker_urls = router.get_worker_urls();
|
let urls = router.get_worker_urls();
|
||||||
let urls = worker_urls.read().unwrap();
|
|
||||||
|
|
||||||
assert_eq!(urls.len(), 2);
|
assert_eq!(urls.len(), 2);
|
||||||
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
||||||
|
|||||||
@@ -236,8 +236,7 @@ async fn add_worker(
|
|||||||
|
|
||||||
#[get("/list_workers")]
|
#[get("/list_workers")]
|
||||||
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
|
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
|
||||||
let workers = data.router.get_worker_urls();
|
let worker_list = data.router.get_worker_urls();
|
||||||
let worker_list = workers.read().unwrap().clone();
|
|
||||||
HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
|
HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
|||||||
info!("✅ Serving router on {}:{}", config.host, config.port);
|
info!("✅ Serving router on {}:{}", config.host, config.port);
|
||||||
info!(
|
info!(
|
||||||
"✅ Serving workers on {:?}",
|
"✅ Serving workers on {:?}",
|
||||||
app_state.router.get_worker_urls().read().unwrap()
|
app_state.router.get_worker_urls()
|
||||||
);
|
);
|
||||||
|
|
||||||
HttpServer::new(move || {
|
HttpServer::new(move || {
|
||||||
|
|||||||
@@ -547,11 +547,12 @@ mod tests {
|
|||||||
|
|
||||||
// Helper to create a Router instance for testing event handlers
|
// Helper to create a Router instance for testing event handlers
|
||||||
fn create_test_router() -> Arc<Router> {
|
fn create_test_router() -> Arc<Router> {
|
||||||
let worker_urls = Arc::new(RwLock::new(Vec::new()));
|
let workers = Arc::new(RwLock::new(Vec::new()));
|
||||||
Arc::new(Router::Random {
|
Arc::new(Router::Random {
|
||||||
worker_urls,
|
workers,
|
||||||
timeout_secs: 5,
|
timeout_secs: 5,
|
||||||
interval_secs: 1,
|
interval_secs: 1,
|
||||||
|
_health_checker: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -878,8 +879,6 @@ mod tests {
|
|||||||
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||||
assert!(!router
|
assert!(!router
|
||||||
.get_worker_urls()
|
.get_worker_urls()
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.contains(&pod_info.worker_url(port)));
|
.contains(&pod_info.worker_url(port)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -907,7 +906,7 @@ mod tests {
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
assert!(tracked_pods.lock().unwrap().is_empty());
|
assert!(tracked_pods.lock().unwrap().is_empty());
|
||||||
assert!(router.get_worker_urls().read().unwrap().is_empty());
|
assert!(router.get_worker_urls().is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
mod test_pd_routing {
|
mod test_pd_routing {
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy};
|
use sglang_router_rs::pd_types::PDSelectionPolicy;
|
||||||
use sglang_router_rs::router::{PolicyConfig, Router};
|
use sglang_router_rs::router::{PolicyConfig, Router};
|
||||||
|
|
||||||
// Test-only struct to help validate PD request parsing
|
// Test-only struct to help validate PD request parsing
|
||||||
@@ -51,40 +51,35 @@ mod test_pd_routing {
|
|||||||
// ========================================================================
|
// ========================================================================
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_engine_info_creation() {
|
fn test_worker_types() {
|
||||||
// Test EngineInfo creation for prefill servers
|
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||||
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
|
||||||
match prefill_engine.engine_type {
|
|
||||||
EngineType::Prefill => (),
|
|
||||||
_ => panic!("Expected Prefill engine type"),
|
|
||||||
}
|
|
||||||
assert_eq!(prefill_engine.url, "http://prefill:8080");
|
|
||||||
assert_eq!(prefill_engine.bootstrap_port, Some(9000));
|
|
||||||
assert_eq!(prefill_engine.get_hostname(), "prefill");
|
|
||||||
|
|
||||||
// Test EngineInfo creation for decode servers
|
// Test worker creation for prefill servers
|
||||||
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string());
|
let prefill_worker =
|
||||||
match decode_engine.engine_type {
|
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||||
EngineType::Decode => (),
|
assert_eq!(prefill_worker.url(), "http://prefill:8080");
|
||||||
_ => panic!("Expected Decode engine type"),
|
match prefill_worker.worker_type() {
|
||||||
|
WorkerType::Prefill { bootstrap_port } => {
|
||||||
|
assert_eq!(bootstrap_port, Some(9000));
|
||||||
|
}
|
||||||
|
_ => panic!("Expected Prefill worker type"),
|
||||||
}
|
}
|
||||||
assert_eq!(decode_engine.url, "http://decode:8080");
|
|
||||||
assert_eq!(decode_engine.bootstrap_port, None);
|
|
||||||
assert_eq!(decode_engine.get_hostname(), "decode");
|
|
||||||
|
|
||||||
// Test API path generation
|
// Test worker creation for decode servers
|
||||||
assert_eq!(
|
let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
|
||||||
prefill_engine.api_path("/generate"),
|
assert_eq!(decode_worker.url(), "http://decode:8080");
|
||||||
"http://prefill:8080/generate"
|
match decode_worker.worker_type() {
|
||||||
);
|
WorkerType::Decode => (),
|
||||||
assert_eq!(
|
_ => panic!("Expected Decode worker type"),
|
||||||
prefill_engine.api_path("health"),
|
}
|
||||||
"http://prefill:8080/health"
|
|
||||||
);
|
// Test regular worker creation
|
||||||
assert_eq!(
|
let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
|
||||||
decode_engine.api_path("/v1/chat/completions"),
|
assert_eq!(regular_worker.url(), "http://regular:8080");
|
||||||
"http://decode:8080/v1/chat/completions"
|
match regular_worker.worker_type() {
|
||||||
);
|
WorkerType::Regular => (),
|
||||||
|
_ => panic!("Expected Regular worker type"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -230,6 +225,9 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_bootstrap_injection_simulation() {
|
fn test_bootstrap_injection_simulation() {
|
||||||
|
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||||
|
use sglang_router_rs::pd_types::get_hostname;
|
||||||
|
|
||||||
// Since we can't test the actual inject_bootstrap_fields function here
|
// Since we can't test the actual inject_bootstrap_fields function here
|
||||||
// (it's private in the router module), we'll test the expected behavior
|
// (it's private in the router module), we'll test the expected behavior
|
||||||
|
|
||||||
@@ -240,15 +238,24 @@ mod test_pd_routing {
|
|||||||
"temperature": 0.7
|
"temperature": 0.7
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Create a prefill worker to simulate injection
|
||||||
|
let prefill_worker =
|
||||||
|
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
|
||||||
|
|
||||||
|
// Extract bootstrap port from worker type
|
||||||
|
let bootstrap_port = match prefill_worker.worker_type() {
|
||||||
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
// Simulate what inject_bootstrap_fields would do
|
// Simulate what inject_bootstrap_fields would do
|
||||||
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000));
|
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
|
||||||
single_json["bootstrap_host"] = json!(prefill_info.get_hostname());
|
single_json["bootstrap_port"] = json!(bootstrap_port);
|
||||||
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
|
|
||||||
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
||||||
|
|
||||||
// Verify bootstrap fields are added correctly
|
// Verify bootstrap fields are added correctly
|
||||||
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
||||||
assert_eq!(single_json["bootstrap_port"], 9000);
|
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
|
||||||
assert!(single_json["bootstrap_room"].is_u64());
|
assert!(single_json["bootstrap_room"].is_u64());
|
||||||
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
||||||
|
|
||||||
@@ -259,8 +266,9 @@ mod test_pd_routing {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let batch_size = 3;
|
let batch_size = 3;
|
||||||
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
let hostname = get_hostname(prefill_worker.url());
|
||||||
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||||
|
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||||
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
||||||
|
|
||||||
// Verify batch bootstrap fields
|
// Verify batch bootstrap fields
|
||||||
@@ -306,7 +314,9 @@ mod test_pd_routing {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_engine_info_hostname_extraction() {
|
fn test_hostname_extraction() {
|
||||||
|
use sglang_router_rs::pd_types::get_hostname;
|
||||||
|
|
||||||
// Test various URL formats
|
// Test various URL formats
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
("http://localhost:8080", "localhost"),
|
("http://localhost:8080", "localhost"),
|
||||||
@@ -318,8 +328,7 @@ mod test_pd_routing {
|
|||||||
];
|
];
|
||||||
|
|
||||||
for (url, expected_hostname) in test_cases {
|
for (url, expected_hostname) in test_cases {
|
||||||
let engine = EngineInfo::new_prefill(url.to_string(), None);
|
assert_eq!(get_hostname(url), expected_hostname);
|
||||||
assert_eq!(engine.get_hostname(), expected_hostname);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -652,6 +661,9 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||||
|
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||||
|
use sglang_router_rs::pd_types::get_hostname;
|
||||||
|
|
||||||
// Test bootstrap injection with actual benchmark request patterns
|
// Test bootstrap injection with actual benchmark request patterns
|
||||||
let mut benchmark_request = json!({
|
let mut benchmark_request = json!({
|
||||||
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
||||||
@@ -664,12 +676,20 @@ mod test_pd_routing {
|
|||||||
"stream": true
|
"stream": true
|
||||||
});
|
});
|
||||||
|
|
||||||
// Simulate bootstrap injection
|
// Create a prefill worker to simulate injection
|
||||||
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
let prefill_worker =
|
||||||
let batch_size = 16;
|
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||||
|
|
||||||
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
// Extract bootstrap port from worker type
|
||||||
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
let bootstrap_port = match prefill_worker.worker_type() {
|
||||||
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
let batch_size = 16;
|
||||||
|
let hostname = get_hostname(prefill_worker.url());
|
||||||
|
|
||||||
|
benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||||
|
benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||||
benchmark_request["bootstrap_room"] =
|
benchmark_request["bootstrap_room"] =
|
||||||
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
||||||
|
|
||||||
@@ -770,6 +790,9 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_large_batch_bootstrap_injection() {
|
fn test_large_batch_bootstrap_injection() {
|
||||||
|
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||||
|
use sglang_router_rs::pd_types::get_hostname;
|
||||||
|
|
||||||
// Test bootstrap injection performance with very large batches
|
// Test bootstrap injection performance with very large batches
|
||||||
// This simulates the bench_one_batch_server.py scenario
|
// This simulates the bench_one_batch_server.py scenario
|
||||||
let large_batch_sizes = vec![1024, 4096, 8192];
|
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||||
@@ -787,14 +810,19 @@ mod test_pd_routing {
|
|||||||
"stream": true
|
"stream": true
|
||||||
});
|
});
|
||||||
|
|
||||||
// Simulate bootstrap injection
|
// Create a prefill worker to simulate injection
|
||||||
let prefill_info =
|
let prefill_worker =
|
||||||
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||||
|
|
||||||
large_batch_request["bootstrap_host"] =
|
// Extract bootstrap port from worker type
|
||||||
json!(vec![prefill_info.get_hostname(); batch_size]);
|
let bootstrap_port = match prefill_worker.worker_type() {
|
||||||
large_batch_request["bootstrap_port"] =
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
json!(vec![prefill_info.bootstrap_port; batch_size]);
|
_ => None,
|
||||||
|
};
|
||||||
|
let hostname = get_hostname(prefill_worker.url());
|
||||||
|
|
||||||
|
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||||
|
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||||
large_batch_request["bootstrap_room"] = json!((0..batch_size)
|
large_batch_request["bootstrap_room"] = json!((0..batch_size)
|
||||||
.map(|_| rand::thread_rng().gen::<u64>())
|
.map(|_| rand::thread_rng().gen::<u64>())
|
||||||
.collect::<Vec<_>>());
|
.collect::<Vec<_>>());
|
||||||
|
|||||||
Reference in New Issue
Block a user