From 3828db4309908c29d174749ed25790c09977875a Mon Sep 17 00:00:00 2001 From: Keyang Ru Date: Wed, 20 Aug 2025 17:38:57 -0700 Subject: [PATCH] [router] Add IGW (Inference Gateway) Feature Flag (#9371) Co-authored-by: Yineng Zhang --- .github/workflows/pr-test-rust.yml | 2 +- sgl-router/src/config/types.rs | 13 +++++++++++ sgl-router/src/config/validation.rs | 5 ++++ sgl-router/src/lib.rs | 14 +++++++++++- sgl-router/src/main.rs | 29 +++++++++++++++++++----- sgl-router/src/routers/factory.rs | 12 ++++++++++ sgl-router/tests/api_endpoints_test.rs | 4 ++++ sgl-router/tests/request_formats_test.rs | 1 + sgl-router/tests/streaming_tests.rs | 1 + sgl-router/tests/test_pd_routing.rs | 1 + 10 files changed, 74 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index e3ea0305f..85107ed30 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -53,7 +53,7 @@ jobs: cargo check --benches - name: Quick benchmark sanity check - timeout-minutes: 10 + timeout-minutes: 15 run: | source "$HOME/.cargo/env" cd sgl-router/ diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 336ba10d7..45e7e8d96 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -51,6 +51,9 @@ pub struct RouterConfig { pub disable_circuit_breaker: bool, /// Health check configuration pub health_check: HealthCheckConfig, + /// Enable Inference Gateway mode (false = proxy mode, true = IGW mode) + #[serde(default)] + pub enable_igw: bool, } /// Routing mode configuration @@ -323,6 +326,7 @@ impl Default for RouterConfig { disable_retries: false, disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), + enable_igw: false, } } } @@ -377,6 +381,11 @@ impl RouterConfig { } cfg } + + /// Check if running in IGW (Inference Gateway) mode + pub fn is_igw_mode(&self) -> bool { + self.enable_igw + } } #[cfg(test)] @@ -456,6 +465,7 @@ mod tests { disable_retries: false, disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), + enable_igw: false, }; let json = serde_json::to_string(&config).unwrap(); @@ -888,6 +898,7 @@ mod tests { disable_retries: false, disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), + enable_igw: false, }; assert!(config.mode.is_pd_mode()); @@ -944,6 +955,7 @@ mod tests { disable_retries: false, disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), + enable_igw: false, }; assert!(!config.mode.is_pd_mode()); @@ -996,6 +1008,7 @@ mod tests { disable_retries: false, disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), + enable_igw: false, }; assert!(config.has_service_discovery()); diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index da2a12523..542e2e467 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -344,6 +344,11 @@ impl ConfigValidator { /// Validate compatibility between different configuration sections fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> { + // IGW mode is independent - skip other compatibility checks when enabled + if config.enable_igw { + return Ok(()); + } + // All policies are now supported for both router types thanks to the unified trait design // No mode/policy restrictions needed anymore diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index e41942c14..4644ea257 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -82,6 +82,8 @@ struct Router { health_check_timeout_secs: u64, health_check_interval_secs: u64, health_check_endpoint: String, + // IGW (Inference Gateway) configuration + enable_igw: bool, } impl Router { @@ -110,7 +112,12 @@ impl Router { }; // Determine routing mode - let mode = if self.pd_disaggregation { + let mode = if self.enable_igw { + // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder + RoutingMode::Regular { + worker_urls: vec![], + } + } else if self.pd_disaggregation { RoutingMode::PrefillDecode { prefill_urls: self.prefill_urls.clone().unwrap_or_default(), decode_urls: self.decode_urls.clone().unwrap_or_default(), @@ -191,6 +198,7 @@ impl Router { check_interval_secs: self.health_check_interval_secs, endpoint: self.health_check_endpoint.clone(), }, + enable_igw: self.enable_igw, }) } } @@ -252,6 +260,8 @@ impl Router { health_check_timeout_secs = 5, health_check_interval_secs = 60, health_check_endpoint = String::from("/health"), + // IGW defaults + enable_igw = false, ))] #[allow(clippy::too_many_arguments)] fn new( @@ -305,6 +315,7 @@ impl Router { health_check_timeout_secs: u64, health_check_interval_secs: u64, health_check_endpoint: String, + enable_igw: bool, ) -> PyResult { Ok(Router { host, @@ -357,6 +368,7 @@ impl Router { health_check_timeout_secs, health_check_interval_secs, health_check_endpoint, + enable_igw, }) } diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 6c6f9fb95..a2956e88c 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -70,6 +70,7 @@ Examples: --decode http://127.0.0.3:30003 \ --decode http://127.0.0.4:30004 \ --prefill-policy cache_aware --decode-policy power_of_two + "#)] struct CliArgs { /// Host address to bind the router server @@ -266,6 +267,11 @@ struct CliArgs { /// Health check endpoint path #[arg(long, default_value = "/health")] health_check_endpoint: String, + + // IGW (Inference Gateway) configuration + /// Enable Inference Gateway mode + #[arg(long, default_value_t = false)] + enable_igw: bool, } impl CliArgs { @@ -307,7 +313,12 @@ impl CliArgs { prefill_urls: Vec<(String, Option)>, ) -> ConfigResult { // Determine routing mode - let mode = if self.pd_disaggregation { + let mode = if self.enable_igw { + // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder + RoutingMode::Regular { + worker_urls: vec![], + } + } else if self.pd_disaggregation { let decode_urls = self.decode.clone(); // Validate PD configuration if not using service discovery @@ -406,6 +417,7 @@ impl CliArgs { check_interval_secs: self.health_check_interval_secs, endpoint: self.health_check_endpoint.clone(), }, + enable_igw: self.enable_igw, }) } @@ -487,17 +499,22 @@ fn main() -> Result<(), Box> { println!("Host: {}:{}", cli_args.host, cli_args.port); println!( "Mode: {}", - if cli_args.pd_disaggregation { + if cli_args.enable_igw { + "IGW (Inference Gateway)" + } else if cli_args.pd_disaggregation { "PD Disaggregated" } else { "Regular" } ); - println!("Policy: {}", cli_args.policy); - if cli_args.pd_disaggregation && !prefill_urls.is_empty() { - println!("Prefill nodes: {:?}", prefill_urls); - println!("Decode nodes: {:?}", cli_args.decode); + if !cli_args.enable_igw { + println!("Policy: {}", cli_args.policy); + + if cli_args.pd_disaggregation && !prefill_urls.is_empty() { + println!("Prefill nodes: {:?}", prefill_urls); + println!("Decode nodes: {:?}", cli_args.decode); + } } // Convert to RouterConfig diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index a96e89b27..7b4f848bc 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -12,6 +12,12 @@ pub struct RouterFactory; impl RouterFactory { /// Create a router instance from application context pub async fn create_router(ctx: &Arc) -> Result, String> { + // Check if IGW mode is enabled + if ctx.router_config.enable_igw { + return Self::create_igw_router(ctx).await; + } + + // Default to proxy mode match &ctx.router_config.mode { RoutingMode::Regular { worker_urls } => { Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await @@ -94,4 +100,10 @@ impl RouterFactory { Ok(Box::new(router)) } + + /// Create an IGW router (placeholder for future implementation) + async fn create_igw_router(_ctx: &Arc) -> Result, String> { + // For now, return an error indicating IGW is not yet implemented + Err("IGW mode is not yet implemented".to_string()) + } } diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index c67080d56..6a4d8d66c 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -51,6 +51,7 @@ impl TestContext { disable_retries: false, disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), + enable_igw: false, }; Self::new_with_config(config, worker_configs).await @@ -1093,6 +1094,7 @@ mod error_tests { disable_retries: false, disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), + enable_igw: false, }; let ctx = TestContext::new_with_config( @@ -1444,6 +1446,7 @@ mod pd_mode_tests { disable_retries: false, disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), + enable_igw: false, }; // Create app context @@ -1599,6 +1602,7 @@ mod request_id_tests { disable_retries: false, disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), + enable_igw: false, }; let ctx = TestContext::new_with_config( diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index c0217c590..c62461754 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -42,6 +42,7 @@ impl TestContext { disable_retries: false, disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), + enable_igw: false, }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index 4d1e65cb0..5e7828952 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -43,6 +43,7 @@ impl TestContext { disable_retries: false, disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), + enable_igw: false, }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 2bf47b187..33091824d 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -184,6 +184,7 @@ mod test_pd_routing { disable_retries: false, disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), + enable_igw: false, }; // Router creation will fail due to health checks, but config should be valid