[router] add worker abstraction (#7960)

This commit is contained in:
Simo Lin
2025-07-11 20:17:48 -07:00
committed by GitHub
parent 2a2d3478af
commit f2d5c4920e
11 changed files with 960 additions and 410 deletions

View File

@@ -12,7 +12,7 @@
mod test_pd_routing {
use rand::Rng;
use serde_json::json;
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy};
use sglang_router_rs::pd_types::PDSelectionPolicy;
use sglang_router_rs::router::{PolicyConfig, Router};
// Test-only struct to help validate PD request parsing
@@ -51,40 +51,35 @@ mod test_pd_routing {
// ========================================================================
#[test]
fn test_engine_info_creation() {
// Test EngineInfo creation for prefill servers
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
match prefill_engine.engine_type {
EngineType::Prefill => (),
_ => panic!("Expected Prefill engine type"),
}
assert_eq!(prefill_engine.url, "http://prefill:8080");
assert_eq!(prefill_engine.bootstrap_port, Some(9000));
assert_eq!(prefill_engine.get_hostname(), "prefill");
fn test_worker_types() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
// Test EngineInfo creation for decode servers
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string());
match decode_engine.engine_type {
EngineType::Decode => (),
_ => panic!("Expected Decode engine type"),
// Test worker creation for prefill servers
let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
assert_eq!(prefill_worker.url(), "http://prefill:8080");
match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => {
assert_eq!(bootstrap_port, Some(9000));
}
_ => panic!("Expected Prefill worker type"),
}
assert_eq!(decode_engine.url, "http://decode:8080");
assert_eq!(decode_engine.bootstrap_port, None);
assert_eq!(decode_engine.get_hostname(), "decode");
// Test API path generation
assert_eq!(
prefill_engine.api_path("/generate"),
"http://prefill:8080/generate"
);
assert_eq!(
prefill_engine.api_path("health"),
"http://prefill:8080/health"
);
assert_eq!(
decode_engine.api_path("/v1/chat/completions"),
"http://decode:8080/v1/chat/completions"
);
// Test worker creation for decode servers
let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
assert_eq!(decode_worker.url(), "http://decode:8080");
match decode_worker.worker_type() {
WorkerType::Decode => (),
_ => panic!("Expected Decode worker type"),
}
// Test regular worker creation
let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
assert_eq!(regular_worker.url(), "http://regular:8080");
match regular_worker.worker_type() {
WorkerType::Regular => (),
_ => panic!("Expected Regular worker type"),
}
}
#[test]
@@ -230,6 +225,9 @@ mod test_pd_routing {
#[test]
fn test_bootstrap_injection_simulation() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior
@@ -240,15 +238,24 @@ mod test_pd_routing {
"temperature": 0.7
});
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
// Simulate what inject_bootstrap_fields would do
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000));
single_json["bootstrap_host"] = json!(prefill_info.get_hostname());
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
single_json["bootstrap_port"] = json!(bootstrap_port);
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
// Verify bootstrap fields are added correctly
assert_eq!(single_json["bootstrap_host"], "prefill1");
assert_eq!(single_json["bootstrap_port"], 9000);
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
assert!(single_json["bootstrap_room"].is_u64());
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
@@ -259,8 +266,9 @@ mod test_pd_routing {
});
let batch_size = 3;
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
let hostname = get_hostname(prefill_worker.url());
batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
// Verify batch bootstrap fields
@@ -306,7 +314,9 @@ mod test_pd_routing {
}
#[test]
fn test_engine_info_hostname_extraction() {
fn test_hostname_extraction() {
use sglang_router_rs::pd_types::get_hostname;
// Test various URL formats
let test_cases = vec![
("http://localhost:8080", "localhost"),
@@ -318,8 +328,7 @@ mod test_pd_routing {
];
for (url, expected_hostname) in test_cases {
let engine = EngineInfo::new_prefill(url.to_string(), None);
assert_eq!(engine.get_hostname(), expected_hostname);
assert_eq!(get_hostname(url), expected_hostname);
}
}
@@ -652,6 +661,9 @@ mod test_pd_routing {
#[test]
fn test_bootstrap_injection_with_benchmark_requests() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
@@ -664,12 +676,20 @@ mod test_pd_routing {
"stream": true
});
// Simulate bootstrap injection
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
let batch_size = 16;
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let batch_size = 16;
let hostname = get_hostname(prefill_worker.url());
benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
benchmark_request["bootstrap_room"] =
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
@@ -770,6 +790,9 @@ mod test_pd_routing {
#[test]
fn test_large_batch_bootstrap_injection() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario
let large_batch_sizes = vec![1024, 4096, 8192];
@@ -787,14 +810,19 @@ mod test_pd_routing {
"stream": true
});
// Simulate bootstrap injection
let prefill_info =
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
large_batch_request["bootstrap_host"] =
json!(vec![prefill_info.get_hostname(); batch_size]);
large_batch_request["bootstrap_port"] =
json!(vec![prefill_info.bootstrap_port; batch_size]);
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
large_batch_request["bootstrap_room"] = json!((0..batch_size)
.map(|_| rand::thread_rng().gen::<u64>())
.collect::<Vec<_>>());