[router] Add IGW (Inference Gateway) Feature Flag (#9371)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
2
.github/workflows/pr-test-rust.yml
vendored
2
.github/workflows/pr-test-rust.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
cargo check --benches
|
||||
|
||||
- name: Quick benchmark sanity check
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 15
|
||||
run: |
|
||||
source "$HOME/.cargo/env"
|
||||
cd sgl-router/
|
||||
|
||||
@@ -51,6 +51,9 @@ pub struct RouterConfig {
|
||||
pub disable_circuit_breaker: bool,
|
||||
/// Health check configuration
|
||||
pub health_check: HealthCheckConfig,
|
||||
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
|
||||
#[serde(default)]
|
||||
pub enable_igw: bool,
|
||||
}
|
||||
|
||||
/// Routing mode configuration
|
||||
@@ -323,6 +326,7 @@ impl Default for RouterConfig {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -377,6 +381,11 @@ impl RouterConfig {
|
||||
}
|
||||
cfg
|
||||
}
|
||||
|
||||
/// Check if running in IGW (Inference Gateway) mode
|
||||
pub fn is_igw_mode(&self) -> bool {
|
||||
self.enable_igw
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -456,6 +465,7 @@ mod tests {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
@@ -888,6 +898,7 @@ mod tests {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
assert!(config.mode.is_pd_mode());
|
||||
@@ -944,6 +955,7 @@ mod tests {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
assert!(!config.mode.is_pd_mode());
|
||||
@@ -996,6 +1008,7 @@ mod tests {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
assert!(config.has_service_discovery());
|
||||
|
||||
@@ -344,6 +344,11 @@ impl ConfigValidator {
|
||||
|
||||
/// Validate compatibility between different configuration sections
|
||||
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
|
||||
// IGW mode is independent - skip other compatibility checks when enabled
|
||||
if config.enable_igw {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// All policies are now supported for both router types thanks to the unified trait design
|
||||
// No mode/policy restrictions needed anymore
|
||||
|
||||
|
||||
@@ -82,6 +82,8 @@ struct Router {
|
||||
health_check_timeout_secs: u64,
|
||||
health_check_interval_secs: u64,
|
||||
health_check_endpoint: String,
|
||||
// IGW (Inference Gateway) configuration
|
||||
enable_igw: bool,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -110,7 +112,12 @@ impl Router {
|
||||
};
|
||||
|
||||
// Determine routing mode
|
||||
let mode = if self.pd_disaggregation {
|
||||
let mode = if self.enable_igw {
|
||||
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
}
|
||||
} else if self.pd_disaggregation {
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
|
||||
decode_urls: self.decode_urls.clone().unwrap_or_default(),
|
||||
@@ -191,6 +198,7 @@ impl Router {
|
||||
check_interval_secs: self.health_check_interval_secs,
|
||||
endpoint: self.health_check_endpoint.clone(),
|
||||
},
|
||||
enable_igw: self.enable_igw,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -252,6 +260,8 @@ impl Router {
|
||||
health_check_timeout_secs = 5,
|
||||
health_check_interval_secs = 60,
|
||||
health_check_endpoint = String::from("/health"),
|
||||
// IGW defaults
|
||||
enable_igw = false,
|
||||
))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
@@ -305,6 +315,7 @@ impl Router {
|
||||
health_check_timeout_secs: u64,
|
||||
health_check_interval_secs: u64,
|
||||
health_check_endpoint: String,
|
||||
enable_igw: bool,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Router {
|
||||
host,
|
||||
@@ -357,6 +368,7 @@ impl Router {
|
||||
health_check_timeout_secs,
|
||||
health_check_interval_secs,
|
||||
health_check_endpoint,
|
||||
enable_igw,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ Examples:
|
||||
--decode http://127.0.0.3:30003 \
|
||||
--decode http://127.0.0.4:30004 \
|
||||
--prefill-policy cache_aware --decode-policy power_of_two
|
||||
|
||||
"#)]
|
||||
struct CliArgs {
|
||||
/// Host address to bind the router server
|
||||
@@ -266,6 +267,11 @@ struct CliArgs {
|
||||
/// Health check endpoint path
|
||||
#[arg(long, default_value = "/health")]
|
||||
health_check_endpoint: String,
|
||||
|
||||
// IGW (Inference Gateway) configuration
|
||||
/// Enable Inference Gateway mode
|
||||
#[arg(long, default_value_t = false)]
|
||||
enable_igw: bool,
|
||||
}
|
||||
|
||||
impl CliArgs {
|
||||
@@ -307,7 +313,12 @@ impl CliArgs {
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
) -> ConfigResult<RouterConfig> {
|
||||
// Determine routing mode
|
||||
let mode = if self.pd_disaggregation {
|
||||
let mode = if self.enable_igw {
|
||||
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
}
|
||||
} else if self.pd_disaggregation {
|
||||
let decode_urls = self.decode.clone();
|
||||
|
||||
// Validate PD configuration if not using service discovery
|
||||
@@ -406,6 +417,7 @@ impl CliArgs {
|
||||
check_interval_secs: self.health_check_interval_secs,
|
||||
endpoint: self.health_check_endpoint.clone(),
|
||||
},
|
||||
enable_igw: self.enable_igw,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -487,17 +499,22 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
||||
println!(
|
||||
"Mode: {}",
|
||||
if cli_args.pd_disaggregation {
|
||||
if cli_args.enable_igw {
|
||||
"IGW (Inference Gateway)"
|
||||
} else if cli_args.pd_disaggregation {
|
||||
"PD Disaggregated"
|
||||
} else {
|
||||
"Regular"
|
||||
}
|
||||
);
|
||||
println!("Policy: {}", cli_args.policy);
|
||||
|
||||
if cli_args.pd_disaggregation && !prefill_urls.is_empty() {
|
||||
println!("Prefill nodes: {:?}", prefill_urls);
|
||||
println!("Decode nodes: {:?}", cli_args.decode);
|
||||
if !cli_args.enable_igw {
|
||||
println!("Policy: {}", cli_args.policy);
|
||||
|
||||
if cli_args.pd_disaggregation && !prefill_urls.is_empty() {
|
||||
println!("Prefill nodes: {:?}", prefill_urls);
|
||||
println!("Decode nodes: {:?}", cli_args.decode);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to RouterConfig
|
||||
|
||||
@@ -12,6 +12,12 @@ pub struct RouterFactory;
|
||||
impl RouterFactory {
|
||||
/// Create a router instance from application context
|
||||
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Check if IGW mode is enabled
|
||||
if ctx.router_config.enable_igw {
|
||||
return Self::create_igw_router(ctx).await;
|
||||
}
|
||||
|
||||
// Default to proxy mode
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||
@@ -94,4 +100,10 @@ impl RouterFactory {
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create an IGW router (placeholder for future implementation)
|
||||
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// For now, return an error indicating IGW is not yet implemented
|
||||
Err("IGW mode is not yet implemented".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ impl TestContext {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
Self::new_with_config(config, worker_configs).await
|
||||
@@ -1093,6 +1094,7 @@ mod error_tests {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
@@ -1444,6 +1446,7 @@ mod pd_mode_tests {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
// Create app context
|
||||
@@ -1599,6 +1602,7 @@ mod request_id_tests {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
|
||||
@@ -42,6 +42,7 @@ impl TestContext {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
@@ -43,6 +43,7 @@ impl TestContext {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
@@ -184,6 +184,7 @@ mod test_pd_routing {
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
};
|
||||
|
||||
// Router creation will fail due to health checks, but config should be valid
|
||||
|
||||
Reference in New Issue
Block a user