[router] remove old/oudated/useless comments (#10967)
This commit is contained in:
@@ -67,58 +67,47 @@ struct Router {
|
||||
decode_policy: Option<PolicyType>,
|
||||
max_concurrent_requests: usize,
|
||||
cors_allowed_origins: Vec<String>,
|
||||
// Retry configuration
|
||||
retry_max_retries: u32,
|
||||
retry_initial_backoff_ms: u64,
|
||||
retry_max_backoff_ms: u64,
|
||||
retry_backoff_multiplier: f32,
|
||||
retry_jitter_factor: f32,
|
||||
disable_retries: bool,
|
||||
// Circuit breaker configuration
|
||||
cb_failure_threshold: u32,
|
||||
cb_success_threshold: u32,
|
||||
cb_timeout_duration_secs: u64,
|
||||
cb_window_duration_secs: u64,
|
||||
disable_circuit_breaker: bool,
|
||||
// Health check configuration
|
||||
health_failure_threshold: u32,
|
||||
health_success_threshold: u32,
|
||||
health_check_timeout_secs: u64,
|
||||
health_check_interval_secs: u64,
|
||||
health_check_endpoint: String,
|
||||
// IGW (Inference Gateway) configuration
|
||||
enable_igw: bool,
|
||||
queue_size: usize,
|
||||
queue_timeout_secs: u64,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
// Connection mode (determined from worker URLs)
|
||||
connection_mode: config::ConnectionMode,
|
||||
// Model path for tokenizer
|
||||
model_path: Option<String>,
|
||||
// Explicit tokenizer path
|
||||
tokenizer_path: Option<String>,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
/// Determine connection mode from worker URLs
|
||||
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
|
||||
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
|
||||
for url in worker_urls {
|
||||
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||
return config::ConnectionMode::Grpc;
|
||||
}
|
||||
}
|
||||
// Default to HTTP for all other cases (including http://, https://, or no scheme)
|
||||
config::ConnectionMode::Http
|
||||
}
|
||||
|
||||
/// Convert PyO3 Router to RouterConfig
|
||||
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
|
||||
use config::{
|
||||
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
|
||||
};
|
||||
|
||||
// Convert policy helper function
|
||||
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
|
||||
match policy {
|
||||
PolicyType::Random => ConfigPolicyConfig::Random,
|
||||
@@ -131,14 +120,12 @@ impl Router {
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 5, // Default value
|
||||
load_check_interval_secs: 5,
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
// Determine routing mode
|
||||
let mode = if self.enable_igw {
|
||||
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
}
|
||||
@@ -155,10 +142,8 @@ impl Router {
|
||||
}
|
||||
};
|
||||
|
||||
// Convert main policy
|
||||
let policy = convert_policy(&self.policy);
|
||||
|
||||
// Service discovery configuration
|
||||
let discovery = if self.service_discovery {
|
||||
Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
@@ -174,7 +159,6 @@ impl Router {
|
||||
None
|
||||
};
|
||||
|
||||
// Metrics configuration
|
||||
let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
|
||||
(Some(port), Some(host)) => Some(MetricsConfig {
|
||||
port,
|
||||
@@ -251,7 +235,7 @@ impl Router {
|
||||
balance_rel_threshold = 1.5,
|
||||
eviction_interval_secs = 120,
|
||||
max_tree_size = 2usize.pow(26),
|
||||
max_payload_size = 512 * 1024 * 1024, // 512MB default for large batches
|
||||
max_payload_size = 512 * 1024 * 1024,
|
||||
dp_aware = false,
|
||||
api_key = None,
|
||||
log_dir = None,
|
||||
@@ -265,40 +249,35 @@ impl Router {
|
||||
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
|
||||
prometheus_port = None,
|
||||
prometheus_host = None,
|
||||
request_timeout_secs = 1800, // Add configurable request timeout
|
||||
request_id_headers = None, // Custom request ID headers
|
||||
pd_disaggregation = false, // New flag for PD mode
|
||||
request_timeout_secs = 1800,
|
||||
request_id_headers = None,
|
||||
pd_disaggregation = false,
|
||||
prefill_urls = None,
|
||||
decode_urls = None,
|
||||
prefill_policy = None,
|
||||
decode_policy = None,
|
||||
max_concurrent_requests = 256,
|
||||
cors_allowed_origins = vec![],
|
||||
// Retry defaults
|
||||
retry_max_retries = 5,
|
||||
retry_initial_backoff_ms = 50,
|
||||
retry_max_backoff_ms = 30_000,
|
||||
retry_backoff_multiplier = 1.5,
|
||||
retry_jitter_factor = 0.2,
|
||||
disable_retries = false,
|
||||
// Circuit breaker defaults
|
||||
cb_failure_threshold = 10,
|
||||
cb_success_threshold = 3,
|
||||
cb_timeout_duration_secs = 60,
|
||||
cb_window_duration_secs = 120,
|
||||
disable_circuit_breaker = false,
|
||||
// Health check defaults
|
||||
health_failure_threshold = 3,
|
||||
health_success_threshold = 2,
|
||||
health_check_timeout_secs = 5,
|
||||
health_check_interval_secs = 60,
|
||||
health_check_endpoint = String::from("/health"),
|
||||
// IGW defaults
|
||||
enable_igw = false,
|
||||
queue_size = 100,
|
||||
queue_timeout_secs = 60,
|
||||
rate_limit_tokens_per_second = None,
|
||||
// Tokenizer defaults
|
||||
model_path = None,
|
||||
tokenizer_path = None,
|
||||
))]
|
||||
@@ -361,17 +340,14 @@ impl Router {
|
||||
model_path: Option<String>,
|
||||
tokenizer_path: Option<String>,
|
||||
) -> PyResult<Self> {
|
||||
// Determine connection mode from worker URLs
|
||||
let mut all_urls = worker_urls.clone();
|
||||
|
||||
// Add prefill URLs if in PD mode
|
||||
if let Some(ref prefill_urls) = prefill_urls {
|
||||
for (url, _) in prefill_urls {
|
||||
all_urls.push(url.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Add decode URLs if in PD mode
|
||||
if let Some(ref decode_urls) = decode_urls {
|
||||
all_urls.extend(decode_urls.clone());
|
||||
}
|
||||
@@ -440,12 +416,10 @@ impl Router {
|
||||
}
|
||||
|
||||
fn start(&self) -> PyResult<()> {
|
||||
// Convert to RouterConfig and validate
|
||||
let router_config = self.to_router_config().map_err(|e| {
|
||||
pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
|
||||
})?;
|
||||
|
||||
// Validate the configuration
|
||||
router_config.validate().map_err(|e| {
|
||||
pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Configuration validation failed: {}",
|
||||
@@ -453,7 +427,6 @@ impl Router {
|
||||
))
|
||||
})?;
|
||||
|
||||
// Create service discovery config if enabled
|
||||
let service_discovery_config = if self.service_discovery {
|
||||
Some(service_discovery::ServiceDiscoveryConfig {
|
||||
enabled: true,
|
||||
@@ -470,7 +443,6 @@ impl Router {
|
||||
None
|
||||
};
|
||||
|
||||
// Create Prometheus config if enabled
|
||||
let prometheus_config = Some(PrometheusConfig {
|
||||
port: self.prometheus_port.unwrap_or(29000),
|
||||
host: self
|
||||
@@ -479,11 +451,9 @@ impl Router {
|
||||
.unwrap_or_else(|| "127.0.0.1".to_string()),
|
||||
});
|
||||
|
||||
// Use tokio runtime instead of actix-web System for better compatibility
|
||||
let runtime = tokio::runtime::Runtime::new()
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
|
||||
|
||||
// Block on the async startup function
|
||||
runtime.block_on(async move {
|
||||
server::startup(server::ServerConfig {
|
||||
host: self.host.clone(),
|
||||
|
||||
@@ -8,20 +8,13 @@ use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{EnvFilter, Layer};
|
||||
|
||||
/// Configuration for the logging system
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoggingConfig {
|
||||
/// Log level for the application (default: INFO)
|
||||
pub level: Level,
|
||||
/// Whether to use json format for logs (default: false)
|
||||
pub json_format: bool,
|
||||
/// Path to store log files. If None, logs will only go to stdout/stderr
|
||||
pub log_dir: Option<String>,
|
||||
/// Whether to colorize logs when output is a terminal (default: true)
|
||||
pub colorize: bool,
|
||||
/// Log file name to use if log_dir is specified (default: "sgl-router")
|
||||
pub log_file_name: String,
|
||||
/// Custom log targets to filter (default: "sglang_router_rs")
|
||||
pub log_targets: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
@@ -38,30 +31,14 @@ impl Default for LoggingConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Guard that keeps the file appender worker thread alive
|
||||
///
|
||||
/// This must be kept in scope for the duration of the program
|
||||
/// to ensure logs are properly written to files
|
||||
#[allow(dead_code)]
|
||||
pub struct LogGuard {
|
||||
_file_guard: Option<WorkerGuard>,
|
||||
}
|
||||
|
||||
/// Initialize the logging system with the given configuration
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `config` - Configuration for the logging system
|
||||
///
|
||||
/// # Returns
|
||||
/// A LogGuard that must be kept alive for the duration of the program
|
||||
///
|
||||
/// # Panics
|
||||
/// Will not panic, as initialization errors are handled gracefully
|
||||
pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
||||
// Forward logs to tracing - ignore errors to allow for multiple initialization
|
||||
let _ = LogTracer::init();
|
||||
|
||||
// Convert log level to filter string
|
||||
let level_filter = match config.level {
|
||||
Level::TRACE => "trace",
|
||||
Level::DEBUG => "debug",
|
||||
@@ -70,9 +47,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
||||
Level::ERROR => "error",
|
||||
};
|
||||
|
||||
// Create env filter
|
||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
||||
// Format: <target>=<level>,<target2>=<level2>,...
|
||||
let filter_string = if let Some(targets) = &config.log_targets {
|
||||
targets
|
||||
.iter()
|
||||
@@ -92,13 +67,10 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
||||
EnvFilter::new(filter_string)
|
||||
});
|
||||
|
||||
// Setup stdout/stderr layer
|
||||
let mut layers = Vec::new();
|
||||
|
||||
// Standard timestamp format: YYYY-MM-DD HH:MM:SS
|
||||
let time_format = "%Y-%m-%d %H:%M:%S".to_string();
|
||||
|
||||
// Configure the console stdout layer
|
||||
let stdout_layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(config.colorize)
|
||||
.with_file(true)
|
||||
@@ -113,14 +85,12 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
||||
|
||||
layers.push(stdout_layer);
|
||||
|
||||
// Create a file appender if log_dir is specified
|
||||
let mut file_guard = None;
|
||||
|
||||
if let Some(log_dir) = &config.log_dir {
|
||||
let file_name = config.log_file_name.clone();
|
||||
let log_dir = PathBuf::from(log_dir);
|
||||
|
||||
// Create log directory if it doesn't exist
|
||||
if !log_dir.exists() {
|
||||
if let Err(e) = std::fs::create_dir_all(&log_dir) {
|
||||
eprintln!("Failed to create log directory: {}", e);
|
||||
@@ -134,7 +104,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
||||
file_guard = Some(guard);
|
||||
|
||||
let file_layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false) // Never use ANSI colors in log files
|
||||
.with_ansi(false)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.with_timer(ChronoUtc::new(time_format))
|
||||
@@ -149,14 +119,11 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
||||
layers.push(file_layer);
|
||||
}
|
||||
|
||||
// Initialize the subscriber with all layers
|
||||
// Use try_init to handle errors gracefully in case another subscriber is already set
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(layers)
|
||||
.try_init();
|
||||
|
||||
// Return the guard to keep the file appender worker thread alive
|
||||
LogGuard {
|
||||
_file_guard: file_guard,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ use sglang_router_rs::server::{self, ServerConfig};
|
||||
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Helper function to parse prefill arguments from command line
|
||||
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let mut prefill_entries = Vec::new();
|
||||
@@ -19,12 +18,11 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
||||
if args[i] == "--prefill" && i + 1 < args.len() {
|
||||
let url = args[i + 1].clone();
|
||||
let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") {
|
||||
// Check if next arg is a port number
|
||||
if let Ok(port) = args[i + 2].parse::<u16>() {
|
||||
i += 1; // Skip the port argument
|
||||
i += 1;
|
||||
Some(port)
|
||||
} else if args[i + 2].to_lowercase() == "none" {
|
||||
i += 1; // Skip the "none" argument
|
||||
i += 1;
|
||||
None
|
||||
} else {
|
||||
None
|
||||
@@ -33,7 +31,7 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
||||
None
|
||||
};
|
||||
prefill_entries.push((url, bootstrap_port));
|
||||
i += 2; // Skip --prefill and URL
|
||||
i += 2;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
@@ -101,252 +99,186 @@ Examples:
|
||||
|
||||
"#)]
|
||||
struct CliArgs {
|
||||
/// Host address to bind the router server
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
|
||||
/// Port number to bind the router server
|
||||
#[arg(long, default_value_t = 30000)]
|
||||
port: u16,
|
||||
|
||||
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
|
||||
#[arg(long, num_args = 0..)]
|
||||
worker_urls: Vec<String>,
|
||||
|
||||
/// Load balancing policy to use
|
||||
#[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||
policy: String,
|
||||
|
||||
/// Enable PD (Prefill-Decode) disaggregated mode
|
||||
#[arg(long, default_value_t = false)]
|
||||
pd_disaggregation: bool,
|
||||
|
||||
/// Decode server URL (can be specified multiple times)
|
||||
#[arg(long, action = ArgAction::Append)]
|
||||
decode: Vec<String>,
|
||||
|
||||
/// Specific policy for prefill nodes in PD mode
|
||||
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||
prefill_policy: Option<String>,
|
||||
|
||||
/// Specific policy for decode nodes in PD mode
|
||||
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||
decode_policy: Option<String>,
|
||||
|
||||
/// Timeout in seconds for worker startup
|
||||
#[arg(long, default_value_t = 600)]
|
||||
worker_startup_timeout_secs: u64,
|
||||
|
||||
/// Interval in seconds between checks for worker startup
|
||||
#[arg(long, default_value_t = 30)]
|
||||
worker_startup_check_interval: u64,
|
||||
|
||||
/// Cache threshold (0.0-1.0) for cache-aware routing
|
||||
#[arg(long, default_value_t = 0.3)]
|
||||
cache_threshold: f32,
|
||||
|
||||
/// Absolute threshold for load balancing
|
||||
#[arg(long, default_value_t = 64)]
|
||||
balance_abs_threshold: usize,
|
||||
|
||||
/// Relative threshold for load balancing
|
||||
#[arg(long, default_value_t = 1.5)]
|
||||
balance_rel_threshold: f32,
|
||||
|
||||
/// Interval in seconds between cache eviction operations
|
||||
#[arg(long, default_value_t = 120)]
|
||||
eviction_interval: u64,
|
||||
|
||||
/// Maximum size of the approximation tree for cache-aware routing
|
||||
#[arg(long, default_value_t = 67108864)] // 2^26
|
||||
#[arg(long, default_value_t = 67108864)]
|
||||
max_tree_size: usize,
|
||||
|
||||
/// Maximum payload size in bytes
|
||||
#[arg(long, default_value_t = 536870912)] // 512MB
|
||||
#[arg(long, default_value_t = 536870912)]
|
||||
max_payload_size: usize,
|
||||
|
||||
/// Enable data parallelism aware schedule
|
||||
#[arg(long, default_value_t = false)]
|
||||
dp_aware: bool,
|
||||
|
||||
/// API key for worker authorization
|
||||
#[arg(long)]
|
||||
api_key: Option<String>,
|
||||
|
||||
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
|
||||
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
|
||||
backend: Backend,
|
||||
|
||||
/// Directory to store log files
|
||||
#[arg(long)]
|
||||
log_dir: Option<String>,
|
||||
|
||||
/// Set the logging level
|
||||
#[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])]
|
||||
log_level: String,
|
||||
|
||||
/// Enable Kubernetes service discovery
|
||||
#[arg(long, default_value_t = false)]
|
||||
service_discovery: bool,
|
||||
|
||||
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
|
||||
#[arg(long, num_args = 0..)]
|
||||
selector: Vec<String>,
|
||||
|
||||
/// Port to use for discovered worker pods
|
||||
#[arg(long, default_value_t = 80)]
|
||||
service_discovery_port: u16,
|
||||
|
||||
/// Kubernetes namespace to watch for pods
|
||||
#[arg(long)]
|
||||
service_discovery_namespace: Option<String>,
|
||||
|
||||
/// Label selector for prefill server pods in PD mode
|
||||
#[arg(long, num_args = 0..)]
|
||||
prefill_selector: Vec<String>,
|
||||
|
||||
/// Label selector for decode server pods in PD mode
|
||||
#[arg(long, num_args = 0..)]
|
||||
decode_selector: Vec<String>,
|
||||
|
||||
/// Port to expose Prometheus metrics
|
||||
#[arg(long, default_value_t = 29000)]
|
||||
prometheus_port: u16,
|
||||
|
||||
/// Host address to bind the Prometheus metrics server
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
prometheus_host: String,
|
||||
|
||||
/// Custom HTTP headers to check for request IDs
|
||||
#[arg(long, num_args = 0..)]
|
||||
request_id_headers: Vec<String>,
|
||||
|
||||
/// Request timeout in seconds
|
||||
#[arg(long, default_value_t = 1800)]
|
||||
request_timeout_secs: u64,
|
||||
|
||||
/// Maximum number of concurrent requests allowed
|
||||
#[arg(long, default_value_t = 256)]
|
||||
max_concurrent_requests: usize,
|
||||
|
||||
/// CORS allowed origins
|
||||
#[arg(long, num_args = 0..)]
|
||||
cors_allowed_origins: Vec<String>,
|
||||
|
||||
// Retry configuration
|
||||
/// Maximum number of retries
|
||||
#[arg(long, default_value_t = 5)]
|
||||
retry_max_retries: u32,
|
||||
|
||||
/// Initial backoff in milliseconds for retries
|
||||
#[arg(long, default_value_t = 50)]
|
||||
retry_initial_backoff_ms: u64,
|
||||
|
||||
/// Maximum backoff in milliseconds for retries
|
||||
#[arg(long, default_value_t = 30000)]
|
||||
retry_max_backoff_ms: u64,
|
||||
|
||||
/// Backoff multiplier for exponential backoff
|
||||
#[arg(long, default_value_t = 1.5)]
|
||||
retry_backoff_multiplier: f32,
|
||||
|
||||
/// Jitter factor for retry backoff
|
||||
#[arg(long, default_value_t = 0.2)]
|
||||
retry_jitter_factor: f32,
|
||||
|
||||
/// Disable retries
|
||||
#[arg(long, default_value_t = false)]
|
||||
disable_retries: bool,
|
||||
|
||||
// Circuit breaker configuration
|
||||
/// Number of failures before circuit breaker opens
|
||||
#[arg(long, default_value_t = 10)]
|
||||
cb_failure_threshold: u32,
|
||||
|
||||
/// Number of successes before circuit breaker closes
|
||||
#[arg(long, default_value_t = 3)]
|
||||
cb_success_threshold: u32,
|
||||
|
||||
/// Timeout duration in seconds for circuit breaker
|
||||
#[arg(long, default_value_t = 60)]
|
||||
cb_timeout_duration_secs: u64,
|
||||
|
||||
/// Window duration in seconds for circuit breaker
|
||||
#[arg(long, default_value_t = 120)]
|
||||
cb_window_duration_secs: u64,
|
||||
|
||||
/// Disable circuit breaker
|
||||
#[arg(long, default_value_t = false)]
|
||||
disable_circuit_breaker: bool,
|
||||
|
||||
// Health check configuration
|
||||
/// Number of consecutive health check failures before marking worker unhealthy
|
||||
#[arg(long, default_value_t = 3)]
|
||||
health_failure_threshold: u32,
|
||||
|
||||
/// Number of consecutive health check successes before marking worker healthy
|
||||
#[arg(long, default_value_t = 2)]
|
||||
health_success_threshold: u32,
|
||||
|
||||
/// Timeout in seconds for health check requests
|
||||
#[arg(long, default_value_t = 5)]
|
||||
health_check_timeout_secs: u64,
|
||||
|
||||
/// Interval in seconds between runtime health checks
|
||||
#[arg(long, default_value_t = 60)]
|
||||
health_check_interval_secs: u64,
|
||||
|
||||
/// Health check endpoint path
|
||||
#[arg(long, default_value = "/health")]
|
||||
health_check_endpoint: String,
|
||||
|
||||
// IGW (Inference Gateway) configuration
|
||||
/// Enable Inference Gateway mode
|
||||
#[arg(long, default_value_t = false)]
|
||||
enable_igw: bool,
|
||||
|
||||
// Tokenizer configuration
|
||||
/// Model path for loading tokenizer (HuggingFace model ID or local path)
|
||||
#[arg(long)]
|
||||
model_path: Option<String>,
|
||||
|
||||
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
|
||||
#[arg(long)]
|
||||
tokenizer_path: Option<String>,
|
||||
|
||||
/// History backend configuration (memory, none, or oracle)
|
||||
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
|
||||
history_backend: String,
|
||||
|
||||
/// Directory containing the Oracle ATP wallet/config files (optional)
|
||||
#[arg(long, env = "ATP_WALLET_PATH")]
|
||||
oracle_wallet_path: Option<String>,
|
||||
|
||||
/// Wallet TNS alias to use (e.g. `<db_name>_low`)
|
||||
#[arg(long, env = "ATP_TNS_ALIAS")]
|
||||
oracle_tns_alias: Option<String>,
|
||||
|
||||
/// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`)
|
||||
#[arg(long, env = "ATP_DSN")]
|
||||
oracle_dsn: Option<String>,
|
||||
|
||||
/// Oracle ATP username
|
||||
#[arg(long, env = "ATP_USER")]
|
||||
oracle_user: Option<String>,
|
||||
|
||||
/// Oracle ATP password
|
||||
#[arg(long, env = "ATP_PASSWORD")]
|
||||
oracle_password: Option<String>,
|
||||
|
||||
/// Minimum number of pooled ATP connections (defaults to 1 when omitted)
|
||||
#[arg(long, env = "ATP_POOL_MIN")]
|
||||
oracle_pool_min: Option<usize>,
|
||||
|
||||
/// Maximum number of pooled ATP connections (defaults to 16 when omitted)
|
||||
#[arg(long, env = "ATP_POOL_MAX")]
|
||||
oracle_pool_max: Option<usize>,
|
||||
|
||||
/// Connection acquisition timeout in seconds (defaults to 30 when omitted)
|
||||
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
|
||||
oracle_pool_timeout_secs: Option<u64>,
|
||||
}
|
||||
@@ -357,19 +289,15 @@ enum OracleConnectSource {
|
||||
}
|
||||
|
||||
impl CliArgs {
|
||||
/// Determine connection mode from worker URLs
|
||||
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
|
||||
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
|
||||
for url in worker_urls {
|
||||
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||
return ConnectionMode::Grpc;
|
||||
}
|
||||
}
|
||||
// Default to HTTP for all other cases (including http://, https://, or no scheme)
|
||||
ConnectionMode::Http
|
||||
}
|
||||
|
||||
/// Parse selector strings into HashMap
|
||||
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
|
||||
let mut map = HashMap::new();
|
||||
for item in selector_list {
|
||||
@@ -382,7 +310,6 @@ impl CliArgs {
|
||||
map
|
||||
}
|
||||
|
||||
/// Convert policy string to PolicyConfig
|
||||
fn parse_policy(&self, policy_str: &str) -> PolicyConfig {
|
||||
match policy_str {
|
||||
"random" => PolicyConfig::Random,
|
||||
@@ -395,9 +322,9 @@ impl CliArgs {
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
"power_of_two" => PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 5, // Default value
|
||||
load_check_interval_secs: 5,
|
||||
},
|
||||
_ => PolicyConfig::RoundRobin, // Fallback
|
||||
_ => PolicyConfig::RoundRobin,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -482,26 +409,21 @@ impl CliArgs {
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert CLI arguments to RouterConfig
|
||||
fn to_router_config(
|
||||
&self,
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
) -> ConfigResult<RouterConfig> {
|
||||
// Determine routing mode
|
||||
let mode = if self.enable_igw {
|
||||
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
}
|
||||
} else if matches!(self.backend, Backend::Openai) {
|
||||
// OpenAI backend mode - use worker_urls as base(s)
|
||||
RoutingMode::OpenAI {
|
||||
worker_urls: self.worker_urls.clone(),
|
||||
}
|
||||
} else if self.pd_disaggregation {
|
||||
let decode_urls = self.decode.clone();
|
||||
|
||||
// Validate PD configuration if not using service discovery
|
||||
if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(),
|
||||
@@ -515,7 +437,6 @@ impl CliArgs {
|
||||
decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)),
|
||||
}
|
||||
} else {
|
||||
// Regular mode
|
||||
if !self.service_discovery && self.worker_urls.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "Regular mode requires --worker-urls when not using service discovery"
|
||||
@@ -527,10 +448,8 @@ impl CliArgs {
|
||||
}
|
||||
};
|
||||
|
||||
// Main policy
|
||||
let policy = self.parse_policy(&self.policy);
|
||||
|
||||
// Service discovery configuration
|
||||
let discovery = if self.service_discovery {
|
||||
Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
@@ -546,13 +465,11 @@ impl CliArgs {
|
||||
None
|
||||
};
|
||||
|
||||
// Metrics configuration
|
||||
let metrics = Some(MetricsConfig {
|
||||
port: self.prometheus_port,
|
||||
host: self.prometheus_host.clone(),
|
||||
});
|
||||
|
||||
// Determine connection mode from all worker URLs
|
||||
let mut all_urls = Vec::new();
|
||||
match &mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
@@ -568,9 +485,7 @@ impl CliArgs {
|
||||
}
|
||||
all_urls.extend(decode_urls.clone());
|
||||
}
|
||||
RoutingMode::OpenAI { .. } => {
|
||||
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
|
||||
}
|
||||
RoutingMode::OpenAI { .. } => {}
|
||||
}
|
||||
let connection_mode = match &mode {
|
||||
RoutingMode::OpenAI { .. } => ConnectionMode::Http,
|
||||
@@ -589,7 +504,6 @@ impl CliArgs {
|
||||
None
|
||||
};
|
||||
|
||||
// Build RouterConfig
|
||||
Ok(RouterConfig {
|
||||
mode,
|
||||
policy,
|
||||
@@ -612,8 +526,8 @@ impl CliArgs {
|
||||
Some(self.request_id_headers.clone())
|
||||
},
|
||||
max_concurrent_requests: self.max_concurrent_requests,
|
||||
queue_size: 100, // Default queue size
|
||||
queue_timeout_secs: 60, // Default timeout
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||
retry: RetryConfig {
|
||||
max_retries: self.retry_max_retries,
|
||||
@@ -646,9 +560,7 @@ impl CliArgs {
|
||||
})
|
||||
}
|
||||
|
||||
/// Create ServerConfig from CLI args and RouterConfig
|
||||
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
|
||||
// Create service discovery config if enabled
|
||||
let service_discovery_config = if self.service_discovery {
|
||||
Some(ServiceDiscoveryConfig {
|
||||
enabled: true,
|
||||
@@ -665,7 +577,6 @@ impl CliArgs {
|
||||
None
|
||||
};
|
||||
|
||||
// Create Prometheus config
|
||||
let prometheus_config = Some(PrometheusConfig {
|
||||
port: self.prometheus_port,
|
||||
host: self.prometheus_host.clone(),
|
||||
@@ -691,19 +602,15 @@ impl CliArgs {
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Parse prefill arguments manually before clap parsing
|
||||
let prefill_urls = parse_prefill_args();
|
||||
|
||||
// Filter out prefill arguments and their values before passing to clap
|
||||
let mut filtered_args: Vec<String> = Vec::new();
|
||||
let raw_args: Vec<String> = std::env::args().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < raw_args.len() {
|
||||
if raw_args[i] == "--prefill" && i + 1 < raw_args.len() {
|
||||
// Skip --prefill and its URL
|
||||
i += 2;
|
||||
// Also skip bootstrap port if present
|
||||
if i < raw_args.len()
|
||||
&& !raw_args[i].starts_with("--")
|
||||
&& (raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none")
|
||||
@@ -716,10 +623,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
}
|
||||
|
||||
// Parse CLI arguments with clap using filtered args
|
||||
let cli_args = CliArgs::parse_from(filtered_args);
|
||||
|
||||
// Print startup info
|
||||
println!("SGLang Router starting...");
|
||||
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
||||
let mode_str = if cli_args.enable_igw {
|
||||
@@ -733,7 +638,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
};
|
||||
println!("Mode: {}", mode_str);
|
||||
|
||||
// Warn for runtimes that are parsed but not yet implemented
|
||||
match cli_args.backend {
|
||||
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
|
||||
println!(
|
||||
@@ -754,19 +658,10 @@ Provide --worker-urls or PD flags as usual.",
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to RouterConfig
|
||||
let router_config = cli_args.to_router_config(prefill_urls)?;
|
||||
|
||||
// Validate configuration
|
||||
router_config.validate()?;
|
||||
|
||||
// Create ServerConfig
|
||||
let server_config = cli_args.to_server_config(router_config);
|
||||
|
||||
// Create a new runtime for the server (like Python binding does)
|
||||
let runtime = tokio::runtime::Runtime::new()?;
|
||||
|
||||
// Block on the async startup function
|
||||
runtime.block_on(async move { server::startup(server_config).await })?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -19,7 +19,6 @@ impl Default for PrometheusConfig {
|
||||
}
|
||||
|
||||
pub fn init_metrics() {
|
||||
// Request metrics
|
||||
describe_counter!(
|
||||
"sgl_router_requests_total",
|
||||
"Total number of requests by route and method"
|
||||
@@ -45,7 +44,6 @@ pub fn init_metrics() {
|
||||
"Total number of requests that exhausted retries by route"
|
||||
);
|
||||
|
||||
// Circuit breaker metrics
|
||||
describe_gauge!(
|
||||
"sgl_router_cb_state",
|
||||
"Circuit breaker state per worker (0=closed, 1=open, 2=half_open)"
|
||||
@@ -59,7 +57,6 @@ pub fn init_metrics() {
|
||||
"Total number of circuit breaker outcomes by worker and outcome type (success/failure)"
|
||||
);
|
||||
|
||||
// Worker metrics
|
||||
describe_gauge!(
|
||||
"sgl_router_active_workers",
|
||||
"Number of currently active workers"
|
||||
@@ -74,7 +71,6 @@ pub fn init_metrics() {
|
||||
"Total requests processed by each worker"
|
||||
);
|
||||
|
||||
// Policy metrics
|
||||
describe_counter!(
|
||||
"sgl_router_policy_decisions_total",
|
||||
"Total routing policy decisions by policy and worker"
|
||||
@@ -92,7 +88,6 @@ pub fn init_metrics() {
|
||||
describe_gauge!("sgl_router_max_load", "Maximum worker load");
|
||||
describe_gauge!("sgl_router_min_load", "Minimum worker load");
|
||||
|
||||
// PD-specific metrics
|
||||
describe_counter!("sgl_router_pd_requests_total", "Total PD requests by route");
|
||||
describe_counter!(
|
||||
"sgl_router_pd_prefill_requests_total",
|
||||
@@ -123,7 +118,6 @@ pub fn init_metrics() {
|
||||
"PD request duration by route"
|
||||
);
|
||||
|
||||
// Service discovery metrics
|
||||
describe_counter!(
|
||||
"sgl_router_discovery_updates_total",
|
||||
"Total service discovery update events"
|
||||
@@ -137,13 +131,11 @@ pub fn init_metrics() {
|
||||
"Number of workers removed in last discovery update"
|
||||
);
|
||||
|
||||
// Generate request specific metrics
|
||||
describe_histogram!(
|
||||
"sgl_router_generate_duration_seconds",
|
||||
"Generate request duration"
|
||||
);
|
||||
|
||||
// Embedding request specific metrics
|
||||
describe_counter!("sgl_router_embeddings_total", "Total embedding requests");
|
||||
describe_histogram!(
|
||||
"sgl_router_embeddings_duration_seconds",
|
||||
@@ -155,13 +147,11 @@ pub fn init_metrics() {
|
||||
);
|
||||
describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size");
|
||||
|
||||
// Running requests gauge for cache-aware policy
|
||||
describe_gauge!(
|
||||
"sgl_router_running_requests",
|
||||
"Number of running requests per worker"
|
||||
);
|
||||
|
||||
// Tokenizer metrics
|
||||
describe_histogram!(
|
||||
"sgl_tokenizer_encode_duration_seconds",
|
||||
"Time to encode text to tokens"
|
||||
@@ -207,7 +197,6 @@ pub fn init_metrics() {
|
||||
"Vocabulary size of the loaded tokenizer"
|
||||
);
|
||||
|
||||
// Stop sequence detection metrics
|
||||
describe_counter!(
|
||||
"sgl_tokenizer_stop_sequences_detected_total",
|
||||
"Total stop sequences detected by type"
|
||||
@@ -221,7 +210,6 @@ pub fn init_metrics() {
|
||||
"Time to check for stop sequences per token"
|
||||
);
|
||||
|
||||
// Streaming decode metrics
|
||||
describe_counter!(
|
||||
"sgl_tokenizer_stream_tokens_total",
|
||||
"Total tokens processed in streaming decode"
|
||||
@@ -235,7 +223,6 @@ pub fn init_metrics() {
|
||||
"Time per streaming decode step"
|
||||
);
|
||||
|
||||
// Factory metrics
|
||||
describe_counter!(
|
||||
"sgl_tokenizer_factory_loads_total",
|
||||
"Total tokenizer loads by file type"
|
||||
@@ -251,7 +238,6 @@ pub fn init_metrics() {
|
||||
}
|
||||
|
||||
pub fn start_prometheus(config: PrometheusConfig) {
|
||||
// Initialize metric descriptions
|
||||
init_metrics();
|
||||
|
||||
let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
|
||||
@@ -280,7 +266,6 @@ pub struct RouterMetrics;
|
||||
pub struct TokenizerMetrics;
|
||||
|
||||
impl RouterMetrics {
|
||||
// Request metrics
|
||||
pub fn record_request(route: &str) {
|
||||
counter!("sgl_router_requests_total",
|
||||
"route" => route.to_string()
|
||||
@@ -324,7 +309,6 @@ impl RouterMetrics {
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// Worker metrics
|
||||
pub fn set_active_workers(count: usize) {
|
||||
gauge!("sgl_router_active_workers").set(count as f64);
|
||||
}
|
||||
@@ -350,7 +334,6 @@ impl RouterMetrics {
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// Policy metrics
|
||||
pub fn record_policy_decision(policy: &str, worker: &str) {
|
||||
counter!("sgl_router_policy_decisions_total",
|
||||
"policy" => policy.to_string(),
|
||||
@@ -383,7 +366,6 @@ impl RouterMetrics {
|
||||
gauge!("sgl_router_min_load").set(min_load as f64);
|
||||
}
|
||||
|
||||
// PD-specific metrics
|
||||
pub fn record_pd_request(route: &str) {
|
||||
counter!("sgl_router_pd_requests_total",
|
||||
"route" => route.to_string()
|
||||
@@ -440,19 +422,16 @@ impl RouterMetrics {
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// Service discovery metrics
|
||||
pub fn record_discovery_update(added: usize, removed: usize) {
|
||||
counter!("sgl_router_discovery_updates_total").increment(1);
|
||||
gauge!("sgl_router_discovery_workers_added").set(added as f64);
|
||||
gauge!("sgl_router_discovery_workers_removed").set(removed as f64);
|
||||
}
|
||||
|
||||
// Generate request metrics
|
||||
pub fn record_generate_duration(duration: Duration) {
|
||||
histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
// Embeddings metrics
|
||||
pub fn record_embeddings_request() {
|
||||
counter!("sgl_router_embeddings_total").increment(1);
|
||||
}
|
||||
@@ -473,7 +452,6 @@ impl RouterMetrics {
|
||||
gauge!("sgl_router_embeddings_queue_size").set(size as f64);
|
||||
}
|
||||
|
||||
// Running requests for cache-aware policy
|
||||
pub fn set_running_requests(worker: &str, count: usize) {
|
||||
gauge!("sgl_router_running_requests",
|
||||
"worker" => worker.to_string()
|
||||
@@ -481,7 +459,6 @@ impl RouterMetrics {
|
||||
.set(count as f64);
|
||||
}
|
||||
|
||||
// Circuit breaker metrics
|
||||
pub fn set_cb_state(worker: &str, state_code: u8) {
|
||||
gauge!("sgl_router_cb_state",
|
||||
"worker" => worker.to_string()
|
||||
@@ -508,7 +485,6 @@ impl RouterMetrics {
|
||||
}
|
||||
|
||||
impl TokenizerMetrics {
|
||||
// Encoding metrics
|
||||
pub fn record_encode_request(tokenizer_type: &str) {
|
||||
counter!("sgl_tokenizer_encode_requests_total",
|
||||
"tokenizer_type" => tokenizer_type.to_string()
|
||||
@@ -535,7 +511,6 @@ impl TokenizerMetrics {
|
||||
histogram!("sgl_tokenizer_chars_per_encode").record(char_count as f64);
|
||||
}
|
||||
|
||||
// Decoding metrics
|
||||
pub fn record_decode_request(tokenizer_type: &str) {
|
||||
counter!("sgl_tokenizer_decode_requests_total",
|
||||
"tokenizer_type" => tokenizer_type.to_string()
|
||||
@@ -558,7 +533,6 @@ impl TokenizerMetrics {
|
||||
histogram!("sgl_tokenizer_tokens_per_decode").record(token_count as f64);
|
||||
}
|
||||
|
||||
// Batch encoding metrics
|
||||
pub fn record_encode_batch_duration(duration: Duration, batch_size: usize) {
|
||||
histogram!("sgl_tokenizer_encode_batch_duration_seconds",
|
||||
"batch_size" => batch_size.to_string()
|
||||
@@ -566,7 +540,6 @@ impl TokenizerMetrics {
|
||||
.record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
// Stop sequence detection metrics
|
||||
pub fn record_stop_sequence_detected(stop_type: &str) {
|
||||
counter!("sgl_tokenizer_stop_sequences_detected_total",
|
||||
"type" => stop_type.to_string()
|
||||
@@ -582,7 +555,6 @@ impl TokenizerMetrics {
|
||||
histogram!("sgl_tokenizer_stop_detection_duration_seconds").record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
// Streaming decode metrics
|
||||
pub fn record_stream_token() {
|
||||
counter!("sgl_tokenizer_stream_tokens_total").increment(1);
|
||||
}
|
||||
@@ -595,7 +567,6 @@ impl TokenizerMetrics {
|
||||
histogram!("sgl_tokenizer_stream_step_duration_seconds").record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
// Factory metrics
|
||||
pub fn record_factory_load(file_type: &str) {
|
||||
counter!("sgl_tokenizer_factory_loads_total",
|
||||
"file_type" => file_type.to_string()
|
||||
@@ -614,7 +585,6 @@ impl TokenizerMetrics {
|
||||
histogram!("sgl_tokenizer_factory_load_duration_seconds").record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
// Vocabulary metrics
|
||||
pub fn set_vocab_size(tokenizer_type: &str, size: usize) {
|
||||
gauge!("sgl_tokenizer_vocab_size",
|
||||
"tokenizer_type" => tokenizer_type.to_string()
|
||||
@@ -705,7 +675,6 @@ mod tests {
|
||||
.parse()
|
||||
.unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
|
||||
|
||||
// Should fall back to 0.0.0.0
|
||||
assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
|
||||
}
|
||||
}
|
||||
@@ -780,7 +749,6 @@ mod tests {
|
||||
fn test_duration_suffix_matcher() {
|
||||
let matcher = Matcher::Suffix(String::from("duration_seconds"));
|
||||
|
||||
// Test matching behavior
|
||||
let _matching_metrics = [
|
||||
"request_duration_seconds",
|
||||
"response_duration_seconds",
|
||||
@@ -789,8 +757,6 @@ mod tests {
|
||||
|
||||
let _non_matching_metrics = ["duration_total", "duration_seconds_total", "other_metric"];
|
||||
|
||||
// Note: We can't directly test Matcher matching without the internals,
|
||||
// but we can verify the matcher is created correctly
|
||||
match matcher {
|
||||
Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"),
|
||||
_ => panic!("Expected Suffix matcher"),
|
||||
@@ -801,7 +767,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_prometheus_builder_configuration() {
|
||||
// This test verifies the builder configuration without actually starting Prometheus
|
||||
let _config = PrometheusConfig::default();
|
||||
|
||||
let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
|
||||
@@ -810,10 +775,8 @@ mod tests {
|
||||
60.0, 90.0, 120.0, 180.0, 240.0,
|
||||
];
|
||||
|
||||
// Verify bucket configuration
|
||||
assert_eq!(duration_bucket.len(), 20);
|
||||
|
||||
// Verify matcher is suffix type
|
||||
match duration_matcher {
|
||||
Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"),
|
||||
_ => panic!("Expected Suffix matcher"),
|
||||
@@ -832,14 +795,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_custom_buckets_for_different_metrics() {
|
||||
// Test that we can create different bucket configurations
|
||||
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
|
||||
let generate_buckets = [0.1, 0.5, 1.0, 5.0, 30.0, 60.0];
|
||||
|
||||
assert_eq!(request_buckets.len(), 5);
|
||||
assert_eq!(generate_buckets.len(), 6);
|
||||
|
||||
// Verify each set is sorted
|
||||
for i in 1..request_buckets.len() {
|
||||
assert!(request_buckets[i] > request_buckets[i - 1]);
|
||||
}
|
||||
@@ -853,7 +814,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_metrics_static_methods() {
|
||||
// Test that all static methods can be called without panic
|
||||
RouterMetrics::record_request("/generate");
|
||||
RouterMetrics::record_request_duration("/generate", Duration::from_millis(100));
|
||||
RouterMetrics::record_request_error("/generate", "timeout");
|
||||
@@ -887,41 +847,32 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_metrics_static_methods() {
|
||||
// Test that all tokenizer metric methods can be called without panic
|
||||
|
||||
// Encoding metrics
|
||||
TokenizerMetrics::record_encode_request("huggingface");
|
||||
TokenizerMetrics::record_encode_duration(Duration::from_millis(10));
|
||||
TokenizerMetrics::record_encode_error("invalid_input");
|
||||
TokenizerMetrics::record_tokens_per_encode(100);
|
||||
TokenizerMetrics::record_chars_per_encode(500);
|
||||
|
||||
// Decoding metrics
|
||||
TokenizerMetrics::record_decode_request("huggingface");
|
||||
TokenizerMetrics::record_decode_duration(Duration::from_millis(5));
|
||||
TokenizerMetrics::record_decode_error("invalid_tokens");
|
||||
TokenizerMetrics::record_tokens_per_decode(50);
|
||||
|
||||
// Batch encoding
|
||||
TokenizerMetrics::record_encode_batch_duration(Duration::from_millis(100), 10);
|
||||
|
||||
// Stop sequence detection
|
||||
TokenizerMetrics::record_stop_sequence_detected("token");
|
||||
TokenizerMetrics::record_stop_sequence_detected("string");
|
||||
TokenizerMetrics::record_partial_match();
|
||||
TokenizerMetrics::record_stop_detection_duration(Duration::from_micros(100));
|
||||
|
||||
// Streaming decode
|
||||
TokenizerMetrics::record_stream_token();
|
||||
TokenizerMetrics::record_incomplete_utf8();
|
||||
TokenizerMetrics::record_stream_step_duration(Duration::from_micros(50));
|
||||
|
||||
// Factory metrics
|
||||
TokenizerMetrics::record_factory_load("json");
|
||||
TokenizerMetrics::record_factory_error("unsupported_format");
|
||||
TokenizerMetrics::record_factory_load_duration(Duration::from_millis(200));
|
||||
|
||||
// Vocabulary metrics
|
||||
TokenizerMetrics::set_vocab_size("huggingface", 50000);
|
||||
}
|
||||
|
||||
@@ -929,17 +880,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_port_already_in_use() {
|
||||
// Skip this test if we can't bind to the port
|
||||
let port = 29123; // Use a different port to avoid conflicts
|
||||
let port = 29123;
|
||||
|
||||
if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) {
|
||||
// Port is available, we can test
|
||||
let config = PrometheusConfig {
|
||||
port,
|
||||
host: "127.0.0.1".to_string(),
|
||||
};
|
||||
|
||||
// Just verify config is created correctly
|
||||
assert_eq!(config.port, port);
|
||||
}
|
||||
}
|
||||
@@ -948,8 +896,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_metrics_endpoint_accessibility() {
|
||||
// This would be an integration test in practice
|
||||
// Here we just verify the configuration
|
||||
let config = PrometheusConfig {
|
||||
port: 29000,
|
||||
host: "127.0.0.1".to_string(),
|
||||
@@ -963,7 +909,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_metric_updates() {
|
||||
// Test that metric updates can be called concurrently
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
@@ -984,11 +929,9 @@ mod tests {
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Let threads run briefly
|
||||
thread::sleep(Duration::from_millis(10));
|
||||
done.store(true, Ordering::Relaxed);
|
||||
|
||||
// Wait for all threads
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
@@ -998,7 +941,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_empty_string_metrics() {
|
||||
// Test that empty strings don't cause issues
|
||||
RouterMetrics::record_request("");
|
||||
RouterMetrics::set_worker_health("", true);
|
||||
RouterMetrics::record_policy_decision("", "");
|
||||
@@ -1030,7 +972,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_extreme_metric_values() {
|
||||
// Test extreme values
|
||||
RouterMetrics::set_active_workers(0);
|
||||
RouterMetrics::set_active_workers(usize::MAX);
|
||||
|
||||
@@ -1038,7 +979,6 @@ mod tests {
|
||||
RouterMetrics::set_worker_load("worker", usize::MAX);
|
||||
|
||||
RouterMetrics::record_request_duration("route", Duration::from_nanos(1));
|
||||
// 24 hours
|
||||
RouterMetrics::record_request_duration("route", Duration::from_secs(86400));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ use tokio::task;
|
||||
use tokio::time;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Represents the service discovery configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServiceDiscoveryConfig {
|
||||
pub enabled: bool,
|
||||
@@ -41,8 +40,8 @@ impl Default for ServiceDiscoveryConfig {
|
||||
enabled: false,
|
||||
selector: HashMap::new(),
|
||||
check_interval: Duration::from_secs(60),
|
||||
port: 8000, // Standard port for modern services
|
||||
namespace: None, // None means watch all namespaces
|
||||
port: 8000,
|
||||
namespace: None,
|
||||
pd_mode: false,
|
||||
prefill_selector: HashMap::new(),
|
||||
decode_selector: HashMap::new(),
|
||||
@@ -51,7 +50,6 @@ impl Default for ServiceDiscoveryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Pod type for PD mode service discovery
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum PodType {
|
||||
Prefill,
|
||||
@@ -59,7 +57,6 @@ pub enum PodType {
|
||||
Regular,
|
||||
}
|
||||
|
||||
/// Represents a Kubernetes pod's information used for worker management
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct PodInfo {
|
||||
pub name: String,
|
||||
@@ -71,7 +68,6 @@ pub struct PodInfo {
|
||||
}
|
||||
|
||||
impl PodInfo {
|
||||
/// Check if a pod matches any of the given selectors
|
||||
fn matches_selector(pod: &Pod, selector: &HashMap<String, String>) -> bool {
|
||||
if selector.is_empty() {
|
||||
return false;
|
||||
@@ -83,19 +79,15 @@ impl PodInfo {
|
||||
.is_some_and(|labels| selector.iter().all(|(k, v)| labels.get(k) == Some(v)))
|
||||
}
|
||||
|
||||
/// Check if a pod should be included in service discovery
|
||||
pub fn should_include(pod: &Pod, config: &ServiceDiscoveryConfig) -> bool {
|
||||
if config.pd_mode {
|
||||
// In PD mode, at least one selector must be non-empty
|
||||
if config.prefill_selector.is_empty() && config.decode_selector.is_empty() {
|
||||
warn!("PD mode enabled but both prefill_selector and decode_selector are empty");
|
||||
return false;
|
||||
}
|
||||
// In PD mode, pod must match either prefill or decode selector
|
||||
Self::matches_selector(pod, &config.prefill_selector)
|
||||
|| Self::matches_selector(pod, &config.decode_selector)
|
||||
} else {
|
||||
// In regular mode, pod must match the general selector
|
||||
if config.selector.is_empty() {
|
||||
warn!("Regular mode enabled but selector is empty");
|
||||
return false;
|
||||
@@ -104,7 +96,6 @@ impl PodInfo {
|
||||
}
|
||||
}
|
||||
|
||||
/// Unified PodInfo creation with optional PD configuration
|
||||
pub fn from_pod(pod: &Pod, config: Option<&ServiceDiscoveryConfig>) -> Option<Self> {
|
||||
let name = pod.metadata.name.clone()?;
|
||||
let status = pod.status.clone()?;
|
||||
@@ -120,10 +111,8 @@ impl PodInfo {
|
||||
|
||||
let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string());
|
||||
|
||||
// Determine pod type based on labels if config is provided and in PD mode
|
||||
let pod_type = if let Some(config) = config {
|
||||
if config.pd_mode {
|
||||
// Use simplified helper methods for cleaner logic
|
||||
if Self::matches_selector(pod, &config.prefill_selector) {
|
||||
Some(PodType::Prefill)
|
||||
} else if Self::matches_selector(pod, &config.decode_selector) {
|
||||
@@ -135,11 +124,9 @@ impl PodInfo {
|
||||
Some(PodType::Regular)
|
||||
}
|
||||
} else {
|
||||
// No config provided, default to None (for backwards compatibility)
|
||||
None
|
||||
};
|
||||
|
||||
// Extract bootstrap port from annotations for prefill pods
|
||||
let bootstrap_port = if matches!(pod_type, Some(PodType::Prefill)) {
|
||||
if let Some(config) = config {
|
||||
pod.metadata
|
||||
@@ -164,12 +151,10 @@ impl PodInfo {
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns true if the pod is in a state where it can accept traffic
|
||||
pub fn is_healthy(&self) -> bool {
|
||||
self.is_ready && self.status == "Running"
|
||||
}
|
||||
|
||||
/// Generates a worker URL for this pod
|
||||
pub fn worker_url(&self, port: u16) -> String {
|
||||
format!("http://{}:{}", self.ip, port)
|
||||
}
|
||||
@@ -179,9 +164,7 @@ pub async fn start_service_discovery(
|
||||
config: ServiceDiscoveryConfig,
|
||||
app_context: Arc<AppContext>,
|
||||
) -> Result<task::JoinHandle<()>, kube::Error> {
|
||||
// Don't initialize anything if service discovery is disabled
|
||||
if !config.enabled {
|
||||
// Return a generic error when service discovery is disabled
|
||||
return Err(kube::Error::Api(kube::error::ErrorResponse {
|
||||
status: "Disabled".to_string(),
|
||||
message: "Service discovery is disabled".to_string(),
|
||||
@@ -192,7 +175,6 @@ pub async fn start_service_discovery(
|
||||
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
// Initialize Kubernetes client
|
||||
let client = Client::try_default().await?;
|
||||
|
||||
// Log the appropriate selectors based on mode
|
||||
@@ -229,12 +211,9 @@ pub async fn start_service_discovery(
|
||||
);
|
||||
}
|
||||
|
||||
// Create the task that will run in the background
|
||||
let handle = task::spawn(async move {
|
||||
// We'll track pods we've already added to avoid duplicates
|
||||
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||
|
||||
// Create a watcher for pods
|
||||
let pods: Api<Pod> = if let Some(namespace) = &config.namespace {
|
||||
Api::namespaced(client, namespace)
|
||||
} else {
|
||||
@@ -243,23 +222,19 @@ pub async fn start_service_discovery(
|
||||
|
||||
debug!("K8s service discovery initialized");
|
||||
|
||||
// Create Arcs for configuration data
|
||||
let config_arc = Arc::new(config.clone());
|
||||
let port = config.port;
|
||||
|
||||
let mut retry_delay = Duration::from_secs(1);
|
||||
const MAX_RETRY_DELAY: Duration = Duration::from_secs(300); // 5 minutes max
|
||||
const MAX_RETRY_DELAY: Duration = Duration::from_secs(300);
|
||||
|
||||
loop {
|
||||
// Create a watcher with the proper parameters according to the kube-rs API
|
||||
let watcher_config = Config::default();
|
||||
let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects();
|
||||
|
||||
// Clone Arcs for the closures
|
||||
let config_clone = Arc::clone(&config_arc);
|
||||
let tracked_pods_clone = Arc::clone(&tracked_pods);
|
||||
|
||||
// Simplified label selector filter using helper method
|
||||
let filtered_stream = watcher_stream.filter_map(move |obj_res| {
|
||||
let config_inner = Arc::clone(&config_clone);
|
||||
|
||||
@@ -277,7 +252,6 @@ pub async fn start_service_discovery(
|
||||
}
|
||||
});
|
||||
|
||||
// Clone again for the next closure
|
||||
let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone);
|
||||
let app_context_clone = Arc::clone(&app_context);
|
||||
let config_clone2 = Arc::clone(&config_arc);
|
||||
@@ -317,7 +291,6 @@ pub async fn start_service_discovery(
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
// Reset retry delay on success
|
||||
retry_delay = Duration::from_secs(1);
|
||||
}
|
||||
Err(err) => {
|
||||
@@ -328,12 +301,10 @@ pub async fn start_service_discovery(
|
||||
);
|
||||
time::sleep(retry_delay).await;
|
||||
|
||||
// Exponential backoff with jitter
|
||||
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
|
||||
}
|
||||
}
|
||||
|
||||
// If the watcher exits for some reason, wait a bit before restarting
|
||||
warn!(
|
||||
"Kubernetes watcher exited, restarting in {} seconds",
|
||||
config_arc.check_interval.as_secs()
|
||||
@@ -354,9 +325,7 @@ async fn handle_pod_event(
|
||||
) {
|
||||
let worker_url = pod_info.worker_url(port);
|
||||
|
||||
// If pod is healthy, try to add it (with atomic check-and-insert)
|
||||
if pod_info.is_healthy() {
|
||||
// Atomic check-and-insert to prevent race conditions
|
||||
let should_add = {
|
||||
let mut tracker = match tracked_pods.lock() {
|
||||
Ok(tracker) => tracker,
|
||||
@@ -367,9 +336,8 @@ async fn handle_pod_event(
|
||||
};
|
||||
|
||||
if tracker.contains(pod_info) {
|
||||
false // Already tracked
|
||||
false
|
||||
} else {
|
||||
// Reserve the spot to prevent other threads from adding the same pod
|
||||
tracker.insert(pod_info.clone());
|
||||
true
|
||||
}
|
||||
@@ -381,7 +349,6 @@ async fn handle_pod_event(
|
||||
pod_info.name, pod_info.pod_type, worker_url
|
||||
);
|
||||
|
||||
// Build worker config based on pod type and routing mode
|
||||
let worker_type = if pd_mode {
|
||||
match &pod_info.pod_type {
|
||||
Some(PodType::Prefill) => Some("prefill".to_string()),
|
||||
@@ -392,7 +359,6 @@ async fn handle_pod_event(
|
||||
None
|
||||
};
|
||||
|
||||
// Only set bootstrap_port for prefill workers in PD mode
|
||||
let bootstrap_port = if pd_mode {
|
||||
match &pod_info.pod_type {
|
||||
Some(PodType::Prefill) => pod_info.bootstrap_port,
|
||||
@@ -425,7 +391,6 @@ async fn handle_pod_event(
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to add worker {} to router: {}", worker_url, e);
|
||||
// Remove from tracking since addition failed
|
||||
if let Ok(mut tracker) = tracked_pods.lock() {
|
||||
tracker.remove(pod_info);
|
||||
}
|
||||
@@ -464,8 +429,6 @@ async fn handle_pod_deletion(
|
||||
error!("Failed to remove worker {}: {}", worker_url, e);
|
||||
}
|
||||
} else {
|
||||
// This case might occur if a pod is deleted before it was ever marked healthy and added.
|
||||
// Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added).
|
||||
debug!(
|
||||
"Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}",
|
||||
pod_info.name, pod_info.pod_type, worker_url
|
||||
@@ -480,7 +443,6 @@ mod tests {
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
|
||||
|
||||
// Helper function to create a Pod for testing PodInfo::from_pod
|
||||
fn create_k8s_pod(
|
||||
name: Option<&str>,
|
||||
ip: Option<&str>,
|
||||
@@ -523,7 +485,6 @@ mod tests {
|
||||
pod
|
||||
}
|
||||
|
||||
// Helper function to create a Pod with PD-specific labels and annotations
|
||||
fn create_pd_k8s_pod(name: &str, ip: &str, pod_type: &str, bootstrap_port: Option<u16>) -> Pod {
|
||||
let mut labels = std::collections::BTreeMap::new();
|
||||
labels.insert("app".to_string(), "sglang".to_string());
|
||||
@@ -559,18 +520,15 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to create an AppContext instance for testing event handlers
|
||||
async fn create_test_app_context() -> Arc<AppContext> {
|
||||
use crate::config::RouterConfig;
|
||||
use crate::middleware::TokenBucket;
|
||||
|
||||
// Create a minimal RouterConfig for testing with very short timeout
|
||||
let router_config = RouterConfig {
|
||||
worker_startup_timeout_secs: 1,
|
||||
..Default::default()
|
||||
}; // Very short timeout for tests
|
||||
};
|
||||
|
||||
// Create AppContext with minimal components
|
||||
Arc::new(AppContext {
|
||||
client: reqwest::Client::new(),
|
||||
router_config: router_config.clone(),
|
||||
@@ -579,16 +537,15 @@ mod tests {
|
||||
policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
|
||||
router_config.policy.clone(),
|
||||
)),
|
||||
tokenizer: None, // HTTP mode doesn't need tokenizer
|
||||
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
|
||||
tool_parser_registry: None, // HTTP mode doesn't need tool parser
|
||||
router_manager: None, // Test doesn't need router manager
|
||||
tokenizer: None,
|
||||
reasoning_parser_factory: None,
|
||||
tool_parser_registry: None,
|
||||
router_manager: None,
|
||||
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
||||
load_monitor: None,
|
||||
})
|
||||
}
|
||||
|
||||
// Helper to create a PD config for testing
|
||||
fn create_pd_config() -> ServiceDiscoveryConfig {
|
||||
let mut prefill_selector = HashMap::new();
|
||||
prefill_selector.insert("app".to_string(), "sglang".to_string());
|
||||
@@ -615,19 +572,15 @@ mod tests {
|
||||
fn test_pod_info_should_include() {
|
||||
let config = create_pd_config();
|
||||
|
||||
// Test prefill pod should be included
|
||||
let prefill_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081));
|
||||
assert!(PodInfo::should_include(&prefill_pod, &config));
|
||||
|
||||
// Test decode pod should be included
|
||||
let decode_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None);
|
||||
assert!(PodInfo::should_include(&decode_pod, &config));
|
||||
|
||||
// Test unmatched pod should not be included
|
||||
let unmatched_pod = create_pd_k8s_pod("other-pod", "10.0.0.3", "other", None);
|
||||
assert!(!PodInfo::should_include(&unmatched_pod, &config));
|
||||
|
||||
// Test regular mode
|
||||
let mut regular_config = ServiceDiscoveryConfig::default();
|
||||
regular_config
|
||||
.selector
|
||||
@@ -654,7 +607,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_pod_type_enum() {
|
||||
// Test that PodType enum has expected variants
|
||||
let prefill = PodType::Prefill;
|
||||
let decode = PodType::Decode;
|
||||
let regular = PodType::Regular;
|
||||
@@ -714,7 +666,7 @@ mod tests {
|
||||
fn test_pod_info_from_pod_with_pd_config_regular_mode() {
|
||||
let k8s_pod = create_pd_k8s_pod("regular-pod", "10.0.0.3", "worker", None);
|
||||
let mut config = create_pd_config();
|
||||
config.pd_mode = false; // Set to regular mode
|
||||
config.pd_mode = false;
|
||||
|
||||
let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap();
|
||||
assert_eq!(pod_info.name, "regular-pod");
|
||||
@@ -742,7 +694,6 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port() {
|
||||
let mut pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", None);
|
||||
// Add invalid bootstrap port annotation
|
||||
pod.metadata.annotations.as_mut().unwrap().insert(
|
||||
"sglang.ai/bootstrap-port".to_string(),
|
||||
"invalid".to_string(),
|
||||
@@ -751,7 +702,7 @@ mod tests {
|
||||
|
||||
let pod_info = PodInfo::from_pod(&pod, Some(&config)).unwrap();
|
||||
assert_eq!(pod_info.pod_type, Some(PodType::Prefill));
|
||||
assert!(pod_info.bootstrap_port.is_none()); // Should be None for invalid port
|
||||
assert!(pod_info.bootstrap_port.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1077,7 +1028,6 @@ mod tests {
|
||||
)
|
||||
.await;
|
||||
|
||||
// Pod should not be tracked since add_worker_from_url will fail for non-running server
|
||||
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user