[router] Refactor router and policy traits with dependency injection (#7987)
Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: Keru Yang <rukeyang@gmail.com> Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: Philip Zhu <phlipzhux@gmail.com>
This commit is contained in:
@@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||
SamplingParams, StringOrArray, UserMessageContent,
|
||||
};
|
||||
use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest};
|
||||
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
|
||||
|
||||
#[test]
|
||||
fn test_benchmark_request_creation() {
|
||||
|
||||
@@ -8,12 +8,18 @@
|
||||
//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type.
|
||||
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
|
||||
|
||||
// TODO: This test file needs to be updated for the new configuration structure
|
||||
// where RoutingMode and PolicyConfig are separate
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_pd_routing {
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::pd_types::PDSelectionPolicy;
|
||||
use sglang_router_rs::router::{PolicyConfig, Router};
|
||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::routers::pd_types::get_hostname;
|
||||
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
|
||||
use sglang_router_rs::routers::RouterFactory;
|
||||
|
||||
// Test-only struct to help validate PD request parsing
|
||||
#[derive(Debug)]
|
||||
@@ -116,49 +122,68 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_pd_router_configuration() {
|
||||
// Test PrefillDecodeConfig creation with various policies
|
||||
// This config is used when pd_disaggregation=true
|
||||
let configs = vec![
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy::Random,
|
||||
prefill_urls: vec![
|
||||
("http://prefill1:8080".to_string(), Some(9000)),
|
||||
("http://prefill2:8080".to_string(), None),
|
||||
],
|
||||
decode_urls: vec![
|
||||
"http://decode1:8080".to_string(),
|
||||
"http://decode2:8080".to_string(),
|
||||
],
|
||||
timeout_secs: 10,
|
||||
interval_secs: 1,
|
||||
},
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy::PowerOfTwo,
|
||||
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
|
||||
decode_urls: vec!["http://decode:8080".to_string()],
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
},
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy::CacheAware {
|
||||
// Test PD router configuration with various policies
|
||||
// In the new structure, RoutingMode and PolicyConfig are separate
|
||||
let test_cases = vec![
|
||||
(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
("http://prefill1:8080".to_string(), Some(9000)),
|
||||
("http://prefill2:8080".to_string(), None),
|
||||
],
|
||||
decode_urls: vec![
|
||||
"http://decode1:8080".to_string(),
|
||||
"http://decode2:8080".to_string(),
|
||||
],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
),
|
||||
(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
|
||||
decode_urls: vec!["http://decode:8080".to_string()],
|
||||
},
|
||||
PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 5,
|
||||
},
|
||||
),
|
||||
(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
("http://p1:8080".to_string(), Some(9000)),
|
||||
("http://p2:8080".to_string(), Some(9001)),
|
||||
("http://p3:8080".to_string(), Some(9002)),
|
||||
],
|
||||
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.7,
|
||||
balance_abs_threshold: 20,
|
||||
balance_rel_threshold: 1.2,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000000,
|
||||
},
|
||||
prefill_urls: vec![
|
||||
("http://p1:8080".to_string(), Some(9000)),
|
||||
("http://p2:8080".to_string(), Some(9001)),
|
||||
("http://p3:8080".to_string(), Some(9002)),
|
||||
],
|
||||
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
|
||||
timeout_secs: 10,
|
||||
interval_secs: 2,
|
||||
},
|
||||
),
|
||||
];
|
||||
|
||||
for config in configs {
|
||||
for (mode, policy) in test_cases {
|
||||
let config = RouterConfig {
|
||||
mode,
|
||||
policy,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3001,
|
||||
max_payload_size: 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 10,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: None,
|
||||
};
|
||||
|
||||
// Router creation will fail due to health checks, but config should be valid
|
||||
let result = Router::new(vec![], config);
|
||||
let result = RouterFactory::create_router(&config);
|
||||
assert!(result.is_err());
|
||||
let error_msg = result.unwrap_err();
|
||||
// Error should be about health/timeout, not configuration
|
||||
@@ -225,9 +250,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_simulation() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Since we can't test the actual inject_bootstrap_fields function here
|
||||
// (it's private in the router module), we'll test the expected behavior
|
||||
|
||||
@@ -315,8 +337,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_hostname_extraction() {
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test various URL formats
|
||||
let test_cases = vec![
|
||||
("http://localhost:8080", "localhost"),
|
||||
@@ -662,7 +682,6 @@ mod test_pd_routing {
|
||||
#[test]
|
||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test bootstrap injection with actual benchmark request patterns
|
||||
let mut benchmark_request = json!({
|
||||
@@ -790,9 +809,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_large_batch_bootstrap_injection() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test bootstrap injection performance with very large batches
|
||||
// This simulates the bench_one_batch_server.py scenario
|
||||
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||
|
||||
Reference in New Issue
Block a user