[router] remove old/oudated/useless comments (#10967)
This commit is contained in:
@@ -67,58 +67,47 @@ struct Router {
|
|||||||
decode_policy: Option<PolicyType>,
|
decode_policy: Option<PolicyType>,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
cors_allowed_origins: Vec<String>,
|
cors_allowed_origins: Vec<String>,
|
||||||
// Retry configuration
|
|
||||||
retry_max_retries: u32,
|
retry_max_retries: u32,
|
||||||
retry_initial_backoff_ms: u64,
|
retry_initial_backoff_ms: u64,
|
||||||
retry_max_backoff_ms: u64,
|
retry_max_backoff_ms: u64,
|
||||||
retry_backoff_multiplier: f32,
|
retry_backoff_multiplier: f32,
|
||||||
retry_jitter_factor: f32,
|
retry_jitter_factor: f32,
|
||||||
disable_retries: bool,
|
disable_retries: bool,
|
||||||
// Circuit breaker configuration
|
|
||||||
cb_failure_threshold: u32,
|
cb_failure_threshold: u32,
|
||||||
cb_success_threshold: u32,
|
cb_success_threshold: u32,
|
||||||
cb_timeout_duration_secs: u64,
|
cb_timeout_duration_secs: u64,
|
||||||
cb_window_duration_secs: u64,
|
cb_window_duration_secs: u64,
|
||||||
disable_circuit_breaker: bool,
|
disable_circuit_breaker: bool,
|
||||||
// Health check configuration
|
|
||||||
health_failure_threshold: u32,
|
health_failure_threshold: u32,
|
||||||
health_success_threshold: u32,
|
health_success_threshold: u32,
|
||||||
health_check_timeout_secs: u64,
|
health_check_timeout_secs: u64,
|
||||||
health_check_interval_secs: u64,
|
health_check_interval_secs: u64,
|
||||||
health_check_endpoint: String,
|
health_check_endpoint: String,
|
||||||
// IGW (Inference Gateway) configuration
|
|
||||||
enable_igw: bool,
|
enable_igw: bool,
|
||||||
queue_size: usize,
|
queue_size: usize,
|
||||||
queue_timeout_secs: u64,
|
queue_timeout_secs: u64,
|
||||||
rate_limit_tokens_per_second: Option<usize>,
|
rate_limit_tokens_per_second: Option<usize>,
|
||||||
// Connection mode (determined from worker URLs)
|
|
||||||
connection_mode: config::ConnectionMode,
|
connection_mode: config::ConnectionMode,
|
||||||
// Model path for tokenizer
|
|
||||||
model_path: Option<String>,
|
model_path: Option<String>,
|
||||||
// Explicit tokenizer path
|
|
||||||
tokenizer_path: Option<String>,
|
tokenizer_path: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
/// Determine connection mode from worker URLs
|
/// Determine connection mode from worker URLs
|
||||||
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
|
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
|
||||||
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
|
|
||||||
for url in worker_urls {
|
for url in worker_urls {
|
||||||
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||||
return config::ConnectionMode::Grpc;
|
return config::ConnectionMode::Grpc;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Default to HTTP for all other cases (including http://, https://, or no scheme)
|
|
||||||
config::ConnectionMode::Http
|
config::ConnectionMode::Http
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert PyO3 Router to RouterConfig
|
|
||||||
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
|
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
|
||||||
use config::{
|
use config::{
|
||||||
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
|
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert policy helper function
|
|
||||||
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
|
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
|
||||||
match policy {
|
match policy {
|
||||||
PolicyType::Random => ConfigPolicyConfig::Random,
|
PolicyType::Random => ConfigPolicyConfig::Random,
|
||||||
@@ -131,14 +120,12 @@ impl Router {
|
|||||||
max_tree_size: self.max_tree_size,
|
max_tree_size: self.max_tree_size,
|
||||||
},
|
},
|
||||||
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
|
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
|
||||||
load_check_interval_secs: 5, // Default value
|
load_check_interval_secs: 5,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Determine routing mode
|
|
||||||
let mode = if self.enable_igw {
|
let mode = if self.enable_igw {
|
||||||
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
|
||||||
RoutingMode::Regular {
|
RoutingMode::Regular {
|
||||||
worker_urls: vec![],
|
worker_urls: vec![],
|
||||||
}
|
}
|
||||||
@@ -155,10 +142,8 @@ impl Router {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert main policy
|
|
||||||
let policy = convert_policy(&self.policy);
|
let policy = convert_policy(&self.policy);
|
||||||
|
|
||||||
// Service discovery configuration
|
|
||||||
let discovery = if self.service_discovery {
|
let discovery = if self.service_discovery {
|
||||||
Some(DiscoveryConfig {
|
Some(DiscoveryConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
@@ -174,7 +159,6 @@ impl Router {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Metrics configuration
|
|
||||||
let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
|
let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
|
||||||
(Some(port), Some(host)) => Some(MetricsConfig {
|
(Some(port), Some(host)) => Some(MetricsConfig {
|
||||||
port,
|
port,
|
||||||
@@ -251,7 +235,7 @@ impl Router {
|
|||||||
balance_rel_threshold = 1.5,
|
balance_rel_threshold = 1.5,
|
||||||
eviction_interval_secs = 120,
|
eviction_interval_secs = 120,
|
||||||
max_tree_size = 2usize.pow(26),
|
max_tree_size = 2usize.pow(26),
|
||||||
max_payload_size = 512 * 1024 * 1024, // 512MB default for large batches
|
max_payload_size = 512 * 1024 * 1024,
|
||||||
dp_aware = false,
|
dp_aware = false,
|
||||||
api_key = None,
|
api_key = None,
|
||||||
log_dir = None,
|
log_dir = None,
|
||||||
@@ -265,40 +249,35 @@ impl Router {
|
|||||||
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
|
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
|
||||||
prometheus_port = None,
|
prometheus_port = None,
|
||||||
prometheus_host = None,
|
prometheus_host = None,
|
||||||
request_timeout_secs = 1800, // Add configurable request timeout
|
request_timeout_secs = 1800,
|
||||||
request_id_headers = None, // Custom request ID headers
|
request_id_headers = None,
|
||||||
pd_disaggregation = false, // New flag for PD mode
|
pd_disaggregation = false,
|
||||||
prefill_urls = None,
|
prefill_urls = None,
|
||||||
decode_urls = None,
|
decode_urls = None,
|
||||||
prefill_policy = None,
|
prefill_policy = None,
|
||||||
decode_policy = None,
|
decode_policy = None,
|
||||||
max_concurrent_requests = 256,
|
max_concurrent_requests = 256,
|
||||||
cors_allowed_origins = vec![],
|
cors_allowed_origins = vec![],
|
||||||
// Retry defaults
|
|
||||||
retry_max_retries = 5,
|
retry_max_retries = 5,
|
||||||
retry_initial_backoff_ms = 50,
|
retry_initial_backoff_ms = 50,
|
||||||
retry_max_backoff_ms = 30_000,
|
retry_max_backoff_ms = 30_000,
|
||||||
retry_backoff_multiplier = 1.5,
|
retry_backoff_multiplier = 1.5,
|
||||||
retry_jitter_factor = 0.2,
|
retry_jitter_factor = 0.2,
|
||||||
disable_retries = false,
|
disable_retries = false,
|
||||||
// Circuit breaker defaults
|
|
||||||
cb_failure_threshold = 10,
|
cb_failure_threshold = 10,
|
||||||
cb_success_threshold = 3,
|
cb_success_threshold = 3,
|
||||||
cb_timeout_duration_secs = 60,
|
cb_timeout_duration_secs = 60,
|
||||||
cb_window_duration_secs = 120,
|
cb_window_duration_secs = 120,
|
||||||
disable_circuit_breaker = false,
|
disable_circuit_breaker = false,
|
||||||
// Health check defaults
|
|
||||||
health_failure_threshold = 3,
|
health_failure_threshold = 3,
|
||||||
health_success_threshold = 2,
|
health_success_threshold = 2,
|
||||||
health_check_timeout_secs = 5,
|
health_check_timeout_secs = 5,
|
||||||
health_check_interval_secs = 60,
|
health_check_interval_secs = 60,
|
||||||
health_check_endpoint = String::from("/health"),
|
health_check_endpoint = String::from("/health"),
|
||||||
// IGW defaults
|
|
||||||
enable_igw = false,
|
enable_igw = false,
|
||||||
queue_size = 100,
|
queue_size = 100,
|
||||||
queue_timeout_secs = 60,
|
queue_timeout_secs = 60,
|
||||||
rate_limit_tokens_per_second = None,
|
rate_limit_tokens_per_second = None,
|
||||||
// Tokenizer defaults
|
|
||||||
model_path = None,
|
model_path = None,
|
||||||
tokenizer_path = None,
|
tokenizer_path = None,
|
||||||
))]
|
))]
|
||||||
@@ -361,17 +340,14 @@ impl Router {
|
|||||||
model_path: Option<String>,
|
model_path: Option<String>,
|
||||||
tokenizer_path: Option<String>,
|
tokenizer_path: Option<String>,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
// Determine connection mode from worker URLs
|
|
||||||
let mut all_urls = worker_urls.clone();
|
let mut all_urls = worker_urls.clone();
|
||||||
|
|
||||||
// Add prefill URLs if in PD mode
|
|
||||||
if let Some(ref prefill_urls) = prefill_urls {
|
if let Some(ref prefill_urls) = prefill_urls {
|
||||||
for (url, _) in prefill_urls {
|
for (url, _) in prefill_urls {
|
||||||
all_urls.push(url.clone());
|
all_urls.push(url.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add decode URLs if in PD mode
|
|
||||||
if let Some(ref decode_urls) = decode_urls {
|
if let Some(ref decode_urls) = decode_urls {
|
||||||
all_urls.extend(decode_urls.clone());
|
all_urls.extend(decode_urls.clone());
|
||||||
}
|
}
|
||||||
@@ -440,12 +416,10 @@ impl Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn start(&self) -> PyResult<()> {
|
fn start(&self) -> PyResult<()> {
|
||||||
// Convert to RouterConfig and validate
|
|
||||||
let router_config = self.to_router_config().map_err(|e| {
|
let router_config = self.to_router_config().map_err(|e| {
|
||||||
pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
|
pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Validate the configuration
|
|
||||||
router_config.validate().map_err(|e| {
|
router_config.validate().map_err(|e| {
|
||||||
pyo3::exceptions::PyValueError::new_err(format!(
|
pyo3::exceptions::PyValueError::new_err(format!(
|
||||||
"Configuration validation failed: {}",
|
"Configuration validation failed: {}",
|
||||||
@@ -453,7 +427,6 @@ impl Router {
|
|||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Create service discovery config if enabled
|
|
||||||
let service_discovery_config = if self.service_discovery {
|
let service_discovery_config = if self.service_discovery {
|
||||||
Some(service_discovery::ServiceDiscoveryConfig {
|
Some(service_discovery::ServiceDiscoveryConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
@@ -470,7 +443,6 @@ impl Router {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create Prometheus config if enabled
|
|
||||||
let prometheus_config = Some(PrometheusConfig {
|
let prometheus_config = Some(PrometheusConfig {
|
||||||
port: self.prometheus_port.unwrap_or(29000),
|
port: self.prometheus_port.unwrap_or(29000),
|
||||||
host: self
|
host: self
|
||||||
@@ -479,11 +451,9 @@ impl Router {
|
|||||||
.unwrap_or_else(|| "127.0.0.1".to_string()),
|
.unwrap_or_else(|| "127.0.0.1".to_string()),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Use tokio runtime instead of actix-web System for better compatibility
|
|
||||||
let runtime = tokio::runtime::Runtime::new()
|
let runtime = tokio::runtime::Runtime::new()
|
||||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
|
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
|
||||||
|
|
||||||
// Block on the async startup function
|
|
||||||
runtime.block_on(async move {
|
runtime.block_on(async move {
|
||||||
server::startup(server::ServerConfig {
|
server::startup(server::ServerConfig {
|
||||||
host: self.host.clone(),
|
host: self.host.clone(),
|
||||||
|
|||||||
@@ -8,20 +8,13 @@ use tracing_subscriber::layer::SubscriberExt;
|
|||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
use tracing_subscriber::{EnvFilter, Layer};
|
use tracing_subscriber::{EnvFilter, Layer};
|
||||||
|
|
||||||
/// Configuration for the logging system
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LoggingConfig {
|
pub struct LoggingConfig {
|
||||||
/// Log level for the application (default: INFO)
|
|
||||||
pub level: Level,
|
pub level: Level,
|
||||||
/// Whether to use json format for logs (default: false)
|
|
||||||
pub json_format: bool,
|
pub json_format: bool,
|
||||||
/// Path to store log files. If None, logs will only go to stdout/stderr
|
|
||||||
pub log_dir: Option<String>,
|
pub log_dir: Option<String>,
|
||||||
/// Whether to colorize logs when output is a terminal (default: true)
|
|
||||||
pub colorize: bool,
|
pub colorize: bool,
|
||||||
/// Log file name to use if log_dir is specified (default: "sgl-router")
|
|
||||||
pub log_file_name: String,
|
pub log_file_name: String,
|
||||||
/// Custom log targets to filter (default: "sglang_router_rs")
|
|
||||||
pub log_targets: Option<Vec<String>>,
|
pub log_targets: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,30 +31,14 @@ impl Default for LoggingConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Guard that keeps the file appender worker thread alive
|
|
||||||
///
|
|
||||||
/// This must be kept in scope for the duration of the program
|
|
||||||
/// to ensure logs are properly written to files
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub struct LogGuard {
|
pub struct LogGuard {
|
||||||
_file_guard: Option<WorkerGuard>,
|
_file_guard: Option<WorkerGuard>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize the logging system with the given configuration
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `config` - Configuration for the logging system
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// A LogGuard that must be kept alive for the duration of the program
|
|
||||||
///
|
|
||||||
/// # Panics
|
|
||||||
/// Will not panic, as initialization errors are handled gracefully
|
|
||||||
pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
||||||
// Forward logs to tracing - ignore errors to allow for multiple initialization
|
|
||||||
let _ = LogTracer::init();
|
let _ = LogTracer::init();
|
||||||
|
|
||||||
// Convert log level to filter string
|
|
||||||
let level_filter = match config.level {
|
let level_filter = match config.level {
|
||||||
Level::TRACE => "trace",
|
Level::TRACE => "trace",
|
||||||
Level::DEBUG => "debug",
|
Level::DEBUG => "debug",
|
||||||
@@ -70,9 +47,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
|||||||
Level::ERROR => "error",
|
Level::ERROR => "error",
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create env filter
|
|
||||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
||||||
// Format: <target>=<level>,<target2>=<level2>,...
|
|
||||||
let filter_string = if let Some(targets) = &config.log_targets {
|
let filter_string = if let Some(targets) = &config.log_targets {
|
||||||
targets
|
targets
|
||||||
.iter()
|
.iter()
|
||||||
@@ -92,13 +67,10 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
|||||||
EnvFilter::new(filter_string)
|
EnvFilter::new(filter_string)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Setup stdout/stderr layer
|
|
||||||
let mut layers = Vec::new();
|
let mut layers = Vec::new();
|
||||||
|
|
||||||
// Standard timestamp format: YYYY-MM-DD HH:MM:SS
|
|
||||||
let time_format = "%Y-%m-%d %H:%M:%S".to_string();
|
let time_format = "%Y-%m-%d %H:%M:%S".to_string();
|
||||||
|
|
||||||
// Configure the console stdout layer
|
|
||||||
let stdout_layer = tracing_subscriber::fmt::layer()
|
let stdout_layer = tracing_subscriber::fmt::layer()
|
||||||
.with_ansi(config.colorize)
|
.with_ansi(config.colorize)
|
||||||
.with_file(true)
|
.with_file(true)
|
||||||
@@ -113,14 +85,12 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
|||||||
|
|
||||||
layers.push(stdout_layer);
|
layers.push(stdout_layer);
|
||||||
|
|
||||||
// Create a file appender if log_dir is specified
|
|
||||||
let mut file_guard = None;
|
let mut file_guard = None;
|
||||||
|
|
||||||
if let Some(log_dir) = &config.log_dir {
|
if let Some(log_dir) = &config.log_dir {
|
||||||
let file_name = config.log_file_name.clone();
|
let file_name = config.log_file_name.clone();
|
||||||
let log_dir = PathBuf::from(log_dir);
|
let log_dir = PathBuf::from(log_dir);
|
||||||
|
|
||||||
// Create log directory if it doesn't exist
|
|
||||||
if !log_dir.exists() {
|
if !log_dir.exists() {
|
||||||
if let Err(e) = std::fs::create_dir_all(&log_dir) {
|
if let Err(e) = std::fs::create_dir_all(&log_dir) {
|
||||||
eprintln!("Failed to create log directory: {}", e);
|
eprintln!("Failed to create log directory: {}", e);
|
||||||
@@ -134,7 +104,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
|||||||
file_guard = Some(guard);
|
file_guard = Some(guard);
|
||||||
|
|
||||||
let file_layer = tracing_subscriber::fmt::layer()
|
let file_layer = tracing_subscriber::fmt::layer()
|
||||||
.with_ansi(false) // Never use ANSI colors in log files
|
.with_ansi(false)
|
||||||
.with_file(true)
|
.with_file(true)
|
||||||
.with_line_number(true)
|
.with_line_number(true)
|
||||||
.with_timer(ChronoUtc::new(time_format))
|
.with_timer(ChronoUtc::new(time_format))
|
||||||
@@ -149,14 +119,11 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
|||||||
layers.push(file_layer);
|
layers.push(file_layer);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the subscriber with all layers
|
|
||||||
// Use try_init to handle errors gracefully in case another subscriber is already set
|
|
||||||
let _ = tracing_subscriber::registry()
|
let _ = tracing_subscriber::registry()
|
||||||
.with(env_filter)
|
.with(env_filter)
|
||||||
.with(layers)
|
.with(layers)
|
||||||
.try_init();
|
.try_init();
|
||||||
|
|
||||||
// Return the guard to keep the file appender worker thread alive
|
|
||||||
LogGuard {
|
LogGuard {
|
||||||
_file_guard: file_guard,
|
_file_guard: file_guard,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ use sglang_router_rs::server::{self, ServerConfig};
|
|||||||
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
|
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
// Helper function to parse prefill arguments from command line
|
|
||||||
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
let mut prefill_entries = Vec::new();
|
let mut prefill_entries = Vec::new();
|
||||||
@@ -19,12 +18,11 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
|||||||
if args[i] == "--prefill" && i + 1 < args.len() {
|
if args[i] == "--prefill" && i + 1 < args.len() {
|
||||||
let url = args[i + 1].clone();
|
let url = args[i + 1].clone();
|
||||||
let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") {
|
let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") {
|
||||||
// Check if next arg is a port number
|
|
||||||
if let Ok(port) = args[i + 2].parse::<u16>() {
|
if let Ok(port) = args[i + 2].parse::<u16>() {
|
||||||
i += 1; // Skip the port argument
|
i += 1;
|
||||||
Some(port)
|
Some(port)
|
||||||
} else if args[i + 2].to_lowercase() == "none" {
|
} else if args[i + 2].to_lowercase() == "none" {
|
||||||
i += 1; // Skip the "none" argument
|
i += 1;
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@@ -33,7 +31,7 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
prefill_entries.push((url, bootstrap_port));
|
prefill_entries.push((url, bootstrap_port));
|
||||||
i += 2; // Skip --prefill and URL
|
i += 2;
|
||||||
} else {
|
} else {
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
@@ -101,252 +99,186 @@ Examples:
|
|||||||
|
|
||||||
"#)]
|
"#)]
|
||||||
struct CliArgs {
|
struct CliArgs {
|
||||||
/// Host address to bind the router server
|
|
||||||
#[arg(long, default_value = "127.0.0.1")]
|
#[arg(long, default_value = "127.0.0.1")]
|
||||||
host: String,
|
host: String,
|
||||||
|
|
||||||
/// Port number to bind the router server
|
|
||||||
#[arg(long, default_value_t = 30000)]
|
#[arg(long, default_value_t = 30000)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
|
|
||||||
#[arg(long, num_args = 0..)]
|
#[arg(long, num_args = 0..)]
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
|
|
||||||
/// Load balancing policy to use
|
|
||||||
#[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
#[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||||
policy: String,
|
policy: String,
|
||||||
|
|
||||||
/// Enable PD (Prefill-Decode) disaggregated mode
|
|
||||||
#[arg(long, default_value_t = false)]
|
#[arg(long, default_value_t = false)]
|
||||||
pd_disaggregation: bool,
|
pd_disaggregation: bool,
|
||||||
|
|
||||||
/// Decode server URL (can be specified multiple times)
|
|
||||||
#[arg(long, action = ArgAction::Append)]
|
#[arg(long, action = ArgAction::Append)]
|
||||||
decode: Vec<String>,
|
decode: Vec<String>,
|
||||||
|
|
||||||
/// Specific policy for prefill nodes in PD mode
|
|
||||||
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||||
prefill_policy: Option<String>,
|
prefill_policy: Option<String>,
|
||||||
|
|
||||||
/// Specific policy for decode nodes in PD mode
|
|
||||||
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||||
decode_policy: Option<String>,
|
decode_policy: Option<String>,
|
||||||
|
|
||||||
/// Timeout in seconds for worker startup
|
|
||||||
#[arg(long, default_value_t = 600)]
|
#[arg(long, default_value_t = 600)]
|
||||||
worker_startup_timeout_secs: u64,
|
worker_startup_timeout_secs: u64,
|
||||||
|
|
||||||
/// Interval in seconds between checks for worker startup
|
|
||||||
#[arg(long, default_value_t = 30)]
|
#[arg(long, default_value_t = 30)]
|
||||||
worker_startup_check_interval: u64,
|
worker_startup_check_interval: u64,
|
||||||
|
|
||||||
/// Cache threshold (0.0-1.0) for cache-aware routing
|
|
||||||
#[arg(long, default_value_t = 0.3)]
|
#[arg(long, default_value_t = 0.3)]
|
||||||
cache_threshold: f32,
|
cache_threshold: f32,
|
||||||
|
|
||||||
/// Absolute threshold for load balancing
|
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
balance_abs_threshold: usize,
|
balance_abs_threshold: usize,
|
||||||
|
|
||||||
/// Relative threshold for load balancing
|
|
||||||
#[arg(long, default_value_t = 1.5)]
|
#[arg(long, default_value_t = 1.5)]
|
||||||
balance_rel_threshold: f32,
|
balance_rel_threshold: f32,
|
||||||
|
|
||||||
/// Interval in seconds between cache eviction operations
|
|
||||||
#[arg(long, default_value_t = 120)]
|
#[arg(long, default_value_t = 120)]
|
||||||
eviction_interval: u64,
|
eviction_interval: u64,
|
||||||
|
|
||||||
/// Maximum size of the approximation tree for cache-aware routing
|
#[arg(long, default_value_t = 67108864)]
|
||||||
#[arg(long, default_value_t = 67108864)] // 2^26
|
|
||||||
max_tree_size: usize,
|
max_tree_size: usize,
|
||||||
|
|
||||||
/// Maximum payload size in bytes
|
#[arg(long, default_value_t = 536870912)]
|
||||||
#[arg(long, default_value_t = 536870912)] // 512MB
|
|
||||||
max_payload_size: usize,
|
max_payload_size: usize,
|
||||||
|
|
||||||
/// Enable data parallelism aware schedule
|
|
||||||
#[arg(long, default_value_t = false)]
|
#[arg(long, default_value_t = false)]
|
||||||
dp_aware: bool,
|
dp_aware: bool,
|
||||||
|
|
||||||
/// API key for worker authorization
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
|
|
||||||
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
|
|
||||||
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
|
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
|
||||||
backend: Backend,
|
backend: Backend,
|
||||||
|
|
||||||
/// Directory to store log files
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
log_dir: Option<String>,
|
log_dir: Option<String>,
|
||||||
|
|
||||||
/// Set the logging level
|
|
||||||
#[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])]
|
#[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])]
|
||||||
log_level: String,
|
log_level: String,
|
||||||
|
|
||||||
/// Enable Kubernetes service discovery
|
|
||||||
#[arg(long, default_value_t = false)]
|
#[arg(long, default_value_t = false)]
|
||||||
service_discovery: bool,
|
service_discovery: bool,
|
||||||
|
|
||||||
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
|
|
||||||
#[arg(long, num_args = 0..)]
|
#[arg(long, num_args = 0..)]
|
||||||
selector: Vec<String>,
|
selector: Vec<String>,
|
||||||
|
|
||||||
/// Port to use for discovered worker pods
|
|
||||||
#[arg(long, default_value_t = 80)]
|
#[arg(long, default_value_t = 80)]
|
||||||
service_discovery_port: u16,
|
service_discovery_port: u16,
|
||||||
|
|
||||||
/// Kubernetes namespace to watch for pods
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
service_discovery_namespace: Option<String>,
|
service_discovery_namespace: Option<String>,
|
||||||
|
|
||||||
/// Label selector for prefill server pods in PD mode
|
|
||||||
#[arg(long, num_args = 0..)]
|
#[arg(long, num_args = 0..)]
|
||||||
prefill_selector: Vec<String>,
|
prefill_selector: Vec<String>,
|
||||||
|
|
||||||
/// Label selector for decode server pods in PD mode
|
|
||||||
#[arg(long, num_args = 0..)]
|
#[arg(long, num_args = 0..)]
|
||||||
decode_selector: Vec<String>,
|
decode_selector: Vec<String>,
|
||||||
|
|
||||||
/// Port to expose Prometheus metrics
|
|
||||||
#[arg(long, default_value_t = 29000)]
|
#[arg(long, default_value_t = 29000)]
|
||||||
prometheus_port: u16,
|
prometheus_port: u16,
|
||||||
|
|
||||||
/// Host address to bind the Prometheus metrics server
|
|
||||||
#[arg(long, default_value = "127.0.0.1")]
|
#[arg(long, default_value = "127.0.0.1")]
|
||||||
prometheus_host: String,
|
prometheus_host: String,
|
||||||
|
|
||||||
/// Custom HTTP headers to check for request IDs
|
|
||||||
#[arg(long, num_args = 0..)]
|
#[arg(long, num_args = 0..)]
|
||||||
request_id_headers: Vec<String>,
|
request_id_headers: Vec<String>,
|
||||||
|
|
||||||
/// Request timeout in seconds
|
|
||||||
#[arg(long, default_value_t = 1800)]
|
#[arg(long, default_value_t = 1800)]
|
||||||
request_timeout_secs: u64,
|
request_timeout_secs: u64,
|
||||||
|
|
||||||
/// Maximum number of concurrent requests allowed
|
|
||||||
#[arg(long, default_value_t = 256)]
|
#[arg(long, default_value_t = 256)]
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
|
||||||
/// CORS allowed origins
|
|
||||||
#[arg(long, num_args = 0..)]
|
#[arg(long, num_args = 0..)]
|
||||||
cors_allowed_origins: Vec<String>,
|
cors_allowed_origins: Vec<String>,
|
||||||
|
|
||||||
// Retry configuration
|
|
||||||
/// Maximum number of retries
|
|
||||||
#[arg(long, default_value_t = 5)]
|
#[arg(long, default_value_t = 5)]
|
||||||
retry_max_retries: u32,
|
retry_max_retries: u32,
|
||||||
|
|
||||||
/// Initial backoff in milliseconds for retries
|
|
||||||
#[arg(long, default_value_t = 50)]
|
#[arg(long, default_value_t = 50)]
|
||||||
retry_initial_backoff_ms: u64,
|
retry_initial_backoff_ms: u64,
|
||||||
|
|
||||||
/// Maximum backoff in milliseconds for retries
|
|
||||||
#[arg(long, default_value_t = 30000)]
|
#[arg(long, default_value_t = 30000)]
|
||||||
retry_max_backoff_ms: u64,
|
retry_max_backoff_ms: u64,
|
||||||
|
|
||||||
/// Backoff multiplier for exponential backoff
|
|
||||||
#[arg(long, default_value_t = 1.5)]
|
#[arg(long, default_value_t = 1.5)]
|
||||||
retry_backoff_multiplier: f32,
|
retry_backoff_multiplier: f32,
|
||||||
|
|
||||||
/// Jitter factor for retry backoff
|
|
||||||
#[arg(long, default_value_t = 0.2)]
|
#[arg(long, default_value_t = 0.2)]
|
||||||
retry_jitter_factor: f32,
|
retry_jitter_factor: f32,
|
||||||
|
|
||||||
/// Disable retries
|
|
||||||
#[arg(long, default_value_t = false)]
|
#[arg(long, default_value_t = false)]
|
||||||
disable_retries: bool,
|
disable_retries: bool,
|
||||||
|
|
||||||
// Circuit breaker configuration
|
|
||||||
/// Number of failures before circuit breaker opens
|
|
||||||
#[arg(long, default_value_t = 10)]
|
#[arg(long, default_value_t = 10)]
|
||||||
cb_failure_threshold: u32,
|
cb_failure_threshold: u32,
|
||||||
|
|
||||||
/// Number of successes before circuit breaker closes
|
|
||||||
#[arg(long, default_value_t = 3)]
|
#[arg(long, default_value_t = 3)]
|
||||||
cb_success_threshold: u32,
|
cb_success_threshold: u32,
|
||||||
|
|
||||||
/// Timeout duration in seconds for circuit breaker
|
|
||||||
#[arg(long, default_value_t = 60)]
|
#[arg(long, default_value_t = 60)]
|
||||||
cb_timeout_duration_secs: u64,
|
cb_timeout_duration_secs: u64,
|
||||||
|
|
||||||
/// Window duration in seconds for circuit breaker
|
|
||||||
#[arg(long, default_value_t = 120)]
|
#[arg(long, default_value_t = 120)]
|
||||||
cb_window_duration_secs: u64,
|
cb_window_duration_secs: u64,
|
||||||
|
|
||||||
/// Disable circuit breaker
|
|
||||||
#[arg(long, default_value_t = false)]
|
#[arg(long, default_value_t = false)]
|
||||||
disable_circuit_breaker: bool,
|
disable_circuit_breaker: bool,
|
||||||
|
|
||||||
// Health check configuration
|
|
||||||
/// Number of consecutive health check failures before marking worker unhealthy
|
|
||||||
#[arg(long, default_value_t = 3)]
|
#[arg(long, default_value_t = 3)]
|
||||||
health_failure_threshold: u32,
|
health_failure_threshold: u32,
|
||||||
|
|
||||||
/// Number of consecutive health check successes before marking worker healthy
|
|
||||||
#[arg(long, default_value_t = 2)]
|
#[arg(long, default_value_t = 2)]
|
||||||
health_success_threshold: u32,
|
health_success_threshold: u32,
|
||||||
|
|
||||||
/// Timeout in seconds for health check requests
|
|
||||||
#[arg(long, default_value_t = 5)]
|
#[arg(long, default_value_t = 5)]
|
||||||
health_check_timeout_secs: u64,
|
health_check_timeout_secs: u64,
|
||||||
|
|
||||||
/// Interval in seconds between runtime health checks
|
|
||||||
#[arg(long, default_value_t = 60)]
|
#[arg(long, default_value_t = 60)]
|
||||||
health_check_interval_secs: u64,
|
health_check_interval_secs: u64,
|
||||||
|
|
||||||
/// Health check endpoint path
|
|
||||||
#[arg(long, default_value = "/health")]
|
#[arg(long, default_value = "/health")]
|
||||||
health_check_endpoint: String,
|
health_check_endpoint: String,
|
||||||
|
|
||||||
// IGW (Inference Gateway) configuration
|
|
||||||
/// Enable Inference Gateway mode
|
|
||||||
#[arg(long, default_value_t = false)]
|
#[arg(long, default_value_t = false)]
|
||||||
enable_igw: bool,
|
enable_igw: bool,
|
||||||
|
|
||||||
// Tokenizer configuration
|
|
||||||
/// Model path for loading tokenizer (HuggingFace model ID or local path)
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_path: Option<String>,
|
model_path: Option<String>,
|
||||||
|
|
||||||
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer_path: Option<String>,
|
tokenizer_path: Option<String>,
|
||||||
|
|
||||||
/// History backend configuration (memory, none, or oracle)
|
|
||||||
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
|
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
|
||||||
history_backend: String,
|
history_backend: String,
|
||||||
|
|
||||||
/// Directory containing the Oracle ATP wallet/config files (optional)
|
|
||||||
#[arg(long, env = "ATP_WALLET_PATH")]
|
#[arg(long, env = "ATP_WALLET_PATH")]
|
||||||
oracle_wallet_path: Option<String>,
|
oracle_wallet_path: Option<String>,
|
||||||
|
|
||||||
/// Wallet TNS alias to use (e.g. `<db_name>_low`)
|
|
||||||
#[arg(long, env = "ATP_TNS_ALIAS")]
|
#[arg(long, env = "ATP_TNS_ALIAS")]
|
||||||
oracle_tns_alias: Option<String>,
|
oracle_tns_alias: Option<String>,
|
||||||
|
|
||||||
/// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`)
|
|
||||||
#[arg(long, env = "ATP_DSN")]
|
#[arg(long, env = "ATP_DSN")]
|
||||||
oracle_dsn: Option<String>,
|
oracle_dsn: Option<String>,
|
||||||
|
|
||||||
/// Oracle ATP username
|
|
||||||
#[arg(long, env = "ATP_USER")]
|
#[arg(long, env = "ATP_USER")]
|
||||||
oracle_user: Option<String>,
|
oracle_user: Option<String>,
|
||||||
|
|
||||||
/// Oracle ATP password
|
|
||||||
#[arg(long, env = "ATP_PASSWORD")]
|
#[arg(long, env = "ATP_PASSWORD")]
|
||||||
oracle_password: Option<String>,
|
oracle_password: Option<String>,
|
||||||
|
|
||||||
/// Minimum number of pooled ATP connections (defaults to 1 when omitted)
|
|
||||||
#[arg(long, env = "ATP_POOL_MIN")]
|
#[arg(long, env = "ATP_POOL_MIN")]
|
||||||
oracle_pool_min: Option<usize>,
|
oracle_pool_min: Option<usize>,
|
||||||
|
|
||||||
/// Maximum number of pooled ATP connections (defaults to 16 when omitted)
|
|
||||||
#[arg(long, env = "ATP_POOL_MAX")]
|
#[arg(long, env = "ATP_POOL_MAX")]
|
||||||
oracle_pool_max: Option<usize>,
|
oracle_pool_max: Option<usize>,
|
||||||
|
|
||||||
/// Connection acquisition timeout in seconds (defaults to 30 when omitted)
|
|
||||||
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
|
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
|
||||||
oracle_pool_timeout_secs: Option<u64>,
|
oracle_pool_timeout_secs: Option<u64>,
|
||||||
}
|
}
|
||||||
@@ -357,19 +289,15 @@ enum OracleConnectSource {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl CliArgs {
|
impl CliArgs {
|
||||||
/// Determine connection mode from worker URLs
|
|
||||||
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
|
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
|
||||||
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
|
|
||||||
for url in worker_urls {
|
for url in worker_urls {
|
||||||
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||||
return ConnectionMode::Grpc;
|
return ConnectionMode::Grpc;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Default to HTTP for all other cases (including http://, https://, or no scheme)
|
|
||||||
ConnectionMode::Http
|
ConnectionMode::Http
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse selector strings into HashMap
|
|
||||||
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
|
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
|
||||||
let mut map = HashMap::new();
|
let mut map = HashMap::new();
|
||||||
for item in selector_list {
|
for item in selector_list {
|
||||||
@@ -382,7 +310,6 @@ impl CliArgs {
|
|||||||
map
|
map
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert policy string to PolicyConfig
|
|
||||||
fn parse_policy(&self, policy_str: &str) -> PolicyConfig {
|
fn parse_policy(&self, policy_str: &str) -> PolicyConfig {
|
||||||
match policy_str {
|
match policy_str {
|
||||||
"random" => PolicyConfig::Random,
|
"random" => PolicyConfig::Random,
|
||||||
@@ -395,9 +322,9 @@ impl CliArgs {
|
|||||||
max_tree_size: self.max_tree_size,
|
max_tree_size: self.max_tree_size,
|
||||||
},
|
},
|
||||||
"power_of_two" => PolicyConfig::PowerOfTwo {
|
"power_of_two" => PolicyConfig::PowerOfTwo {
|
||||||
load_check_interval_secs: 5, // Default value
|
load_check_interval_secs: 5,
|
||||||
},
|
},
|
||||||
_ => PolicyConfig::RoundRobin, // Fallback
|
_ => PolicyConfig::RoundRobin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -482,26 +409,21 @@ impl CliArgs {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert CLI arguments to RouterConfig
|
|
||||||
fn to_router_config(
|
fn to_router_config(
|
||||||
&self,
|
&self,
|
||||||
prefill_urls: Vec<(String, Option<u16>)>,
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
) -> ConfigResult<RouterConfig> {
|
) -> ConfigResult<RouterConfig> {
|
||||||
// Determine routing mode
|
|
||||||
let mode = if self.enable_igw {
|
let mode = if self.enable_igw {
|
||||||
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
|
||||||
RoutingMode::Regular {
|
RoutingMode::Regular {
|
||||||
worker_urls: vec![],
|
worker_urls: vec![],
|
||||||
}
|
}
|
||||||
} else if matches!(self.backend, Backend::Openai) {
|
} else if matches!(self.backend, Backend::Openai) {
|
||||||
// OpenAI backend mode - use worker_urls as base(s)
|
|
||||||
RoutingMode::OpenAI {
|
RoutingMode::OpenAI {
|
||||||
worker_urls: self.worker_urls.clone(),
|
worker_urls: self.worker_urls.clone(),
|
||||||
}
|
}
|
||||||
} else if self.pd_disaggregation {
|
} else if self.pd_disaggregation {
|
||||||
let decode_urls = self.decode.clone();
|
let decode_urls = self.decode.clone();
|
||||||
|
|
||||||
// Validate PD configuration if not using service discovery
|
|
||||||
if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) {
|
if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) {
|
||||||
return Err(ConfigError::ValidationFailed {
|
return Err(ConfigError::ValidationFailed {
|
||||||
reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(),
|
reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(),
|
||||||
@@ -515,7 +437,6 @@ impl CliArgs {
|
|||||||
decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)),
|
decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Regular mode
|
|
||||||
if !self.service_discovery && self.worker_urls.is_empty() {
|
if !self.service_discovery && self.worker_urls.is_empty() {
|
||||||
return Err(ConfigError::ValidationFailed {
|
return Err(ConfigError::ValidationFailed {
|
||||||
reason: "Regular mode requires --worker-urls when not using service discovery"
|
reason: "Regular mode requires --worker-urls when not using service discovery"
|
||||||
@@ -527,10 +448,8 @@ impl CliArgs {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Main policy
|
|
||||||
let policy = self.parse_policy(&self.policy);
|
let policy = self.parse_policy(&self.policy);
|
||||||
|
|
||||||
// Service discovery configuration
|
|
||||||
let discovery = if self.service_discovery {
|
let discovery = if self.service_discovery {
|
||||||
Some(DiscoveryConfig {
|
Some(DiscoveryConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
@@ -546,13 +465,11 @@ impl CliArgs {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Metrics configuration
|
|
||||||
let metrics = Some(MetricsConfig {
|
let metrics = Some(MetricsConfig {
|
||||||
port: self.prometheus_port,
|
port: self.prometheus_port,
|
||||||
host: self.prometheus_host.clone(),
|
host: self.prometheus_host.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Determine connection mode from all worker URLs
|
|
||||||
let mut all_urls = Vec::new();
|
let mut all_urls = Vec::new();
|
||||||
match &mode {
|
match &mode {
|
||||||
RoutingMode::Regular { worker_urls } => {
|
RoutingMode::Regular { worker_urls } => {
|
||||||
@@ -568,9 +485,7 @@ impl CliArgs {
|
|||||||
}
|
}
|
||||||
all_urls.extend(decode_urls.clone());
|
all_urls.extend(decode_urls.clone());
|
||||||
}
|
}
|
||||||
RoutingMode::OpenAI { .. } => {
|
RoutingMode::OpenAI { .. } => {}
|
||||||
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
let connection_mode = match &mode {
|
let connection_mode = match &mode {
|
||||||
RoutingMode::OpenAI { .. } => ConnectionMode::Http,
|
RoutingMode::OpenAI { .. } => ConnectionMode::Http,
|
||||||
@@ -589,7 +504,6 @@ impl CliArgs {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build RouterConfig
|
|
||||||
Ok(RouterConfig {
|
Ok(RouterConfig {
|
||||||
mode,
|
mode,
|
||||||
policy,
|
policy,
|
||||||
@@ -612,8 +526,8 @@ impl CliArgs {
|
|||||||
Some(self.request_id_headers.clone())
|
Some(self.request_id_headers.clone())
|
||||||
},
|
},
|
||||||
max_concurrent_requests: self.max_concurrent_requests,
|
max_concurrent_requests: self.max_concurrent_requests,
|
||||||
queue_size: 100, // Default queue size
|
queue_size: 100,
|
||||||
queue_timeout_secs: 60, // Default timeout
|
queue_timeout_secs: 60,
|
||||||
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||||
retry: RetryConfig {
|
retry: RetryConfig {
|
||||||
max_retries: self.retry_max_retries,
|
max_retries: self.retry_max_retries,
|
||||||
@@ -646,9 +560,7 @@ impl CliArgs {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create ServerConfig from CLI args and RouterConfig
|
|
||||||
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
|
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
|
||||||
// Create service discovery config if enabled
|
|
||||||
let service_discovery_config = if self.service_discovery {
|
let service_discovery_config = if self.service_discovery {
|
||||||
Some(ServiceDiscoveryConfig {
|
Some(ServiceDiscoveryConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
@@ -665,7 +577,6 @@ impl CliArgs {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create Prometheus config
|
|
||||||
let prometheus_config = Some(PrometheusConfig {
|
let prometheus_config = Some(PrometheusConfig {
|
||||||
port: self.prometheus_port,
|
port: self.prometheus_port,
|
||||||
host: self.prometheus_host.clone(),
|
host: self.prometheus_host.clone(),
|
||||||
@@ -691,19 +602,15 @@ impl CliArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Parse prefill arguments manually before clap parsing
|
|
||||||
let prefill_urls = parse_prefill_args();
|
let prefill_urls = parse_prefill_args();
|
||||||
|
|
||||||
// Filter out prefill arguments and their values before passing to clap
|
|
||||||
let mut filtered_args: Vec<String> = Vec::new();
|
let mut filtered_args: Vec<String> = Vec::new();
|
||||||
let raw_args: Vec<String> = std::env::args().collect();
|
let raw_args: Vec<String> = std::env::args().collect();
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
|
|
||||||
while i < raw_args.len() {
|
while i < raw_args.len() {
|
||||||
if raw_args[i] == "--prefill" && i + 1 < raw_args.len() {
|
if raw_args[i] == "--prefill" && i + 1 < raw_args.len() {
|
||||||
// Skip --prefill and its URL
|
|
||||||
i += 2;
|
i += 2;
|
||||||
// Also skip bootstrap port if present
|
|
||||||
if i < raw_args.len()
|
if i < raw_args.len()
|
||||||
&& !raw_args[i].starts_with("--")
|
&& !raw_args[i].starts_with("--")
|
||||||
&& (raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none")
|
&& (raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none")
|
||||||
@@ -716,10 +623,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse CLI arguments with clap using filtered args
|
|
||||||
let cli_args = CliArgs::parse_from(filtered_args);
|
let cli_args = CliArgs::parse_from(filtered_args);
|
||||||
|
|
||||||
// Print startup info
|
|
||||||
println!("SGLang Router starting...");
|
println!("SGLang Router starting...");
|
||||||
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
||||||
let mode_str = if cli_args.enable_igw {
|
let mode_str = if cli_args.enable_igw {
|
||||||
@@ -733,7 +638,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
};
|
};
|
||||||
println!("Mode: {}", mode_str);
|
println!("Mode: {}", mode_str);
|
||||||
|
|
||||||
// Warn for runtimes that are parsed but not yet implemented
|
|
||||||
match cli_args.backend {
|
match cli_args.backend {
|
||||||
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
|
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
|
||||||
println!(
|
println!(
|
||||||
@@ -754,19 +658,10 @@ Provide --worker-urls or PD flags as usual.",
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to RouterConfig
|
|
||||||
let router_config = cli_args.to_router_config(prefill_urls)?;
|
let router_config = cli_args.to_router_config(prefill_urls)?;
|
||||||
|
|
||||||
// Validate configuration
|
|
||||||
router_config.validate()?;
|
router_config.validate()?;
|
||||||
|
|
||||||
// Create ServerConfig
|
|
||||||
let server_config = cli_args.to_server_config(router_config);
|
let server_config = cli_args.to_server_config(router_config);
|
||||||
|
|
||||||
// Create a new runtime for the server (like Python binding does)
|
|
||||||
let runtime = tokio::runtime::Runtime::new()?;
|
let runtime = tokio::runtime::Runtime::new()?;
|
||||||
|
|
||||||
// Block on the async startup function
|
|
||||||
runtime.block_on(async move { server::startup(server_config).await })?;
|
runtime.block_on(async move { server::startup(server_config).await })?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ impl Default for PrometheusConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn init_metrics() {
|
pub fn init_metrics() {
|
||||||
// Request metrics
|
|
||||||
describe_counter!(
|
describe_counter!(
|
||||||
"sgl_router_requests_total",
|
"sgl_router_requests_total",
|
||||||
"Total number of requests by route and method"
|
"Total number of requests by route and method"
|
||||||
@@ -45,7 +44,6 @@ pub fn init_metrics() {
|
|||||||
"Total number of requests that exhausted retries by route"
|
"Total number of requests that exhausted retries by route"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Circuit breaker metrics
|
|
||||||
describe_gauge!(
|
describe_gauge!(
|
||||||
"sgl_router_cb_state",
|
"sgl_router_cb_state",
|
||||||
"Circuit breaker state per worker (0=closed, 1=open, 2=half_open)"
|
"Circuit breaker state per worker (0=closed, 1=open, 2=half_open)"
|
||||||
@@ -59,7 +57,6 @@ pub fn init_metrics() {
|
|||||||
"Total number of circuit breaker outcomes by worker and outcome type (success/failure)"
|
"Total number of circuit breaker outcomes by worker and outcome type (success/failure)"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Worker metrics
|
|
||||||
describe_gauge!(
|
describe_gauge!(
|
||||||
"sgl_router_active_workers",
|
"sgl_router_active_workers",
|
||||||
"Number of currently active workers"
|
"Number of currently active workers"
|
||||||
@@ -74,7 +71,6 @@ pub fn init_metrics() {
|
|||||||
"Total requests processed by each worker"
|
"Total requests processed by each worker"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Policy metrics
|
|
||||||
describe_counter!(
|
describe_counter!(
|
||||||
"sgl_router_policy_decisions_total",
|
"sgl_router_policy_decisions_total",
|
||||||
"Total routing policy decisions by policy and worker"
|
"Total routing policy decisions by policy and worker"
|
||||||
@@ -92,7 +88,6 @@ pub fn init_metrics() {
|
|||||||
describe_gauge!("sgl_router_max_load", "Maximum worker load");
|
describe_gauge!("sgl_router_max_load", "Maximum worker load");
|
||||||
describe_gauge!("sgl_router_min_load", "Minimum worker load");
|
describe_gauge!("sgl_router_min_load", "Minimum worker load");
|
||||||
|
|
||||||
// PD-specific metrics
|
|
||||||
describe_counter!("sgl_router_pd_requests_total", "Total PD requests by route");
|
describe_counter!("sgl_router_pd_requests_total", "Total PD requests by route");
|
||||||
describe_counter!(
|
describe_counter!(
|
||||||
"sgl_router_pd_prefill_requests_total",
|
"sgl_router_pd_prefill_requests_total",
|
||||||
@@ -123,7 +118,6 @@ pub fn init_metrics() {
|
|||||||
"PD request duration by route"
|
"PD request duration by route"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Service discovery metrics
|
|
||||||
describe_counter!(
|
describe_counter!(
|
||||||
"sgl_router_discovery_updates_total",
|
"sgl_router_discovery_updates_total",
|
||||||
"Total service discovery update events"
|
"Total service discovery update events"
|
||||||
@@ -137,13 +131,11 @@ pub fn init_metrics() {
|
|||||||
"Number of workers removed in last discovery update"
|
"Number of workers removed in last discovery update"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Generate request specific metrics
|
|
||||||
describe_histogram!(
|
describe_histogram!(
|
||||||
"sgl_router_generate_duration_seconds",
|
"sgl_router_generate_duration_seconds",
|
||||||
"Generate request duration"
|
"Generate request duration"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Embedding request specific metrics
|
|
||||||
describe_counter!("sgl_router_embeddings_total", "Total embedding requests");
|
describe_counter!("sgl_router_embeddings_total", "Total embedding requests");
|
||||||
describe_histogram!(
|
describe_histogram!(
|
||||||
"sgl_router_embeddings_duration_seconds",
|
"sgl_router_embeddings_duration_seconds",
|
||||||
@@ -155,13 +147,11 @@ pub fn init_metrics() {
|
|||||||
);
|
);
|
||||||
describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size");
|
describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size");
|
||||||
|
|
||||||
// Running requests gauge for cache-aware policy
|
|
||||||
describe_gauge!(
|
describe_gauge!(
|
||||||
"sgl_router_running_requests",
|
"sgl_router_running_requests",
|
||||||
"Number of running requests per worker"
|
"Number of running requests per worker"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Tokenizer metrics
|
|
||||||
describe_histogram!(
|
describe_histogram!(
|
||||||
"sgl_tokenizer_encode_duration_seconds",
|
"sgl_tokenizer_encode_duration_seconds",
|
||||||
"Time to encode text to tokens"
|
"Time to encode text to tokens"
|
||||||
@@ -207,7 +197,6 @@ pub fn init_metrics() {
|
|||||||
"Vocabulary size of the loaded tokenizer"
|
"Vocabulary size of the loaded tokenizer"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Stop sequence detection metrics
|
|
||||||
describe_counter!(
|
describe_counter!(
|
||||||
"sgl_tokenizer_stop_sequences_detected_total",
|
"sgl_tokenizer_stop_sequences_detected_total",
|
||||||
"Total stop sequences detected by type"
|
"Total stop sequences detected by type"
|
||||||
@@ -221,7 +210,6 @@ pub fn init_metrics() {
|
|||||||
"Time to check for stop sequences per token"
|
"Time to check for stop sequences per token"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Streaming decode metrics
|
|
||||||
describe_counter!(
|
describe_counter!(
|
||||||
"sgl_tokenizer_stream_tokens_total",
|
"sgl_tokenizer_stream_tokens_total",
|
||||||
"Total tokens processed in streaming decode"
|
"Total tokens processed in streaming decode"
|
||||||
@@ -235,7 +223,6 @@ pub fn init_metrics() {
|
|||||||
"Time per streaming decode step"
|
"Time per streaming decode step"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Factory metrics
|
|
||||||
describe_counter!(
|
describe_counter!(
|
||||||
"sgl_tokenizer_factory_loads_total",
|
"sgl_tokenizer_factory_loads_total",
|
||||||
"Total tokenizer loads by file type"
|
"Total tokenizer loads by file type"
|
||||||
@@ -251,7 +238,6 @@ pub fn init_metrics() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn start_prometheus(config: PrometheusConfig) {
|
pub fn start_prometheus(config: PrometheusConfig) {
|
||||||
// Initialize metric descriptions
|
|
||||||
init_metrics();
|
init_metrics();
|
||||||
|
|
||||||
let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
|
let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
|
||||||
@@ -280,7 +266,6 @@ pub struct RouterMetrics;
|
|||||||
pub struct TokenizerMetrics;
|
pub struct TokenizerMetrics;
|
||||||
|
|
||||||
impl RouterMetrics {
|
impl RouterMetrics {
|
||||||
// Request metrics
|
|
||||||
pub fn record_request(route: &str) {
|
pub fn record_request(route: &str) {
|
||||||
counter!("sgl_router_requests_total",
|
counter!("sgl_router_requests_total",
|
||||||
"route" => route.to_string()
|
"route" => route.to_string()
|
||||||
@@ -324,7 +309,6 @@ impl RouterMetrics {
|
|||||||
.increment(1);
|
.increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Worker metrics
|
|
||||||
pub fn set_active_workers(count: usize) {
|
pub fn set_active_workers(count: usize) {
|
||||||
gauge!("sgl_router_active_workers").set(count as f64);
|
gauge!("sgl_router_active_workers").set(count as f64);
|
||||||
}
|
}
|
||||||
@@ -350,7 +334,6 @@ impl RouterMetrics {
|
|||||||
.increment(1);
|
.increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Policy metrics
|
|
||||||
pub fn record_policy_decision(policy: &str, worker: &str) {
|
pub fn record_policy_decision(policy: &str, worker: &str) {
|
||||||
counter!("sgl_router_policy_decisions_total",
|
counter!("sgl_router_policy_decisions_total",
|
||||||
"policy" => policy.to_string(),
|
"policy" => policy.to_string(),
|
||||||
@@ -383,7 +366,6 @@ impl RouterMetrics {
|
|||||||
gauge!("sgl_router_min_load").set(min_load as f64);
|
gauge!("sgl_router_min_load").set(min_load as f64);
|
||||||
}
|
}
|
||||||
|
|
||||||
// PD-specific metrics
|
|
||||||
pub fn record_pd_request(route: &str) {
|
pub fn record_pd_request(route: &str) {
|
||||||
counter!("sgl_router_pd_requests_total",
|
counter!("sgl_router_pd_requests_total",
|
||||||
"route" => route.to_string()
|
"route" => route.to_string()
|
||||||
@@ -440,19 +422,16 @@ impl RouterMetrics {
|
|||||||
.increment(1);
|
.increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Service discovery metrics
|
|
||||||
pub fn record_discovery_update(added: usize, removed: usize) {
|
pub fn record_discovery_update(added: usize, removed: usize) {
|
||||||
counter!("sgl_router_discovery_updates_total").increment(1);
|
counter!("sgl_router_discovery_updates_total").increment(1);
|
||||||
gauge!("sgl_router_discovery_workers_added").set(added as f64);
|
gauge!("sgl_router_discovery_workers_added").set(added as f64);
|
||||||
gauge!("sgl_router_discovery_workers_removed").set(removed as f64);
|
gauge!("sgl_router_discovery_workers_removed").set(removed as f64);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate request metrics
|
|
||||||
pub fn record_generate_duration(duration: Duration) {
|
pub fn record_generate_duration(duration: Duration) {
|
||||||
histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64());
|
histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Embeddings metrics
|
|
||||||
pub fn record_embeddings_request() {
|
pub fn record_embeddings_request() {
|
||||||
counter!("sgl_router_embeddings_total").increment(1);
|
counter!("sgl_router_embeddings_total").increment(1);
|
||||||
}
|
}
|
||||||
@@ -473,7 +452,6 @@ impl RouterMetrics {
|
|||||||
gauge!("sgl_router_embeddings_queue_size").set(size as f64);
|
gauge!("sgl_router_embeddings_queue_size").set(size as f64);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Running requests for cache-aware policy
|
|
||||||
pub fn set_running_requests(worker: &str, count: usize) {
|
pub fn set_running_requests(worker: &str, count: usize) {
|
||||||
gauge!("sgl_router_running_requests",
|
gauge!("sgl_router_running_requests",
|
||||||
"worker" => worker.to_string()
|
"worker" => worker.to_string()
|
||||||
@@ -481,7 +459,6 @@ impl RouterMetrics {
|
|||||||
.set(count as f64);
|
.set(count as f64);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Circuit breaker metrics
|
|
||||||
pub fn set_cb_state(worker: &str, state_code: u8) {
|
pub fn set_cb_state(worker: &str, state_code: u8) {
|
||||||
gauge!("sgl_router_cb_state",
|
gauge!("sgl_router_cb_state",
|
||||||
"worker" => worker.to_string()
|
"worker" => worker.to_string()
|
||||||
@@ -508,7 +485,6 @@ impl RouterMetrics {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TokenizerMetrics {
|
impl TokenizerMetrics {
|
||||||
// Encoding metrics
|
|
||||||
pub fn record_encode_request(tokenizer_type: &str) {
|
pub fn record_encode_request(tokenizer_type: &str) {
|
||||||
counter!("sgl_tokenizer_encode_requests_total",
|
counter!("sgl_tokenizer_encode_requests_total",
|
||||||
"tokenizer_type" => tokenizer_type.to_string()
|
"tokenizer_type" => tokenizer_type.to_string()
|
||||||
@@ -535,7 +511,6 @@ impl TokenizerMetrics {
|
|||||||
histogram!("sgl_tokenizer_chars_per_encode").record(char_count as f64);
|
histogram!("sgl_tokenizer_chars_per_encode").record(char_count as f64);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decoding metrics
|
|
||||||
pub fn record_decode_request(tokenizer_type: &str) {
|
pub fn record_decode_request(tokenizer_type: &str) {
|
||||||
counter!("sgl_tokenizer_decode_requests_total",
|
counter!("sgl_tokenizer_decode_requests_total",
|
||||||
"tokenizer_type" => tokenizer_type.to_string()
|
"tokenizer_type" => tokenizer_type.to_string()
|
||||||
@@ -558,7 +533,6 @@ impl TokenizerMetrics {
|
|||||||
histogram!("sgl_tokenizer_tokens_per_decode").record(token_count as f64);
|
histogram!("sgl_tokenizer_tokens_per_decode").record(token_count as f64);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Batch encoding metrics
|
|
||||||
pub fn record_encode_batch_duration(duration: Duration, batch_size: usize) {
|
pub fn record_encode_batch_duration(duration: Duration, batch_size: usize) {
|
||||||
histogram!("sgl_tokenizer_encode_batch_duration_seconds",
|
histogram!("sgl_tokenizer_encode_batch_duration_seconds",
|
||||||
"batch_size" => batch_size.to_string()
|
"batch_size" => batch_size.to_string()
|
||||||
@@ -566,7 +540,6 @@ impl TokenizerMetrics {
|
|||||||
.record(duration.as_secs_f64());
|
.record(duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop sequence detection metrics
|
|
||||||
pub fn record_stop_sequence_detected(stop_type: &str) {
|
pub fn record_stop_sequence_detected(stop_type: &str) {
|
||||||
counter!("sgl_tokenizer_stop_sequences_detected_total",
|
counter!("sgl_tokenizer_stop_sequences_detected_total",
|
||||||
"type" => stop_type.to_string()
|
"type" => stop_type.to_string()
|
||||||
@@ -582,7 +555,6 @@ impl TokenizerMetrics {
|
|||||||
histogram!("sgl_tokenizer_stop_detection_duration_seconds").record(duration.as_secs_f64());
|
histogram!("sgl_tokenizer_stop_detection_duration_seconds").record(duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Streaming decode metrics
|
|
||||||
pub fn record_stream_token() {
|
pub fn record_stream_token() {
|
||||||
counter!("sgl_tokenizer_stream_tokens_total").increment(1);
|
counter!("sgl_tokenizer_stream_tokens_total").increment(1);
|
||||||
}
|
}
|
||||||
@@ -595,7 +567,6 @@ impl TokenizerMetrics {
|
|||||||
histogram!("sgl_tokenizer_stream_step_duration_seconds").record(duration.as_secs_f64());
|
histogram!("sgl_tokenizer_stream_step_duration_seconds").record(duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Factory metrics
|
|
||||||
pub fn record_factory_load(file_type: &str) {
|
pub fn record_factory_load(file_type: &str) {
|
||||||
counter!("sgl_tokenizer_factory_loads_total",
|
counter!("sgl_tokenizer_factory_loads_total",
|
||||||
"file_type" => file_type.to_string()
|
"file_type" => file_type.to_string()
|
||||||
@@ -614,7 +585,6 @@ impl TokenizerMetrics {
|
|||||||
histogram!("sgl_tokenizer_factory_load_duration_seconds").record(duration.as_secs_f64());
|
histogram!("sgl_tokenizer_factory_load_duration_seconds").record(duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vocabulary metrics
|
|
||||||
pub fn set_vocab_size(tokenizer_type: &str, size: usize) {
|
pub fn set_vocab_size(tokenizer_type: &str, size: usize) {
|
||||||
gauge!("sgl_tokenizer_vocab_size",
|
gauge!("sgl_tokenizer_vocab_size",
|
||||||
"tokenizer_type" => tokenizer_type.to_string()
|
"tokenizer_type" => tokenizer_type.to_string()
|
||||||
@@ -705,7 +675,6 @@ mod tests {
|
|||||||
.parse()
|
.parse()
|
||||||
.unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
|
.unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
|
||||||
|
|
||||||
// Should fall back to 0.0.0.0
|
|
||||||
assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
|
assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -780,7 +749,6 @@ mod tests {
|
|||||||
fn test_duration_suffix_matcher() {
|
fn test_duration_suffix_matcher() {
|
||||||
let matcher = Matcher::Suffix(String::from("duration_seconds"));
|
let matcher = Matcher::Suffix(String::from("duration_seconds"));
|
||||||
|
|
||||||
// Test matching behavior
|
|
||||||
let _matching_metrics = [
|
let _matching_metrics = [
|
||||||
"request_duration_seconds",
|
"request_duration_seconds",
|
||||||
"response_duration_seconds",
|
"response_duration_seconds",
|
||||||
@@ -789,8 +757,6 @@ mod tests {
|
|||||||
|
|
||||||
let _non_matching_metrics = ["duration_total", "duration_seconds_total", "other_metric"];
|
let _non_matching_metrics = ["duration_total", "duration_seconds_total", "other_metric"];
|
||||||
|
|
||||||
// Note: We can't directly test Matcher matching without the internals,
|
|
||||||
// but we can verify the matcher is created correctly
|
|
||||||
match matcher {
|
match matcher {
|
||||||
Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"),
|
Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"),
|
||||||
_ => panic!("Expected Suffix matcher"),
|
_ => panic!("Expected Suffix matcher"),
|
||||||
@@ -801,7 +767,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_prometheus_builder_configuration() {
|
fn test_prometheus_builder_configuration() {
|
||||||
// This test verifies the builder configuration without actually starting Prometheus
|
|
||||||
let _config = PrometheusConfig::default();
|
let _config = PrometheusConfig::default();
|
||||||
|
|
||||||
let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
|
let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
|
||||||
@@ -810,10 +775,8 @@ mod tests {
|
|||||||
60.0, 90.0, 120.0, 180.0, 240.0,
|
60.0, 90.0, 120.0, 180.0, 240.0,
|
||||||
];
|
];
|
||||||
|
|
||||||
// Verify bucket configuration
|
|
||||||
assert_eq!(duration_bucket.len(), 20);
|
assert_eq!(duration_bucket.len(), 20);
|
||||||
|
|
||||||
// Verify matcher is suffix type
|
|
||||||
match duration_matcher {
|
match duration_matcher {
|
||||||
Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"),
|
Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"),
|
||||||
_ => panic!("Expected Suffix matcher"),
|
_ => panic!("Expected Suffix matcher"),
|
||||||
@@ -832,14 +795,12 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_custom_buckets_for_different_metrics() {
|
fn test_custom_buckets_for_different_metrics() {
|
||||||
// Test that we can create different bucket configurations
|
|
||||||
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
|
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
|
||||||
let generate_buckets = [0.1, 0.5, 1.0, 5.0, 30.0, 60.0];
|
let generate_buckets = [0.1, 0.5, 1.0, 5.0, 30.0, 60.0];
|
||||||
|
|
||||||
assert_eq!(request_buckets.len(), 5);
|
assert_eq!(request_buckets.len(), 5);
|
||||||
assert_eq!(generate_buckets.len(), 6);
|
assert_eq!(generate_buckets.len(), 6);
|
||||||
|
|
||||||
// Verify each set is sorted
|
|
||||||
for i in 1..request_buckets.len() {
|
for i in 1..request_buckets.len() {
|
||||||
assert!(request_buckets[i] > request_buckets[i - 1]);
|
assert!(request_buckets[i] > request_buckets[i - 1]);
|
||||||
}
|
}
|
||||||
@@ -853,7 +814,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_metrics_static_methods() {
|
fn test_metrics_static_methods() {
|
||||||
// Test that all static methods can be called without panic
|
|
||||||
RouterMetrics::record_request("/generate");
|
RouterMetrics::record_request("/generate");
|
||||||
RouterMetrics::record_request_duration("/generate", Duration::from_millis(100));
|
RouterMetrics::record_request_duration("/generate", Duration::from_millis(100));
|
||||||
RouterMetrics::record_request_error("/generate", "timeout");
|
RouterMetrics::record_request_error("/generate", "timeout");
|
||||||
@@ -887,41 +847,32 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tokenizer_metrics_static_methods() {
|
fn test_tokenizer_metrics_static_methods() {
|
||||||
// Test that all tokenizer metric methods can be called without panic
|
|
||||||
|
|
||||||
// Encoding metrics
|
|
||||||
TokenizerMetrics::record_encode_request("huggingface");
|
TokenizerMetrics::record_encode_request("huggingface");
|
||||||
TokenizerMetrics::record_encode_duration(Duration::from_millis(10));
|
TokenizerMetrics::record_encode_duration(Duration::from_millis(10));
|
||||||
TokenizerMetrics::record_encode_error("invalid_input");
|
TokenizerMetrics::record_encode_error("invalid_input");
|
||||||
TokenizerMetrics::record_tokens_per_encode(100);
|
TokenizerMetrics::record_tokens_per_encode(100);
|
||||||
TokenizerMetrics::record_chars_per_encode(500);
|
TokenizerMetrics::record_chars_per_encode(500);
|
||||||
|
|
||||||
// Decoding metrics
|
|
||||||
TokenizerMetrics::record_decode_request("huggingface");
|
TokenizerMetrics::record_decode_request("huggingface");
|
||||||
TokenizerMetrics::record_decode_duration(Duration::from_millis(5));
|
TokenizerMetrics::record_decode_duration(Duration::from_millis(5));
|
||||||
TokenizerMetrics::record_decode_error("invalid_tokens");
|
TokenizerMetrics::record_decode_error("invalid_tokens");
|
||||||
TokenizerMetrics::record_tokens_per_decode(50);
|
TokenizerMetrics::record_tokens_per_decode(50);
|
||||||
|
|
||||||
// Batch encoding
|
|
||||||
TokenizerMetrics::record_encode_batch_duration(Duration::from_millis(100), 10);
|
TokenizerMetrics::record_encode_batch_duration(Duration::from_millis(100), 10);
|
||||||
|
|
||||||
// Stop sequence detection
|
|
||||||
TokenizerMetrics::record_stop_sequence_detected("token");
|
TokenizerMetrics::record_stop_sequence_detected("token");
|
||||||
TokenizerMetrics::record_stop_sequence_detected("string");
|
TokenizerMetrics::record_stop_sequence_detected("string");
|
||||||
TokenizerMetrics::record_partial_match();
|
TokenizerMetrics::record_partial_match();
|
||||||
TokenizerMetrics::record_stop_detection_duration(Duration::from_micros(100));
|
TokenizerMetrics::record_stop_detection_duration(Duration::from_micros(100));
|
||||||
|
|
||||||
// Streaming decode
|
|
||||||
TokenizerMetrics::record_stream_token();
|
TokenizerMetrics::record_stream_token();
|
||||||
TokenizerMetrics::record_incomplete_utf8();
|
TokenizerMetrics::record_incomplete_utf8();
|
||||||
TokenizerMetrics::record_stream_step_duration(Duration::from_micros(50));
|
TokenizerMetrics::record_stream_step_duration(Duration::from_micros(50));
|
||||||
|
|
||||||
// Factory metrics
|
|
||||||
TokenizerMetrics::record_factory_load("json");
|
TokenizerMetrics::record_factory_load("json");
|
||||||
TokenizerMetrics::record_factory_error("unsupported_format");
|
TokenizerMetrics::record_factory_error("unsupported_format");
|
||||||
TokenizerMetrics::record_factory_load_duration(Duration::from_millis(200));
|
TokenizerMetrics::record_factory_load_duration(Duration::from_millis(200));
|
||||||
|
|
||||||
// Vocabulary metrics
|
|
||||||
TokenizerMetrics::set_vocab_size("huggingface", 50000);
|
TokenizerMetrics::set_vocab_size("huggingface", 50000);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -929,17 +880,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_port_already_in_use() {
|
fn test_port_already_in_use() {
|
||||||
// Skip this test if we can't bind to the port
|
let port = 29123;
|
||||||
let port = 29123; // Use a different port to avoid conflicts
|
|
||||||
|
|
||||||
if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) {
|
if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) {
|
||||||
// Port is available, we can test
|
|
||||||
let config = PrometheusConfig {
|
let config = PrometheusConfig {
|
||||||
port,
|
port,
|
||||||
host: "127.0.0.1".to_string(),
|
host: "127.0.0.1".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Just verify config is created correctly
|
|
||||||
assert_eq!(config.port, port);
|
assert_eq!(config.port, port);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -948,8 +896,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_metrics_endpoint_accessibility() {
|
fn test_metrics_endpoint_accessibility() {
|
||||||
// This would be an integration test in practice
|
|
||||||
// Here we just verify the configuration
|
|
||||||
let config = PrometheusConfig {
|
let config = PrometheusConfig {
|
||||||
port: 29000,
|
port: 29000,
|
||||||
host: "127.0.0.1".to_string(),
|
host: "127.0.0.1".to_string(),
|
||||||
@@ -963,7 +909,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_concurrent_metric_updates() {
|
fn test_concurrent_metric_updates() {
|
||||||
// Test that metric updates can be called concurrently
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread;
|
use std::thread;
|
||||||
@@ -984,11 +929,9 @@ mod tests {
|
|||||||
handles.push(handle);
|
handles.push(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let threads run briefly
|
|
||||||
thread::sleep(Duration::from_millis(10));
|
thread::sleep(Duration::from_millis(10));
|
||||||
done.store(true, Ordering::Relaxed);
|
done.store(true, Ordering::Relaxed);
|
||||||
|
|
||||||
// Wait for all threads
|
|
||||||
for handle in handles {
|
for handle in handles {
|
||||||
handle.join().unwrap();
|
handle.join().unwrap();
|
||||||
}
|
}
|
||||||
@@ -998,7 +941,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_empty_string_metrics() {
|
fn test_empty_string_metrics() {
|
||||||
// Test that empty strings don't cause issues
|
|
||||||
RouterMetrics::record_request("");
|
RouterMetrics::record_request("");
|
||||||
RouterMetrics::set_worker_health("", true);
|
RouterMetrics::set_worker_health("", true);
|
||||||
RouterMetrics::record_policy_decision("", "");
|
RouterMetrics::record_policy_decision("", "");
|
||||||
@@ -1030,7 +972,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_extreme_metric_values() {
|
fn test_extreme_metric_values() {
|
||||||
// Test extreme values
|
|
||||||
RouterMetrics::set_active_workers(0);
|
RouterMetrics::set_active_workers(0);
|
||||||
RouterMetrics::set_active_workers(usize::MAX);
|
RouterMetrics::set_active_workers(usize::MAX);
|
||||||
|
|
||||||
@@ -1038,7 +979,6 @@ mod tests {
|
|||||||
RouterMetrics::set_worker_load("worker", usize::MAX);
|
RouterMetrics::set_worker_load("worker", usize::MAX);
|
||||||
|
|
||||||
RouterMetrics::record_request_duration("route", Duration::from_nanos(1));
|
RouterMetrics::record_request_duration("route", Duration::from_nanos(1));
|
||||||
// 24 hours
|
|
||||||
RouterMetrics::record_request_duration("route", Duration::from_secs(86400));
|
RouterMetrics::record_request_duration("route", Duration::from_secs(86400));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ use tokio::task;
|
|||||||
use tokio::time;
|
use tokio::time;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
/// Represents the service discovery configuration
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ServiceDiscoveryConfig {
|
pub struct ServiceDiscoveryConfig {
|
||||||
pub enabled: bool,
|
pub enabled: bool,
|
||||||
@@ -41,8 +40,8 @@ impl Default for ServiceDiscoveryConfig {
|
|||||||
enabled: false,
|
enabled: false,
|
||||||
selector: HashMap::new(),
|
selector: HashMap::new(),
|
||||||
check_interval: Duration::from_secs(60),
|
check_interval: Duration::from_secs(60),
|
||||||
port: 8000, // Standard port for modern services
|
port: 8000,
|
||||||
namespace: None, // None means watch all namespaces
|
namespace: None,
|
||||||
pd_mode: false,
|
pd_mode: false,
|
||||||
prefill_selector: HashMap::new(),
|
prefill_selector: HashMap::new(),
|
||||||
decode_selector: HashMap::new(),
|
decode_selector: HashMap::new(),
|
||||||
@@ -51,7 +50,6 @@ impl Default for ServiceDiscoveryConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pod type for PD mode service discovery
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum PodType {
|
pub enum PodType {
|
||||||
Prefill,
|
Prefill,
|
||||||
@@ -59,7 +57,6 @@ pub enum PodType {
|
|||||||
Regular,
|
Regular,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents a Kubernetes pod's information used for worker management
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub struct PodInfo {
|
pub struct PodInfo {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
@@ -71,7 +68,6 @@ pub struct PodInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PodInfo {
|
impl PodInfo {
|
||||||
/// Check if a pod matches any of the given selectors
|
|
||||||
fn matches_selector(pod: &Pod, selector: &HashMap<String, String>) -> bool {
|
fn matches_selector(pod: &Pod, selector: &HashMap<String, String>) -> bool {
|
||||||
if selector.is_empty() {
|
if selector.is_empty() {
|
||||||
return false;
|
return false;
|
||||||
@@ -83,19 +79,15 @@ impl PodInfo {
|
|||||||
.is_some_and(|labels| selector.iter().all(|(k, v)| labels.get(k) == Some(v)))
|
.is_some_and(|labels| selector.iter().all(|(k, v)| labels.get(k) == Some(v)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if a pod should be included in service discovery
|
|
||||||
pub fn should_include(pod: &Pod, config: &ServiceDiscoveryConfig) -> bool {
|
pub fn should_include(pod: &Pod, config: &ServiceDiscoveryConfig) -> bool {
|
||||||
if config.pd_mode {
|
if config.pd_mode {
|
||||||
// In PD mode, at least one selector must be non-empty
|
|
||||||
if config.prefill_selector.is_empty() && config.decode_selector.is_empty() {
|
if config.prefill_selector.is_empty() && config.decode_selector.is_empty() {
|
||||||
warn!("PD mode enabled but both prefill_selector and decode_selector are empty");
|
warn!("PD mode enabled but both prefill_selector and decode_selector are empty");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// In PD mode, pod must match either prefill or decode selector
|
|
||||||
Self::matches_selector(pod, &config.prefill_selector)
|
Self::matches_selector(pod, &config.prefill_selector)
|
||||||
|| Self::matches_selector(pod, &config.decode_selector)
|
|| Self::matches_selector(pod, &config.decode_selector)
|
||||||
} else {
|
} else {
|
||||||
// In regular mode, pod must match the general selector
|
|
||||||
if config.selector.is_empty() {
|
if config.selector.is_empty() {
|
||||||
warn!("Regular mode enabled but selector is empty");
|
warn!("Regular mode enabled but selector is empty");
|
||||||
return false;
|
return false;
|
||||||
@@ -104,7 +96,6 @@ impl PodInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unified PodInfo creation with optional PD configuration
|
|
||||||
pub fn from_pod(pod: &Pod, config: Option<&ServiceDiscoveryConfig>) -> Option<Self> {
|
pub fn from_pod(pod: &Pod, config: Option<&ServiceDiscoveryConfig>) -> Option<Self> {
|
||||||
let name = pod.metadata.name.clone()?;
|
let name = pod.metadata.name.clone()?;
|
||||||
let status = pod.status.clone()?;
|
let status = pod.status.clone()?;
|
||||||
@@ -120,10 +111,8 @@ impl PodInfo {
|
|||||||
|
|
||||||
let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string());
|
let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string());
|
||||||
|
|
||||||
// Determine pod type based on labels if config is provided and in PD mode
|
|
||||||
let pod_type = if let Some(config) = config {
|
let pod_type = if let Some(config) = config {
|
||||||
if config.pd_mode {
|
if config.pd_mode {
|
||||||
// Use simplified helper methods for cleaner logic
|
|
||||||
if Self::matches_selector(pod, &config.prefill_selector) {
|
if Self::matches_selector(pod, &config.prefill_selector) {
|
||||||
Some(PodType::Prefill)
|
Some(PodType::Prefill)
|
||||||
} else if Self::matches_selector(pod, &config.decode_selector) {
|
} else if Self::matches_selector(pod, &config.decode_selector) {
|
||||||
@@ -135,11 +124,9 @@ impl PodInfo {
|
|||||||
Some(PodType::Regular)
|
Some(PodType::Regular)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No config provided, default to None (for backwards compatibility)
|
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Extract bootstrap port from annotations for prefill pods
|
|
||||||
let bootstrap_port = if matches!(pod_type, Some(PodType::Prefill)) {
|
let bootstrap_port = if matches!(pod_type, Some(PodType::Prefill)) {
|
||||||
if let Some(config) = config {
|
if let Some(config) = config {
|
||||||
pod.metadata
|
pod.metadata
|
||||||
@@ -164,12 +151,10 @@ impl PodInfo {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if the pod is in a state where it can accept traffic
|
|
||||||
pub fn is_healthy(&self) -> bool {
|
pub fn is_healthy(&self) -> bool {
|
||||||
self.is_ready && self.status == "Running"
|
self.is_ready && self.status == "Running"
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a worker URL for this pod
|
|
||||||
pub fn worker_url(&self, port: u16) -> String {
|
pub fn worker_url(&self, port: u16) -> String {
|
||||||
format!("http://{}:{}", self.ip, port)
|
format!("http://{}:{}", self.ip, port)
|
||||||
}
|
}
|
||||||
@@ -179,9 +164,7 @@ pub async fn start_service_discovery(
|
|||||||
config: ServiceDiscoveryConfig,
|
config: ServiceDiscoveryConfig,
|
||||||
app_context: Arc<AppContext>,
|
app_context: Arc<AppContext>,
|
||||||
) -> Result<task::JoinHandle<()>, kube::Error> {
|
) -> Result<task::JoinHandle<()>, kube::Error> {
|
||||||
// Don't initialize anything if service discovery is disabled
|
|
||||||
if !config.enabled {
|
if !config.enabled {
|
||||||
// Return a generic error when service discovery is disabled
|
|
||||||
return Err(kube::Error::Api(kube::error::ErrorResponse {
|
return Err(kube::Error::Api(kube::error::ErrorResponse {
|
||||||
status: "Disabled".to_string(),
|
status: "Disabled".to_string(),
|
||||||
message: "Service discovery is disabled".to_string(),
|
message: "Service discovery is disabled".to_string(),
|
||||||
@@ -192,7 +175,6 @@ pub async fn start_service_discovery(
|
|||||||
|
|
||||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
|
|
||||||
// Initialize Kubernetes client
|
|
||||||
let client = Client::try_default().await?;
|
let client = Client::try_default().await?;
|
||||||
|
|
||||||
// Log the appropriate selectors based on mode
|
// Log the appropriate selectors based on mode
|
||||||
@@ -229,12 +211,9 @@ pub async fn start_service_discovery(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the task that will run in the background
|
|
||||||
let handle = task::spawn(async move {
|
let handle = task::spawn(async move {
|
||||||
// We'll track pods we've already added to avoid duplicates
|
|
||||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
|
|
||||||
// Create a watcher for pods
|
|
||||||
let pods: Api<Pod> = if let Some(namespace) = &config.namespace {
|
let pods: Api<Pod> = if let Some(namespace) = &config.namespace {
|
||||||
Api::namespaced(client, namespace)
|
Api::namespaced(client, namespace)
|
||||||
} else {
|
} else {
|
||||||
@@ -243,23 +222,19 @@ pub async fn start_service_discovery(
|
|||||||
|
|
||||||
debug!("K8s service discovery initialized");
|
debug!("K8s service discovery initialized");
|
||||||
|
|
||||||
// Create Arcs for configuration data
|
|
||||||
let config_arc = Arc::new(config.clone());
|
let config_arc = Arc::new(config.clone());
|
||||||
let port = config.port;
|
let port = config.port;
|
||||||
|
|
||||||
let mut retry_delay = Duration::from_secs(1);
|
let mut retry_delay = Duration::from_secs(1);
|
||||||
const MAX_RETRY_DELAY: Duration = Duration::from_secs(300); // 5 minutes max
|
const MAX_RETRY_DELAY: Duration = Duration::from_secs(300);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Create a watcher with the proper parameters according to the kube-rs API
|
|
||||||
let watcher_config = Config::default();
|
let watcher_config = Config::default();
|
||||||
let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects();
|
let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects();
|
||||||
|
|
||||||
// Clone Arcs for the closures
|
|
||||||
let config_clone = Arc::clone(&config_arc);
|
let config_clone = Arc::clone(&config_arc);
|
||||||
let tracked_pods_clone = Arc::clone(&tracked_pods);
|
let tracked_pods_clone = Arc::clone(&tracked_pods);
|
||||||
|
|
||||||
// Simplified label selector filter using helper method
|
|
||||||
let filtered_stream = watcher_stream.filter_map(move |obj_res| {
|
let filtered_stream = watcher_stream.filter_map(move |obj_res| {
|
||||||
let config_inner = Arc::clone(&config_clone);
|
let config_inner = Arc::clone(&config_clone);
|
||||||
|
|
||||||
@@ -277,7 +252,6 @@ pub async fn start_service_discovery(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Clone again for the next closure
|
|
||||||
let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone);
|
let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone);
|
||||||
let app_context_clone = Arc::clone(&app_context);
|
let app_context_clone = Arc::clone(&app_context);
|
||||||
let config_clone2 = Arc::clone(&config_arc);
|
let config_clone2 = Arc::clone(&config_arc);
|
||||||
@@ -317,7 +291,6 @@ pub async fn start_service_discovery(
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
// Reset retry delay on success
|
|
||||||
retry_delay = Duration::from_secs(1);
|
retry_delay = Duration::from_secs(1);
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
@@ -328,12 +301,10 @@ pub async fn start_service_discovery(
|
|||||||
);
|
);
|
||||||
time::sleep(retry_delay).await;
|
time::sleep(retry_delay).await;
|
||||||
|
|
||||||
// Exponential backoff with jitter
|
|
||||||
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
|
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the watcher exits for some reason, wait a bit before restarting
|
|
||||||
warn!(
|
warn!(
|
||||||
"Kubernetes watcher exited, restarting in {} seconds",
|
"Kubernetes watcher exited, restarting in {} seconds",
|
||||||
config_arc.check_interval.as_secs()
|
config_arc.check_interval.as_secs()
|
||||||
@@ -354,9 +325,7 @@ async fn handle_pod_event(
|
|||||||
) {
|
) {
|
||||||
let worker_url = pod_info.worker_url(port);
|
let worker_url = pod_info.worker_url(port);
|
||||||
|
|
||||||
// If pod is healthy, try to add it (with atomic check-and-insert)
|
|
||||||
if pod_info.is_healthy() {
|
if pod_info.is_healthy() {
|
||||||
// Atomic check-and-insert to prevent race conditions
|
|
||||||
let should_add = {
|
let should_add = {
|
||||||
let mut tracker = match tracked_pods.lock() {
|
let mut tracker = match tracked_pods.lock() {
|
||||||
Ok(tracker) => tracker,
|
Ok(tracker) => tracker,
|
||||||
@@ -367,9 +336,8 @@ async fn handle_pod_event(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if tracker.contains(pod_info) {
|
if tracker.contains(pod_info) {
|
||||||
false // Already tracked
|
false
|
||||||
} else {
|
} else {
|
||||||
// Reserve the spot to prevent other threads from adding the same pod
|
|
||||||
tracker.insert(pod_info.clone());
|
tracker.insert(pod_info.clone());
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
@@ -381,7 +349,6 @@ async fn handle_pod_event(
|
|||||||
pod_info.name, pod_info.pod_type, worker_url
|
pod_info.name, pod_info.pod_type, worker_url
|
||||||
);
|
);
|
||||||
|
|
||||||
// Build worker config based on pod type and routing mode
|
|
||||||
let worker_type = if pd_mode {
|
let worker_type = if pd_mode {
|
||||||
match &pod_info.pod_type {
|
match &pod_info.pod_type {
|
||||||
Some(PodType::Prefill) => Some("prefill".to_string()),
|
Some(PodType::Prefill) => Some("prefill".to_string()),
|
||||||
@@ -392,7 +359,6 @@ async fn handle_pod_event(
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Only set bootstrap_port for prefill workers in PD mode
|
|
||||||
let bootstrap_port = if pd_mode {
|
let bootstrap_port = if pd_mode {
|
||||||
match &pod_info.pod_type {
|
match &pod_info.pod_type {
|
||||||
Some(PodType::Prefill) => pod_info.bootstrap_port,
|
Some(PodType::Prefill) => pod_info.bootstrap_port,
|
||||||
@@ -425,7 +391,6 @@ async fn handle_pod_event(
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to add worker {} to router: {}", worker_url, e);
|
error!("Failed to add worker {} to router: {}", worker_url, e);
|
||||||
// Remove from tracking since addition failed
|
|
||||||
if let Ok(mut tracker) = tracked_pods.lock() {
|
if let Ok(mut tracker) = tracked_pods.lock() {
|
||||||
tracker.remove(pod_info);
|
tracker.remove(pod_info);
|
||||||
}
|
}
|
||||||
@@ -464,8 +429,6 @@ async fn handle_pod_deletion(
|
|||||||
error!("Failed to remove worker {}: {}", worker_url, e);
|
error!("Failed to remove worker {}: {}", worker_url, e);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// This case might occur if a pod is deleted before it was ever marked healthy and added.
|
|
||||||
// Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added).
|
|
||||||
debug!(
|
debug!(
|
||||||
"Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}",
|
"Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}",
|
||||||
pod_info.name, pod_info.pod_type, worker_url
|
pod_info.name, pod_info.pod_type, worker_url
|
||||||
@@ -480,7 +443,6 @@ mod tests {
|
|||||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
|
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
|
||||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
|
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
|
||||||
|
|
||||||
// Helper function to create a Pod for testing PodInfo::from_pod
|
|
||||||
fn create_k8s_pod(
|
fn create_k8s_pod(
|
||||||
name: Option<&str>,
|
name: Option<&str>,
|
||||||
ip: Option<&str>,
|
ip: Option<&str>,
|
||||||
@@ -523,7 +485,6 @@ mod tests {
|
|||||||
pod
|
pod
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to create a Pod with PD-specific labels and annotations
|
|
||||||
fn create_pd_k8s_pod(name: &str, ip: &str, pod_type: &str, bootstrap_port: Option<u16>) -> Pod {
|
fn create_pd_k8s_pod(name: &str, ip: &str, pod_type: &str, bootstrap_port: Option<u16>) -> Pod {
|
||||||
let mut labels = std::collections::BTreeMap::new();
|
let mut labels = std::collections::BTreeMap::new();
|
||||||
labels.insert("app".to_string(), "sglang".to_string());
|
labels.insert("app".to_string(), "sglang".to_string());
|
||||||
@@ -559,18 +520,15 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to create an AppContext instance for testing event handlers
|
|
||||||
async fn create_test_app_context() -> Arc<AppContext> {
|
async fn create_test_app_context() -> Arc<AppContext> {
|
||||||
use crate::config::RouterConfig;
|
use crate::config::RouterConfig;
|
||||||
use crate::middleware::TokenBucket;
|
use crate::middleware::TokenBucket;
|
||||||
|
|
||||||
// Create a minimal RouterConfig for testing with very short timeout
|
|
||||||
let router_config = RouterConfig {
|
let router_config = RouterConfig {
|
||||||
worker_startup_timeout_secs: 1,
|
worker_startup_timeout_secs: 1,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}; // Very short timeout for tests
|
};
|
||||||
|
|
||||||
// Create AppContext with minimal components
|
|
||||||
Arc::new(AppContext {
|
Arc::new(AppContext {
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
router_config: router_config.clone(),
|
router_config: router_config.clone(),
|
||||||
@@ -579,16 +537,15 @@ mod tests {
|
|||||||
policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
|
policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
|
||||||
router_config.policy.clone(),
|
router_config.policy.clone(),
|
||||||
)),
|
)),
|
||||||
tokenizer: None, // HTTP mode doesn't need tokenizer
|
tokenizer: None,
|
||||||
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
|
reasoning_parser_factory: None,
|
||||||
tool_parser_registry: None, // HTTP mode doesn't need tool parser
|
tool_parser_registry: None,
|
||||||
router_manager: None, // Test doesn't need router manager
|
router_manager: None,
|
||||||
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
||||||
load_monitor: None,
|
load_monitor: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to create a PD config for testing
|
|
||||||
fn create_pd_config() -> ServiceDiscoveryConfig {
|
fn create_pd_config() -> ServiceDiscoveryConfig {
|
||||||
let mut prefill_selector = HashMap::new();
|
let mut prefill_selector = HashMap::new();
|
||||||
prefill_selector.insert("app".to_string(), "sglang".to_string());
|
prefill_selector.insert("app".to_string(), "sglang".to_string());
|
||||||
@@ -615,19 +572,15 @@ mod tests {
|
|||||||
fn test_pod_info_should_include() {
|
fn test_pod_info_should_include() {
|
||||||
let config = create_pd_config();
|
let config = create_pd_config();
|
||||||
|
|
||||||
// Test prefill pod should be included
|
|
||||||
let prefill_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081));
|
let prefill_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081));
|
||||||
assert!(PodInfo::should_include(&prefill_pod, &config));
|
assert!(PodInfo::should_include(&prefill_pod, &config));
|
||||||
|
|
||||||
// Test decode pod should be included
|
|
||||||
let decode_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None);
|
let decode_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None);
|
||||||
assert!(PodInfo::should_include(&decode_pod, &config));
|
assert!(PodInfo::should_include(&decode_pod, &config));
|
||||||
|
|
||||||
// Test unmatched pod should not be included
|
|
||||||
let unmatched_pod = create_pd_k8s_pod("other-pod", "10.0.0.3", "other", None);
|
let unmatched_pod = create_pd_k8s_pod("other-pod", "10.0.0.3", "other", None);
|
||||||
assert!(!PodInfo::should_include(&unmatched_pod, &config));
|
assert!(!PodInfo::should_include(&unmatched_pod, &config));
|
||||||
|
|
||||||
// Test regular mode
|
|
||||||
let mut regular_config = ServiceDiscoveryConfig::default();
|
let mut regular_config = ServiceDiscoveryConfig::default();
|
||||||
regular_config
|
regular_config
|
||||||
.selector
|
.selector
|
||||||
@@ -654,7 +607,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pod_type_enum() {
|
fn test_pod_type_enum() {
|
||||||
// Test that PodType enum has expected variants
|
|
||||||
let prefill = PodType::Prefill;
|
let prefill = PodType::Prefill;
|
||||||
let decode = PodType::Decode;
|
let decode = PodType::Decode;
|
||||||
let regular = PodType::Regular;
|
let regular = PodType::Regular;
|
||||||
@@ -714,7 +666,7 @@ mod tests {
|
|||||||
fn test_pod_info_from_pod_with_pd_config_regular_mode() {
|
fn test_pod_info_from_pod_with_pd_config_regular_mode() {
|
||||||
let k8s_pod = create_pd_k8s_pod("regular-pod", "10.0.0.3", "worker", None);
|
let k8s_pod = create_pd_k8s_pod("regular-pod", "10.0.0.3", "worker", None);
|
||||||
let mut config = create_pd_config();
|
let mut config = create_pd_config();
|
||||||
config.pd_mode = false; // Set to regular mode
|
config.pd_mode = false;
|
||||||
|
|
||||||
let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap();
|
let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap();
|
||||||
assert_eq!(pod_info.name, "regular-pod");
|
assert_eq!(pod_info.name, "regular-pod");
|
||||||
@@ -742,7 +694,6 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port() {
|
fn test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port() {
|
||||||
let mut pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", None);
|
let mut pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", None);
|
||||||
// Add invalid bootstrap port annotation
|
|
||||||
pod.metadata.annotations.as_mut().unwrap().insert(
|
pod.metadata.annotations.as_mut().unwrap().insert(
|
||||||
"sglang.ai/bootstrap-port".to_string(),
|
"sglang.ai/bootstrap-port".to_string(),
|
||||||
"invalid".to_string(),
|
"invalid".to_string(),
|
||||||
@@ -751,7 +702,7 @@ mod tests {
|
|||||||
|
|
||||||
let pod_info = PodInfo::from_pod(&pod, Some(&config)).unwrap();
|
let pod_info = PodInfo::from_pod(&pod, Some(&config)).unwrap();
|
||||||
assert_eq!(pod_info.pod_type, Some(PodType::Prefill));
|
assert_eq!(pod_info.pod_type, Some(PodType::Prefill));
|
||||||
assert!(pod_info.bootstrap_port.is_none()); // Should be None for invalid port
|
assert!(pod_info.bootstrap_port.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -1077,7 +1028,6 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Pod should not be tracked since add_worker_from_url will fail for non-running server
|
|
||||||
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user