[router] Add IGW (Inference Gateway) Feature Flag (#9371)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
2
.github/workflows/pr-test-rust.yml
vendored
2
.github/workflows/pr-test-rust.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
|||||||
cargo check --benches
|
cargo check --benches
|
||||||
|
|
||||||
- name: Quick benchmark sanity check
|
- name: Quick benchmark sanity check
|
||||||
timeout-minutes: 10
|
timeout-minutes: 15
|
||||||
run: |
|
run: |
|
||||||
source "$HOME/.cargo/env"
|
source "$HOME/.cargo/env"
|
||||||
cd sgl-router/
|
cd sgl-router/
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ pub struct RouterConfig {
|
|||||||
pub disable_circuit_breaker: bool,
|
pub disable_circuit_breaker: bool,
|
||||||
/// Health check configuration
|
/// Health check configuration
|
||||||
pub health_check: HealthCheckConfig,
|
pub health_check: HealthCheckConfig,
|
||||||
|
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
|
||||||
|
#[serde(default)]
|
||||||
|
pub enable_igw: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Routing mode configuration
|
/// Routing mode configuration
|
||||||
@@ -323,6 +326,7 @@ impl Default for RouterConfig {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: HealthCheckConfig::default(),
|
health_check: HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -377,6 +381,11 @@ impl RouterConfig {
|
|||||||
}
|
}
|
||||||
cfg
|
cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if running in IGW (Inference Gateway) mode
|
||||||
|
pub fn is_igw_mode(&self) -> bool {
|
||||||
|
self.enable_igw
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -456,6 +465,7 @@ mod tests {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: HealthCheckConfig::default(),
|
health_check: HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let json = serde_json::to_string(&config).unwrap();
|
let json = serde_json::to_string(&config).unwrap();
|
||||||
@@ -888,6 +898,7 @@ mod tests {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: HealthCheckConfig::default(),
|
health_check: HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.mode.is_pd_mode());
|
assert!(config.mode.is_pd_mode());
|
||||||
@@ -944,6 +955,7 @@ mod tests {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: HealthCheckConfig::default(),
|
health_check: HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(!config.mode.is_pd_mode());
|
assert!(!config.mode.is_pd_mode());
|
||||||
@@ -996,6 +1008,7 @@ mod tests {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: HealthCheckConfig::default(),
|
health_check: HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.has_service_discovery());
|
assert!(config.has_service_discovery());
|
||||||
|
|||||||
@@ -344,6 +344,11 @@ impl ConfigValidator {
|
|||||||
|
|
||||||
/// Validate compatibility between different configuration sections
|
/// Validate compatibility between different configuration sections
|
||||||
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
|
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
|
||||||
|
// IGW mode is independent - skip other compatibility checks when enabled
|
||||||
|
if config.enable_igw {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
// All policies are now supported for both router types thanks to the unified trait design
|
// All policies are now supported for both router types thanks to the unified trait design
|
||||||
// No mode/policy restrictions needed anymore
|
// No mode/policy restrictions needed anymore
|
||||||
|
|
||||||
|
|||||||
@@ -82,6 +82,8 @@ struct Router {
|
|||||||
health_check_timeout_secs: u64,
|
health_check_timeout_secs: u64,
|
||||||
health_check_interval_secs: u64,
|
health_check_interval_secs: u64,
|
||||||
health_check_endpoint: String,
|
health_check_endpoint: String,
|
||||||
|
// IGW (Inference Gateway) configuration
|
||||||
|
enable_igw: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
@@ -110,7 +112,12 @@ impl Router {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Determine routing mode
|
// Determine routing mode
|
||||||
let mode = if self.pd_disaggregation {
|
let mode = if self.enable_igw {
|
||||||
|
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec![],
|
||||||
|
}
|
||||||
|
} else if self.pd_disaggregation {
|
||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
|
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
|
||||||
decode_urls: self.decode_urls.clone().unwrap_or_default(),
|
decode_urls: self.decode_urls.clone().unwrap_or_default(),
|
||||||
@@ -191,6 +198,7 @@ impl Router {
|
|||||||
check_interval_secs: self.health_check_interval_secs,
|
check_interval_secs: self.health_check_interval_secs,
|
||||||
endpoint: self.health_check_endpoint.clone(),
|
endpoint: self.health_check_endpoint.clone(),
|
||||||
},
|
},
|
||||||
|
enable_igw: self.enable_igw,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -252,6 +260,8 @@ impl Router {
|
|||||||
health_check_timeout_secs = 5,
|
health_check_timeout_secs = 5,
|
||||||
health_check_interval_secs = 60,
|
health_check_interval_secs = 60,
|
||||||
health_check_endpoint = String::from("/health"),
|
health_check_endpoint = String::from("/health"),
|
||||||
|
// IGW defaults
|
||||||
|
enable_igw = false,
|
||||||
))]
|
))]
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn new(
|
fn new(
|
||||||
@@ -305,6 +315,7 @@ impl Router {
|
|||||||
health_check_timeout_secs: u64,
|
health_check_timeout_secs: u64,
|
||||||
health_check_interval_secs: u64,
|
health_check_interval_secs: u64,
|
||||||
health_check_endpoint: String,
|
health_check_endpoint: String,
|
||||||
|
enable_igw: bool,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
Ok(Router {
|
Ok(Router {
|
||||||
host,
|
host,
|
||||||
@@ -357,6 +368,7 @@ impl Router {
|
|||||||
health_check_timeout_secs,
|
health_check_timeout_secs,
|
||||||
health_check_interval_secs,
|
health_check_interval_secs,
|
||||||
health_check_endpoint,
|
health_check_endpoint,
|
||||||
|
enable_igw,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ Examples:
|
|||||||
--decode http://127.0.0.3:30003 \
|
--decode http://127.0.0.3:30003 \
|
||||||
--decode http://127.0.0.4:30004 \
|
--decode http://127.0.0.4:30004 \
|
||||||
--prefill-policy cache_aware --decode-policy power_of_two
|
--prefill-policy cache_aware --decode-policy power_of_two
|
||||||
|
|
||||||
"#)]
|
"#)]
|
||||||
struct CliArgs {
|
struct CliArgs {
|
||||||
/// Host address to bind the router server
|
/// Host address to bind the router server
|
||||||
@@ -266,6 +267,11 @@ struct CliArgs {
|
|||||||
/// Health check endpoint path
|
/// Health check endpoint path
|
||||||
#[arg(long, default_value = "/health")]
|
#[arg(long, default_value = "/health")]
|
||||||
health_check_endpoint: String,
|
health_check_endpoint: String,
|
||||||
|
|
||||||
|
// IGW (Inference Gateway) configuration
|
||||||
|
/// Enable Inference Gateway mode
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
enable_igw: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CliArgs {
|
impl CliArgs {
|
||||||
@@ -307,7 +313,12 @@ impl CliArgs {
|
|||||||
prefill_urls: Vec<(String, Option<u16>)>,
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
) -> ConfigResult<RouterConfig> {
|
) -> ConfigResult<RouterConfig> {
|
||||||
// Determine routing mode
|
// Determine routing mode
|
||||||
let mode = if self.pd_disaggregation {
|
let mode = if self.enable_igw {
|
||||||
|
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec![],
|
||||||
|
}
|
||||||
|
} else if self.pd_disaggregation {
|
||||||
let decode_urls = self.decode.clone();
|
let decode_urls = self.decode.clone();
|
||||||
|
|
||||||
// Validate PD configuration if not using service discovery
|
// Validate PD configuration if not using service discovery
|
||||||
@@ -406,6 +417,7 @@ impl CliArgs {
|
|||||||
check_interval_secs: self.health_check_interval_secs,
|
check_interval_secs: self.health_check_interval_secs,
|
||||||
endpoint: self.health_check_endpoint.clone(),
|
endpoint: self.health_check_endpoint.clone(),
|
||||||
},
|
},
|
||||||
|
enable_igw: self.enable_igw,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -487,17 +499,22 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
||||||
println!(
|
println!(
|
||||||
"Mode: {}",
|
"Mode: {}",
|
||||||
if cli_args.pd_disaggregation {
|
if cli_args.enable_igw {
|
||||||
|
"IGW (Inference Gateway)"
|
||||||
|
} else if cli_args.pd_disaggregation {
|
||||||
"PD Disaggregated"
|
"PD Disaggregated"
|
||||||
} else {
|
} else {
|
||||||
"Regular"
|
"Regular"
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
println!("Policy: {}", cli_args.policy);
|
|
||||||
|
|
||||||
if cli_args.pd_disaggregation && !prefill_urls.is_empty() {
|
if !cli_args.enable_igw {
|
||||||
println!("Prefill nodes: {:?}", prefill_urls);
|
println!("Policy: {}", cli_args.policy);
|
||||||
println!("Decode nodes: {:?}", cli_args.decode);
|
|
||||||
|
if cli_args.pd_disaggregation && !prefill_urls.is_empty() {
|
||||||
|
println!("Prefill nodes: {:?}", prefill_urls);
|
||||||
|
println!("Decode nodes: {:?}", cli_args.decode);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to RouterConfig
|
// Convert to RouterConfig
|
||||||
|
|||||||
@@ -12,6 +12,12 @@ pub struct RouterFactory;
|
|||||||
impl RouterFactory {
|
impl RouterFactory {
|
||||||
/// Create a router instance from application context
|
/// Create a router instance from application context
|
||||||
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
|
// Check if IGW mode is enabled
|
||||||
|
if ctx.router_config.enable_igw {
|
||||||
|
return Self::create_igw_router(ctx).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to proxy mode
|
||||||
match &ctx.router_config.mode {
|
match &ctx.router_config.mode {
|
||||||
RoutingMode::Regular { worker_urls } => {
|
RoutingMode::Regular { worker_urls } => {
|
||||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
|
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||||
@@ -94,4 +100,10 @@ impl RouterFactory {
|
|||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an IGW router (placeholder for future implementation)
|
||||||
|
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
|
// For now, return an error indicating IGW is not yet implemented
|
||||||
|
Err("IGW mode is not yet implemented".to_string())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ impl TestContext {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
Self::new_with_config(config, worker_configs).await
|
Self::new_with_config(config, worker_configs).await
|
||||||
@@ -1093,6 +1094,7 @@ mod error_tests {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = TestContext::new_with_config(
|
let ctx = TestContext::new_with_config(
|
||||||
@@ -1444,6 +1446,7 @@ mod pd_mode_tests {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create app context
|
// Create app context
|
||||||
@@ -1599,6 +1602,7 @@ mod request_id_tests {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = TestContext::new_with_config(
|
let ctx = TestContext::new_with_config(
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ impl TestContext {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut workers = Vec::new();
|
let mut workers = Vec::new();
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ impl TestContext {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut workers = Vec::new();
|
let mut workers = Vec::new();
|
||||||
|
|||||||
@@ -184,6 +184,7 @@ mod test_pd_routing {
|
|||||||
disable_retries: false,
|
disable_retries: false,
|
||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Router creation will fail due to health checks, but config should be valid
|
// Router creation will fail due to health checks, but config should be valid
|
||||||
|
|||||||
Reference in New Issue
Block a user