[router] remove old/oudated/useless comments across code base (#10968)
This commit is contained in:
@@ -9,7 +9,6 @@ mod test_pd_routing {
|
||||
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
|
||||
use sglang_router_rs::routers::RouterFactory;
|
||||
|
||||
// Test-only struct to help validate PD request parsing
|
||||
#[derive(Debug)]
|
||||
struct PDRequest {
|
||||
pub is_stream: bool,
|
||||
@@ -17,14 +16,12 @@ mod test_pd_routing {
|
||||
}
|
||||
|
||||
impl PDRequest {
|
||||
// Extract PD-relevant info from JSON for testing
|
||||
pub fn from_json(json: &serde_json::Value) -> Self {
|
||||
let is_stream = json
|
||||
.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Detect batch size from text or input_ids
|
||||
let batch_size = if let Some(text) = json.get("text") {
|
||||
text.as_array().map(|arr| arr.len())
|
||||
} else if let Some(input_ids) = json.get("input_ids") {
|
||||
@@ -40,15 +37,10 @@ mod test_pd_routing {
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Phase 1: Basic PD Components and Router Creation
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_worker_types() {
|
||||
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
||||
|
||||
// Test worker creation for prefill servers
|
||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
@@ -65,7 +57,6 @@ mod test_pd_routing {
|
||||
_ => panic!("Expected Prefill worker type"),
|
||||
}
|
||||
|
||||
// Test worker creation for decode servers
|
||||
let decode_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://decode:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
@@ -78,7 +69,6 @@ mod test_pd_routing {
|
||||
_ => panic!("Expected Decode worker type"),
|
||||
}
|
||||
|
||||
// Test regular worker creation
|
||||
let regular_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://regular:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
@@ -94,7 +84,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_pd_selection_policies() {
|
||||
// Test all PD selection policy variants
|
||||
// Note: These policies are only used when pd_disaggregation=true
|
||||
let policies = vec![
|
||||
PDSelectionPolicy::Random,
|
||||
@@ -107,7 +96,6 @@ mod test_pd_routing {
|
||||
];
|
||||
|
||||
for policy in policies {
|
||||
// Verify each policy can be created and matched
|
||||
match &policy {
|
||||
PDSelectionPolicy::Random => {
|
||||
assert!(matches!(policy, PDSelectionPolicy::Random));
|
||||
@@ -126,7 +114,6 @@ mod test_pd_routing {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pd_router_configuration() {
|
||||
// Test PD router configuration with various policies
|
||||
// In the new structure, RoutingMode and PolicyConfig are separate
|
||||
let test_cases = vec![
|
||||
(
|
||||
@@ -221,7 +208,6 @@ mod test_pd_routing {
|
||||
"Router creation should succeed with empty worker"
|
||||
);
|
||||
|
||||
// Verify that no workers are registered since we didn't initialize them
|
||||
let stats = app_context.worker_registry.stats();
|
||||
assert_eq!(
|
||||
stats.total_workers, 0,
|
||||
@@ -230,13 +216,8 @@ mod test_pd_routing {
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Phase 2: Bootstrap Injection and Request Handling
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_pd_request_from_json() {
|
||||
// Test PDRequest parsing from single text request
|
||||
let single_json = json!({
|
||||
"text": "Hello world",
|
||||
"stream": false,
|
||||
@@ -248,7 +229,6 @@ mod test_pd_routing {
|
||||
assert!(!pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
|
||||
// Test PDRequest parsing from batch text request
|
||||
let batch_json = json!({
|
||||
"text": ["Hello", "World", "Test"],
|
||||
"stream": true,
|
||||
@@ -259,7 +239,6 @@ mod test_pd_routing {
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, Some(3));
|
||||
|
||||
// Test PDRequest parsing from input_ids request
|
||||
let ids_json = json!({
|
||||
"input_ids": [[1, 2, 3], [4, 5, 6]],
|
||||
"stream": false
|
||||
@@ -269,7 +248,6 @@ mod test_pd_routing {
|
||||
assert!(!pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, Some(2));
|
||||
|
||||
// Test PDRequest parsing from chat request
|
||||
let chat_json = json!({
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
@@ -288,14 +266,12 @@ mod test_pd_routing {
|
||||
// Since we can't test the actual inject_bootstrap_fields function here
|
||||
// (it's private in the router module), we'll test the expected behavior
|
||||
|
||||
// Simulate bootstrap injection for single request
|
||||
let mut single_json = json!({
|
||||
"text": "Hello world",
|
||||
"stream": false,
|
||||
"temperature": 0.7
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill1:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
@@ -305,24 +281,20 @@ mod test_pd_routing {
|
||||
.build(),
|
||||
);
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Simulate what inject_bootstrap_fields would do
|
||||
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
|
||||
single_json["bootstrap_port"] = json!(bootstrap_port);
|
||||
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
||||
|
||||
// Verify bootstrap fields are added correctly
|
||||
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
||||
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
|
||||
assert!(single_json["bootstrap_room"].is_u64());
|
||||
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
||||
|
||||
// Simulate bootstrap injection for batch request
|
||||
let mut batch_json = json!({
|
||||
"text": ["Hello", "World", "Test"],
|
||||
"stream": true
|
||||
@@ -334,7 +306,6 @@ mod test_pd_routing {
|
||||
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
||||
|
||||
// Verify batch bootstrap fields
|
||||
assert!(batch_json["bootstrap_host"].is_array());
|
||||
assert_eq!(
|
||||
batch_json["bootstrap_host"].as_array().unwrap().len(),
|
||||
@@ -347,7 +318,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_request_serialization() {
|
||||
// Test that requests can be properly serialized and deserialized
|
||||
let request = json!({
|
||||
"text": "Test prompt",
|
||||
"stream": false,
|
||||
@@ -360,13 +330,10 @@ mod test_pd_routing {
|
||||
"bootstrap_room": 12345u64
|
||||
});
|
||||
|
||||
// Convert to bytes (as would happen in the router)
|
||||
let bytes = serde_json::to_vec(&request).unwrap();
|
||||
|
||||
// Parse back from bytes
|
||||
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||
|
||||
// Verify all fields are preserved
|
||||
assert_eq!(parsed["text"], "Test prompt");
|
||||
assert_eq!(parsed["stream"], false);
|
||||
assert_eq!(parsed["temperature"], 0.7);
|
||||
@@ -378,7 +345,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_hostname_extraction() {
|
||||
// Test various URL formats
|
||||
let test_cases = vec![
|
||||
("http://localhost:8080", "localhost"),
|
||||
("http://10.0.0.1:8080", "10.0.0.1"),
|
||||
@@ -395,13 +361,11 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_pd_request_edge_cases() {
|
||||
// Test empty request
|
||||
let empty_json = json!({});
|
||||
let pd_req = PDRequest::from_json(&empty_json);
|
||||
assert!(!pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
|
||||
// Test request with only stream field
|
||||
let stream_only = json!({
|
||||
"stream": true
|
||||
});
|
||||
@@ -409,14 +373,12 @@ mod test_pd_routing {
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
|
||||
// Test request with empty text array
|
||||
let empty_batch = json!({
|
||||
"text": []
|
||||
});
|
||||
let pd_req = PDRequest::from_json(&empty_batch);
|
||||
assert_eq!(pd_req.batch_size, Some(0));
|
||||
|
||||
// Test request with non-array text (should be None)
|
||||
let non_array_text = json!({
|
||||
"text": "single string"
|
||||
});
|
||||
@@ -424,29 +386,21 @@ mod test_pd_routing {
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Phase 2: Background Load Monitoring Tests
|
||||
// ========================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_background_load_monitoring() {
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::watch;
|
||||
|
||||
// Create a watch channel for testing
|
||||
let (tx, rx) = watch::channel(HashMap::new());
|
||||
|
||||
// Simulate load updates
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("http://prefill1:8080".to_string(), 10);
|
||||
loads.insert("http://prefill2:8080".to_string(), 20);
|
||||
loads.insert("http://decode1:8080".to_string(), 5);
|
||||
loads.insert("http://decode2:8080".to_string(), 15);
|
||||
|
||||
// Send the loads
|
||||
tx.send(loads.clone()).unwrap();
|
||||
|
||||
// Verify receiver gets the update
|
||||
let received_loads = rx.borrow();
|
||||
assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10));
|
||||
assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20));
|
||||
@@ -456,7 +410,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_load_monitoring_configuration() {
|
||||
// Test that load monitoring is only enabled for PowerOfTwo policy
|
||||
let policies = vec![
|
||||
(PDSelectionPolicy::Random, false),
|
||||
(PDSelectionPolicy::PowerOfTwo, true),
|
||||
@@ -483,42 +436,31 @@ mod test_pd_routing {
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::watch;
|
||||
|
||||
// Test watch channel's broadcast behavior
|
||||
let (tx, rx1) = watch::channel(HashMap::new());
|
||||
let rx2 = rx1.clone();
|
||||
|
||||
// Initial state - empty map
|
||||
assert!(rx1.borrow().is_empty());
|
||||
assert!(rx2.borrow().is_empty());
|
||||
|
||||
// Update 1
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("worker1".to_string(), 10);
|
||||
tx.send(loads.clone()).unwrap();
|
||||
|
||||
// Both receivers see the update
|
||||
assert_eq!(rx1.borrow().get("worker1"), Some(&10));
|
||||
assert_eq!(rx2.borrow().get("worker1"), Some(&10));
|
||||
|
||||
// Update 2 - overwrites previous
|
||||
loads.insert("worker1".to_string(), 20);
|
||||
loads.insert("worker2".to_string(), 30);
|
||||
tx.send(loads).unwrap();
|
||||
|
||||
// Both receivers see the latest state
|
||||
assert_eq!(rx1.borrow().get("worker1"), Some(&20));
|
||||
assert_eq!(rx2.borrow().get("worker2"), Some(&30));
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Tests based on bench_one_batch_server.py patterns
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_generate_request_formats() {
|
||||
// Based on bench_one_batch_server.py request patterns
|
||||
|
||||
// Test 1: Batch request with input_ids (most common in benchmarks)
|
||||
let batch_request = json!({
|
||||
"input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
"sampling_params": {
|
||||
@@ -534,7 +476,6 @@ mod test_pd_routing {
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, Some(3));
|
||||
|
||||
// Test 2: Request with return_logprob (critical for PD)
|
||||
let logprob_request = json!({
|
||||
"input_ids": [[1, 2, 3]],
|
||||
"sampling_params": {
|
||||
@@ -548,7 +489,6 @@ mod test_pd_routing {
|
||||
assert_eq!(logprob_request["return_logprob"], true);
|
||||
assert_eq!(logprob_request["stream"], false);
|
||||
|
||||
// Test 3: Large batch sizes from benchmark
|
||||
let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py
|
||||
for bs in batch_sizes {
|
||||
let request = json!({
|
||||
@@ -567,7 +507,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_sampling_params_handling() {
|
||||
// Test various sampling parameters from bench_one_batch_server.py
|
||||
let sampling_params_variations = vec![
|
||||
json!({
|
||||
"temperature": 0.0,
|
||||
@@ -595,14 +534,12 @@ mod test_pd_routing {
|
||||
"stream": false
|
||||
});
|
||||
|
||||
// Verify params are preserved
|
||||
assert_eq!(request["sampling_params"], params);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_response_parsing() {
|
||||
// Test SSE format parsing from streaming responses
|
||||
let sse_chunks = ["data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
|
||||
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
|
||||
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
|
||||
@@ -615,13 +552,11 @@ mod test_pd_routing {
|
||||
assert!(parsed["meta_info"]["completion_tokens"].is_u64());
|
||||
}
|
||||
|
||||
// Test [DONE] detection
|
||||
assert_eq!(sse_chunks[3], "data: [DONE]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ttft_calculation() {
|
||||
// Test Time To First Token calculation pattern
|
||||
let first_token_response = json!({
|
||||
"text": "Hello",
|
||||
"meta_info": {
|
||||
@@ -637,7 +572,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_throughput_metrics() {
|
||||
// Test throughput calculation patterns from bench_one_batch_server.py
|
||||
let batch_size = 16;
|
||||
let input_len = 1024;
|
||||
let output_len = 16;
|
||||
@@ -655,7 +589,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_error_response_handling() {
|
||||
// Test error response format from bench_one_batch_server.py
|
||||
let error_response = json!({
|
||||
"error": "Request has failed. Invalid input format."
|
||||
});
|
||||
@@ -666,7 +599,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_structured_output_request() {
|
||||
// Test structured output format (json_schema)
|
||||
let structured_request = json!({
|
||||
"text": "What is the capital of France? Answer in JSON.",
|
||||
"sampling_params": {
|
||||
@@ -687,7 +619,6 @@ mod test_pd_routing {
|
||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
||||
|
||||
// Test bootstrap injection with actual benchmark request patterns
|
||||
let mut benchmark_request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
||||
"sampling_params": {
|
||||
@@ -699,7 +630,6 @@ mod test_pd_routing {
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
@@ -709,7 +639,6 @@ mod test_pd_routing {
|
||||
.build(),
|
||||
);
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
@@ -722,7 +651,6 @@ mod test_pd_routing {
|
||||
benchmark_request["bootstrap_room"] =
|
||||
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
||||
|
||||
// Verify bootstrap fields match batch size
|
||||
assert_eq!(
|
||||
benchmark_request["bootstrap_host"]
|
||||
.as_array()
|
||||
@@ -745,14 +673,12 @@ mod test_pd_routing {
|
||||
batch_size
|
||||
);
|
||||
|
||||
// Verify original fields are preserved
|
||||
assert_eq!(benchmark_request["return_logprob"], true);
|
||||
assert_eq!(benchmark_request["stream"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_info_response_format() {
|
||||
// Test server info format expected by bench_one_batch_server.py
|
||||
let server_info = json!({
|
||||
"internal_states": [{
|
||||
"avg_spec_accept_length": 3.5,
|
||||
@@ -769,16 +695,13 @@ mod test_pd_routing {
|
||||
]
|
||||
});
|
||||
|
||||
// Verify structure matches what benchmark expects
|
||||
assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64());
|
||||
assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64());
|
||||
assert!(server_info["prefill"].is_array());
|
||||
assert!(server_info["decode"].is_array());
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Comprehensive Endpoint Coverage Test
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_pd_endpoints_coverage() {
|
||||
@@ -807,7 +730,6 @@ mod test_pd_routing {
|
||||
assert_eq!(implemented_count, 10);
|
||||
assert_eq!(total_count, 11);
|
||||
|
||||
// Document the missing endpoint
|
||||
let missing: Vec<_> = implemented_endpoints
|
||||
.iter()
|
||||
.filter(|(_, _, impl_status)| !impl_status)
|
||||
@@ -819,14 +741,12 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_large_batch_bootstrap_injection() {
|
||||
// Test bootstrap injection performance with very large batches
|
||||
// This simulates the bench_one_batch_server.py scenario
|
||||
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||
|
||||
for batch_size in large_batch_sizes {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Simulate a large batch request
|
||||
let mut large_batch_request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3, 4]; batch_size],
|
||||
"sampling_params": {
|
||||
@@ -836,7 +756,6 @@ mod test_pd_routing {
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://prefill:8080")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
@@ -846,7 +765,6 @@ mod test_pd_routing {
|
||||
.build(),
|
||||
);
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
@@ -861,7 +779,6 @@ mod test_pd_routing {
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
// Verify bootstrap fields are correctly sized
|
||||
assert_eq!(
|
||||
large_batch_request["bootstrap_host"]
|
||||
.as_array()
|
||||
@@ -899,7 +816,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_payload_size_calculation() {
|
||||
// Test payload size estimation for bench_one_batch_server.py scenarios
|
||||
let test_cases = vec![
|
||||
(1, 1024, 16), // Small batch
|
||||
(16, 1024, 16), // Medium batch
|
||||
@@ -937,14 +853,12 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_policy_type_to_pd_selection_policy_mapping() {
|
||||
// Test that PDSelectionPolicy doesn't include RoundRobin
|
||||
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
|
||||
assert_eq!(
|
||||
pd_policy_count, 3,
|
||||
"PDSelectionPolicy should have exactly 3 variants"
|
||||
);
|
||||
|
||||
// Verify that each PDSelectionPolicy variant can be created
|
||||
let _random = PDSelectionPolicy::Random;
|
||||
let _po2 = PDSelectionPolicy::PowerOfTwo;
|
||||
let _cache_aware = PDSelectionPolicy::CacheAware {
|
||||
|
||||
Reference in New Issue
Block a user