[router] add worker abstraction (#7960)
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
mod test_pd_routing {
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy};
|
||||
use sglang_router_rs::pd_types::PDSelectionPolicy;
|
||||
use sglang_router_rs::router::{PolicyConfig, Router};
|
||||
|
||||
// Test-only struct to help validate PD request parsing
|
||||
@@ -51,40 +51,35 @@ mod test_pd_routing {
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_engine_info_creation() {
|
||||
// Test EngineInfo creation for prefill servers
|
||||
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
match prefill_engine.engine_type {
|
||||
EngineType::Prefill => (),
|
||||
_ => panic!("Expected Prefill engine type"),
|
||||
}
|
||||
assert_eq!(prefill_engine.url, "http://prefill:8080");
|
||||
assert_eq!(prefill_engine.bootstrap_port, Some(9000));
|
||||
assert_eq!(prefill_engine.get_hostname(), "prefill");
|
||||
fn test_worker_types() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
|
||||
// Test EngineInfo creation for decode servers
|
||||
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string());
|
||||
match decode_engine.engine_type {
|
||||
EngineType::Decode => (),
|
||||
_ => panic!("Expected Decode engine type"),
|
||||
// Test worker creation for prefill servers
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
assert_eq!(prefill_worker.url(), "http://prefill:8080");
|
||||
match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => {
|
||||
assert_eq!(bootstrap_port, Some(9000));
|
||||
}
|
||||
_ => panic!("Expected Prefill worker type"),
|
||||
}
|
||||
assert_eq!(decode_engine.url, "http://decode:8080");
|
||||
assert_eq!(decode_engine.bootstrap_port, None);
|
||||
assert_eq!(decode_engine.get_hostname(), "decode");
|
||||
|
||||
// Test API path generation
|
||||
assert_eq!(
|
||||
prefill_engine.api_path("/generate"),
|
||||
"http://prefill:8080/generate"
|
||||
);
|
||||
assert_eq!(
|
||||
prefill_engine.api_path("health"),
|
||||
"http://prefill:8080/health"
|
||||
);
|
||||
assert_eq!(
|
||||
decode_engine.api_path("/v1/chat/completions"),
|
||||
"http://decode:8080/v1/chat/completions"
|
||||
);
|
||||
// Test worker creation for decode servers
|
||||
let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
|
||||
assert_eq!(decode_worker.url(), "http://decode:8080");
|
||||
match decode_worker.worker_type() {
|
||||
WorkerType::Decode => (),
|
||||
_ => panic!("Expected Decode worker type"),
|
||||
}
|
||||
|
||||
// Test regular worker creation
|
||||
let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
|
||||
assert_eq!(regular_worker.url(), "http://regular:8080");
|
||||
match regular_worker.worker_type() {
|
||||
WorkerType::Regular => (),
|
||||
_ => panic!("Expected Regular worker type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -230,6 +225,9 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_simulation() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Since we can't test the actual inject_bootstrap_fields function here
|
||||
// (it's private in the router module), we'll test the expected behavior
|
||||
|
||||
@@ -240,15 +238,24 @@ mod test_pd_routing {
|
||||
"temperature": 0.7
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Simulate what inject_bootstrap_fields would do
|
||||
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000));
|
||||
single_json["bootstrap_host"] = json!(prefill_info.get_hostname());
|
||||
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
|
||||
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
|
||||
single_json["bootstrap_port"] = json!(bootstrap_port);
|
||||
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
||||
|
||||
// Verify bootstrap fields are added correctly
|
||||
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
||||
assert_eq!(single_json["bootstrap_port"], 9000);
|
||||
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
|
||||
assert!(single_json["bootstrap_room"].is_u64());
|
||||
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
||||
|
||||
@@ -259,8 +266,9 @@ mod test_pd_routing {
|
||||
});
|
||||
|
||||
let batch_size = 3;
|
||||
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
||||
|
||||
// Verify batch bootstrap fields
|
||||
@@ -306,7 +314,9 @@ mod test_pd_routing {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_engine_info_hostname_extraction() {
|
||||
fn test_hostname_extraction() {
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test various URL formats
|
||||
let test_cases = vec![
|
||||
("http://localhost:8080", "localhost"),
|
||||
@@ -318,8 +328,7 @@ mod test_pd_routing {
|
||||
];
|
||||
|
||||
for (url, expected_hostname) in test_cases {
|
||||
let engine = EngineInfo::new_prefill(url.to_string(), None);
|
||||
assert_eq!(engine.get_hostname(), expected_hostname);
|
||||
assert_eq!(get_hostname(url), expected_hostname);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,6 +661,9 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test bootstrap injection with actual benchmark request patterns
|
||||
let mut benchmark_request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
||||
@@ -664,12 +676,20 @@ mod test_pd_routing {
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Simulate bootstrap injection
|
||||
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
let batch_size = 16;
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
|
||||
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
let batch_size = 16;
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
|
||||
benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
benchmark_request["bootstrap_room"] =
|
||||
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
||||
|
||||
@@ -770,6 +790,9 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_large_batch_bootstrap_injection() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test bootstrap injection performance with very large batches
|
||||
// This simulates the bench_one_batch_server.py scenario
|
||||
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||
@@ -787,14 +810,19 @@ mod test_pd_routing {
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Simulate bootstrap injection
|
||||
let prefill_info =
|
||||
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
|
||||
large_batch_request["bootstrap_host"] =
|
||||
json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||
large_batch_request["bootstrap_port"] =
|
||||
json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
|
||||
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
large_batch_request["bootstrap_room"] = json!((0..batch_size)
|
||||
.map(|_| rand::thread_rng().gen::<u64>())
|
||||
.collect::<Vec<_>>());
|
||||
|
||||
Reference in New Issue
Block a user