diff --git a/.github/workflows/pr-test-pd-router.yml b/.github/workflows/pr-test-pd-router.yml index 271a8b3d9..91e809123 100644 --- a/.github/workflows/pr-test-pd-router.yml +++ b/.github/workflows/pr-test-pd-router.yml @@ -131,110 +131,199 @@ jobs: SERVER_PID=$! echo "server_pid=$SERVER_PID" >> $GITHUB_OUTPUT - echo "Waiting for router to become healthy..." - TIMEOUT=300 - ELAPSED=0 - while [ $ELAPSED -lt $TIMEOUT ]; do - if curl --connect-timeout 5 --silent http://127.0.0.9:8000 > /dev/null 2>&1; then - echo "✓ Router is reachable" - break + # Wait for all 8 servers to be healthy (script already does this) + wait_count=0 + while [ $wait_count -lt 30 ]; do + if ps -p $SERVER_PID > /dev/null; then + # Check if the startup script printed success message + sleep 2 + wait_count=$((wait_count + 1)) + else + # Script exited - check if it was successful + wait $SERVER_PID + exit_code=$? + if [ $exit_code -eq 0 ]; then + echo "✓ All disaggregation servers are healthy" + break + else + echo "Error: Server startup failed with code $exit_code" + exit 1 + fi fi - if ! ps -p $SERVER_PID > /dev/null; then - echo "Error: Server processes failed to start" - exit 1 - fi - echo "Waiting for router... (${ELAPSED}s/${TIMEOUT}s)" - sleep 10 - ELAPSED=$((ELAPSED + 10)) done - if [ $ELAPSED -ge $TIMEOUT ]; then - echo "Error: Router health check timeout after ${TIMEOUT}s" - exit 1 - fi + echo "✓ Servers started (PID: $SERVER_PID)" - echo "✓ Servers started and healthy (PID: $SERVER_PID)" - - - name: Test API functionality - timeout-minutes: 5 + - name: Test all policies sequentially + timeout-minutes: 30 run: | + POLICIES=("random" "round_robin" "cache_aware" "power_of_two") BASE_URL="http://127.0.0.9:8000" - echo "Testing API completions..." - response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer test-token" \ - -d '{ - "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", - "messages": [ - {"role": "user", "content": "Write a Python function to calculate fibonacci numbers recursively"} - ], - "stream": false, - "max_tokens": 100 - }') + for policy in "${POLICIES[@]}"; do + echo "" + echo "==================================================" + echo "Testing policy: $policy" + echo "==================================================" - if echo "$response" | jq -e '.choices[0].message.content' > /dev/null 2>&1; then - echo "✓ API test passed" - else - echo "✗ API test failed: $response" - exit 1 - fi + # Start router with the current policy + echo "Starting router with policy: $policy..." + python3 -m sglang_router.launch_router \ + --pd-disaggregation \ + --policy "$policy" \ + --prefill http://127.0.0.1:30001 9001 \ + --prefill http://127.0.0.2:30002 9002 \ + --prefill http://127.0.0.3:30003 9003 \ + --prefill http://127.0.0.4:30004 9004 \ + --decode http://127.0.0.5:30005 \ + --decode http://127.0.0.6:30006 \ + --decode http://127.0.0.7:30007 \ + --decode http://127.0.0.8:30008 \ + --host 127.0.0.9 \ + --port 8000 & + ROUTER_PID=$! - echo "Testing streaming API..." - stream_response=$(timeout 30 curl -s -X POST "$BASE_URL/v1/chat/completions" \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer test-token" \ - -d '{ - "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", - "messages": [ - {"role": "user", "content": "Count from 1 to 5"} - ], - "stream": true, - "max_tokens": 50 - }') + # Wait for router to become healthy + echo "Waiting for router to become healthy..." + TIMEOUT=60 + ELAPSED=0 + while [ $ELAPSED -lt $TIMEOUT ]; do + if curl --connect-timeout 5 --silent http://127.0.0.9:8000 > /dev/null 2>&1; then + echo "✓ Router is reachable" + break + fi + if ! ps -p $ROUTER_PID > /dev/null; then + echo "Error: Router process died" + exit 1 + fi + sleep 5 + ELAPSED=$((ELAPSED + 5)) + done - if echo "$stream_response" | grep -q "data:"; then - echo "✓ Streaming API test passed" - else - echo "✗ Streaming API test failed" - exit 1 - fi + if [ $ELAPSED -ge $TIMEOUT ]; then + echo "Error: Router health check timeout" + kill $ROUTER_PID 2>/dev/null || true + exit 1 + fi - - name: Run benchmark test - timeout-minutes: 5 - run: | - echo "Running benchmark test..." - benchmark_output=$(python3 -m sglang.bench_one_batch_server \ - --model-path "/raid/models/meta-llama/Llama-3.1-8B-Instruct" \ - --base-url "http://127.0.0.9:8000" \ - --batch-size 8 \ - --input-len 4096 \ - --output-len 5 \ - --skip-warmup) + # Test API functionality + echo "Testing API completions for $policy..." + response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer test-token" \ + -d '{ + "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", + "messages": [ + {"role": "user", "content": "Write a Python function to calculate fibonacci numbers recursively"} + ], + "stream": false, + "max_tokens": 100 + }') - echo "$benchmark_output" + if echo "$response" | jq -e '.choices[0].message.content' > /dev/null 2>&1; then + echo "✓ API test passed for $policy" + else + echo "✗ API test failed for $policy: $response" + kill $ROUTER_PID 2>/dev/null || true + exit 1 + fi - # Extract metrics from output - latency=$(echo "$benchmark_output" | grep "latency:" | awk '{print $2}' | sed 's/s//') - input_throughput=$(echo "$benchmark_output" | grep "input throughput:" | awk '{print $3}') - output_throughput=$(echo "$benchmark_output" | grep "output throughput:" | awk '{print $3}') + # Test streaming + echo "Testing streaming API for $policy..." + stream_response=$(timeout 30 curl -s -X POST "$BASE_URL/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer test-token" \ + -d '{ + "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", + "messages": [ + {"role": "user", "content": "Count from 1 to 5"} + ], + "stream": true, + "max_tokens": 50 + }') - # Validate performance (latency<1.5s, input>20k, output>1k) - command -v bc >/dev/null || (apt-get update && apt-get install -y bc) + if echo "$stream_response" | grep -q "data:"; then + echo "✓ Streaming API test passed for $policy" + else + echo "✗ Streaming API test failed for $policy" + kill $ROUTER_PID 2>/dev/null || true + exit 1 + fi - echo "Performance: ${latency}s | ${input_throughput} | ${output_throughput} tok/s" + # Run benchmark + echo "Running benchmark for $policy..." + benchmark_output=$(python3 -m sglang.bench_one_batch_server \ + --model-path "/raid/models/meta-llama/Llama-3.1-8B-Instruct" \ + --base-url "http://127.0.0.9:8000" \ + --batch-size 8 \ + --input-len 4096 \ + --output-len 5 \ + --skip-warmup) - fail="" - (( $(echo "$latency > 1.5" | bc -l) )) && fail="Latency too high (${latency}s>1.5s) " - (( $(echo "$input_throughput < 20000" | bc -l) )) && fail="${fail}Input too low (${input_throughput}<20k) " - (( $(echo "$output_throughput < 1000" | bc -l) )) && fail="${fail}Output too low (${output_throughput}<1k) " + echo "$benchmark_output" - if [ -n "$fail" ]; then - echo "✗ Benchmark failed: $fail" - exit 1 - else - echo "✓ Performance validation passed" - fi + # Save benchmark output + echo "$benchmark_output" > "benchmark_${policy}.txt" + + # Extract and validate metrics + latency=$(echo "$benchmark_output" | grep "latency:" | awk '{print $2}' | sed 's/s//') + input_throughput=$(echo "$benchmark_output" | grep "input throughput:" | awk '{print $3}') + output_throughput=$(echo "$benchmark_output" | grep "output throughput:" | awk '{print $3}') + + command -v bc >/dev/null || (apt-get update && apt-get install -y bc) + + echo "Performance for $policy: ${latency}s | ${input_throughput} | ${output_throughput} tok/s" + + # Validate performance + fail="" + (( $(echo "$latency > 1.5" | bc -l) )) && fail="Latency too high (${latency}s>1.5s) " + (( $(echo "$input_throughput < 20000" | bc -l) )) && fail="${fail}Input too low (${input_throughput}<20k) " + (( $(echo "$output_throughput < 1000" | bc -l) )) && fail="${fail}Output too low (${output_throughput}<1k) " + + if [ -n "$fail" ]; then + echo "✗ Benchmark failed for $policy: $fail" + kill $ROUTER_PID 2>/dev/null || true + exit 1 + else + echo "✓ Performance validation passed for $policy" + fi + + # Stop router before testing next policy + echo "Stopping router for $policy..." + # First try graceful shutdown + kill $ROUTER_PID 2>/dev/null || true + + # Wait up to 5 seconds for graceful shutdown + for i in {1..5}; do + if ! ps -p $ROUTER_PID > /dev/null 2>&1; then + echo "Router stopped gracefully" + break + fi + sleep 1 + done + + # Force kill if still running + if ps -p $ROUTER_PID > /dev/null 2>&1; then + echo "Force killing router..." + kill -9 $ROUTER_PID 2>/dev/null || true + fi + + # Short delay to ensure port is released + sleep 2 + + echo "✓ Completed testing for $policy" + done + + echo "" + echo "✅ All policies tested successfully!" + + + - name: Upload benchmark results + if: success() + uses: actions/upload-artifact@v4 + with: + name: benchmark-results-all-policies + path: benchmark_*.txt - name: Cleanup servers if: always() @@ -247,3 +336,34 @@ jobs: sleep 5 remaining=$(ps aux | grep -c "sglang.launch_server" || echo "0") echo "Cleanup completed. Remaining processes: $remaining" + + summarize-benchmarks: + needs: test-disaggregation + runs-on: ubuntu-latest + if: success() + + steps: + - name: Download benchmark results + uses: actions/download-artifact@v4 + with: + name: benchmark-results-all-policies + + - name: Create benchmark summary + run: | + echo "## PD Router Benchmark Results Summary" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Policy | Latency (s) | Input Throughput (tok/s) | Output Throughput (tok/s) |" >> $GITHUB_STEP_SUMMARY + echo "|--------|-------------|-------------------------|--------------------------|" >> $GITHUB_STEP_SUMMARY + + for policy in random round_robin cache_aware power_of_two; do + if [ -f "benchmark_${policy}.txt" ]; then + latency=$(grep "latency:" "benchmark_${policy}.txt" | awk '{print $2}') + input_throughput=$(grep "input throughput:" "benchmark_${policy}.txt" | awk '{print $3}') + output_throughput=$(grep "output throughput:" "benchmark_${policy}.txt" | awk '{print $3}') + + echo "| ${policy} | ${latency} | ${input_throughput} | ${output_throughput} |" >> $GITHUB_STEP_SUMMARY + fi + done + + echo "" >> $GITHUB_STEP_SUMMARY + echo "✅ All policies tested successfully!" >> $GITHUB_STEP_SUMMARY diff --git a/scripts/ci_start_disaggregation_servers.sh b/scripts/ci_start_disaggregation_servers.sh index f652a4f04..22643e0df 100755 --- a/scripts/ci_start_disaggregation_servers.sh +++ b/scripts/ci_start_disaggregation_servers.sh @@ -87,20 +87,8 @@ while true; do fi done -# Launch the router -echo "Launching router at 127.0.0.9:8000..." -python3 -m sglang_router.launch_router \ - --pd-disaggregation \ - --policy power_of_two \ - --prefill http://127.0.0.1:30001 9001 \ - --prefill http://127.0.0.2:30002 9002 \ - --prefill http://127.0.0.3:30003 9003 \ - --prefill http://127.0.0.4:30004 9004 \ - --decode http://127.0.0.5:30005 \ - --decode http://127.0.0.6:30006 \ - --decode http://127.0.0.7:30007 \ - --decode http://127.0.0.8:30008 \ - --host 127.0.0.9 \ - --port 8000 & +# Don't launch router here - just keep servers running +echo "✅ All disaggregation servers are ready and waiting for router connections" -wait # Wait for all background jobs to finish +# Keep the script running +wait # Wait for all background server jobs diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index c2cee90d5..576d07d2f 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, }; -use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest}; +use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; // Sample request data for benchmarks fn create_sample_generate_request() -> GenerateRequest { diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index eb2018283..14a0fa12d 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -164,56 +164,47 @@ class TestLaunchRouter(unittest.TestCase): """Test that policy validation works correctly for PD and regular modes.""" from sglang_router.launch_router import RouterArgs, launch_router - # Test 1: PowerOfTwo is only valid in PD mode + # Test 1: PowerOfTwo requires at least 2 workers args = self.create_router_args( pd_disaggregation=False, policy="power_of_two", - worker_urls=["http://localhost:8000"], + worker_urls=["http://localhost:8000"], # Only 1 worker ) # Should raise error with self.assertRaises(ValueError) as cm: launch_router(args) self.assertIn( - "PowerOfTwo policy is only supported in PD disaggregated mode", + "Power-of-two policy requires at least 2 workers", str(cm.exception), ) - # Test 2: RoundRobin is not valid in PD mode + # Test 2: PowerOfTwo with sufficient workers should succeed args = self.create_router_args( - pd_disaggregation=True, - policy="round_robin", - prefill=[["http://prefill1:8080", "9000"]], - decode=[["http://decode1:8081"]], - worker_urls=[], + pd_disaggregation=False, + policy="power_of_two", + worker_urls=["http://localhost:8000", "http://localhost:8001"], # 2 workers ) + # This should not raise an error (validation passes) - # Should raise error - with self.assertRaises(ValueError) as cm: - launch_router(args) - self.assertIn( - "RoundRobin policy is not supported in PD disaggregated mode", - str(cm.exception), - ) - - # Test 3: Valid combinations should not raise errors + # Test 3: All policies now work in both modes # Regular mode with RoundRobin args = self.create_router_args( pd_disaggregation=False, policy="round_robin", worker_urls=["http://localhost:8000"], ) - # This should not raise (though it may fail to connect) + # This should not raise validation error - # PD mode with PowerOfTwo + # PD mode with RoundRobin (now supported!) args = self.create_router_args( pd_disaggregation=True, - policy="power_of_two", + policy="round_robin", prefill=[["http://prefill1:8080", "9000"]], decode=[["http://decode1:8081"]], worker_urls=[], ) - # This should not raise (though it may fail to connect) + # This should not raise validation error def test_pd_service_discovery_args_parsing(self): """Test PD service discovery CLI argument parsing.""" diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 9d57f439d..6b24a5fd1 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -1,4 +1,4 @@ -use super::{ConfigError, ConfigResult}; +use super::ConfigResult; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -215,6 +215,7 @@ impl RouterConfig { self.metrics.is_some() } + /* Commented out - no longer needed without compatibility layer /// Convert to routing PolicyConfig for internal use pub fn to_routing_policy_config(&self) -> ConfigResult { match (&self.mode, &self.policy) { @@ -291,4 +292,5 @@ impl RouterConfig { } } } + */ } diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index 838742722..381fcce07 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -255,29 +255,8 @@ impl ConfigValidator { /// Validate compatibility between different configuration sections fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> { - // Check mode and policy compatibility - match (&config.mode, &config.policy) { - (RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => { - // PowerOfTwo is only supported in PD mode - return Err(ConfigError::IncompatibleConfig { - reason: "PowerOfTwo policy is only supported in PD disaggregated mode" - .to_string(), - }); - } - (RoutingMode::PrefillDecode { .. }, PolicyConfig::RoundRobin) => { - return Err(ConfigError::IncompatibleConfig { - reason: "RoundRobin policy is not supported in PD disaggregated mode" - .to_string(), - }); - } - (RoutingMode::PrefillDecode { .. }, PolicyConfig::CacheAware { .. }) => { - return Err(ConfigError::IncompatibleConfig { - reason: "CacheAware policy is not supported in PD disaggregated mode" - .to_string(), - }); - } - _ => {} - } + // All policies are now supported for both router types thanks to the unified trait design + // No mode/policy restrictions needed anymore // Check if service discovery is enabled for worker count validation let has_service_discovery = config.discovery.as_ref().map_or(false, |d| d.enabled); @@ -459,8 +438,8 @@ mod tests { } #[test] - fn test_validate_incompatible_policy() { - // RoundRobin with PD mode + fn test_validate_roundrobin_with_pd_mode() { + // RoundRobin with PD mode is now supported let config = RouterConfig::new( RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), None)], @@ -470,16 +449,12 @@ mod tests { ); let result = ConfigValidator::validate(&config); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("RoundRobin policy is not supported in PD disaggregated mode")); + assert!(result.is_ok()); } #[test] fn test_validate_cache_aware_with_pd_mode() { - // CacheAware with PD mode should fail + // CacheAware with PD mode is now supported let config = RouterConfig::new( RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), None)], @@ -495,16 +470,12 @@ mod tests { ); let result = ConfigValidator::validate(&config); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("CacheAware policy is not supported in PD disaggregated mode")); + assert!(result.is_ok()); } #[test] fn test_validate_power_of_two_with_regular_mode() { - // PowerOfTwo with Regular mode should fail + // PowerOfTwo with Regular mode is now supported let config = RouterConfig::new( RoutingMode::Regular { worker_urls: vec![ @@ -518,10 +489,6 @@ mod tests { ); let result = ConfigValidator::validate(&config); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("PowerOfTwo policy is only supported in PD disaggregated mode")); + assert!(result.is_ok()); } } diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 2b1bcffce..49e8cc573 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -4,11 +4,9 @@ pub mod logging; use std::collections::HashMap; pub mod core; pub mod openai_api_types; -pub mod pd_router; -pub mod pd_types; +pub mod policies; pub mod prometheus; -pub mod request_adapter; -pub mod router; +pub mod routers; pub mod server; pub mod service_discovery; pub mod tree; @@ -241,11 +239,6 @@ impl Router { )) })?; - // Convert to internal policy config - let policy_config = router_config - .to_routing_policy_config() - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; - // Create service discovery config if enabled let service_discovery_config = if self.service_discovery { Some(service_discovery::ServiceDiscoveryConfig { @@ -282,8 +275,7 @@ impl Router { server::startup(server::ServerConfig { host: self.host.clone(), port: self.port, - worker_urls: self.worker_urls.clone(), - policy_config, + router_config, max_payload_size: self.max_payload_size, log_dir: self.log_dir.clone(), log_level: self.log_level.clone(), diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs new file mode 100644 index 000000000..db5972ba6 --- /dev/null +++ b/sgl-router/src/policies/cache_aware.rs @@ -0,0 +1,399 @@ +/* + Cache-Aware Load Balancing Router + + This router combines two strategies to optimize both cache utilization and request distribution: + + 1. Cache-Aware Routing (Approximate Tree) + 2. Load Balancing (Shortest Queue with Balance Thresholds) + + The router dynamically switches between these strategies based on load conditions: + - Uses load balancing when the system is imbalanced + - Uses cache-aware routing when the system is balanced + + A system is considered imbalanced if both conditions are met: + 1. (max - min) > abs_threshold + 2. max > rel_threshold * min + + Strategy Details: + + 1. Cache-Aware Routing (Approximate Tree) + ------------------------------------------- + This strategy maintains an approximate radix tree for each worker based on request history, + eliminating the need for direct cache state queries. The tree stores raw text characters + instead of token IDs to avoid tokenization overhead. + + Process: + a. For each request, find the worker with the highest prefix match + b. If match rate > cache_threshold: + Route to the worker with highest match (likely has relevant data cached) + c. If match rate ≤ cache_threshold: + Route to the worker with smallest tree size (most available cache capacity) + d. Background maintenance: + Periodically evict least recently used leaf nodes to prevent memory overflow + + 2. Load Balancing (Shortest Queue) + ------------------------------------------- + This strategy tracks pending request counts per worker and routes new requests + to the least busy worker when the system is detected to be imbalanced. + + Configuration Parameters: + ------------------------ + 1. cache_threshold: (float, 0.0 to 1.0) + Minimum prefix match ratio to use highest-match routing. + Below this threshold, routes to worker with most available cache space. + + 2. balance_abs_threshold: (integer) + Absolute difference threshold for load imbalance detection. + System is potentially imbalanced if (max_load - min_load) > abs_threshold + + 3. balance_rel_threshold: (float) + Relative ratio threshold for load imbalance detection. + System is potentially imbalanced if max_load > min_load * rel_threshold + Used in conjunction with abs_threshold to determine final imbalance state. + + 4. eviction_interval_secs: (integer) + Interval between LRU eviction cycles for the approximate trees. + + 5. max_tree_size: (integer) + Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted + during the next eviction cycle. +*/ + +use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy}; +use crate::core::Worker; +use crate::tree::Tree; +use metrics::{counter, gauge}; +use std::sync::{Arc, Mutex}; +use std::thread; +use std::time::Duration; +use tracing::{debug, info}; + +/// Cache-aware routing policy +/// +/// Routes requests based on cache affinity when load is balanced, +/// switches to shortest-queue routing when load is imbalanced. +#[derive(Debug)] +pub struct CacheAwarePolicy { + config: CacheAwareConfig, + tree: Arc>, + eviction_handle: Option>, +} + +impl CacheAwarePolicy { + pub fn new() -> Self { + Self::with_config(CacheAwareConfig::default()) + } + + pub fn with_config(config: CacheAwareConfig) -> Self { + let tree = Arc::new(Mutex::new(Tree::new())); + + // Start background eviction thread if configured + let eviction_handle = if config.eviction_interval_secs > 0 { + let tree_clone = Arc::clone(&tree); + let max_tree_size = config.max_tree_size; + let interval = config.eviction_interval_secs; + + Some(thread::spawn(move || loop { + thread::sleep(Duration::from_secs(interval)); + + if let Ok(tree_guard) = tree_clone.lock() { + tree_guard.evict_tenant_by_size(max_tree_size); + debug!("Cache eviction completed, max_size: {}", max_tree_size); + } + })) + } else { + None + }; + + Self { + config, + tree, + eviction_handle, + } + } + + /// Initialize the tree with worker URLs + pub fn init_workers(&self, workers: &[Box]) { + if let Ok(tree) = self.tree.lock() { + for worker in workers { + tree.insert("", worker.url()); + } + } + } + + /// Remove a worker from the tree + pub fn remove_worker(&self, url: &str) { + if let Ok(tree) = self.tree.lock() { + tree.remove_tenant(url); + } + } + + /// Run cache eviction to prevent unbounded growth + pub fn evict_cache(&self, max_size: usize) { + if let Ok(tree) = self.tree.lock() { + tree.evict_tenant_by_size(max_size); + } + } +} + +impl LoadBalancingPolicy for CacheAwarePolicy { + fn select_worker( + &self, + workers: &[Box], + request_text: Option<&str>, + ) -> Option { + let healthy_indices = get_healthy_worker_indices(workers); + + if healthy_indices.is_empty() { + return None; + } + + // Get current load statistics + let loads: Vec = workers.iter().map(|w| w.load()).collect(); + let max_load = *loads.iter().max().unwrap_or(&0); + let min_load = *loads.iter().min().unwrap_or(&0); + + // Check if load is imbalanced + let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold + && (max_load as f32) > (min_load as f32 * self.config.balance_rel_threshold); + + if is_imbalanced { + // Log load balancing trigger + let worker_loads: Vec<(String, usize)> = workers + .iter() + .map(|w| (w.url().to_string(), w.load())) + .collect(); + + info!( + "Load balancing triggered due to workload imbalance:\n\ + Max load: {}, Min load: {}\n\ + Current worker loads: {:?}", + max_load, min_load, worker_loads + ); + + counter!("sgl_router_load_balancing_events_total").increment(1); + gauge!("sgl_router_max_load").set(max_load as f64); + gauge!("sgl_router_min_load").set(min_load as f64); + + // Use shortest queue when imbalanced + let min_load_idx = healthy_indices + .iter() + .min_by_key(|&&idx| workers[idx].load()) + .copied()?; + + // Increment processed counter + workers[min_load_idx].increment_processed(); + counter!("sgl_router_processed_requests_total", "worker" => workers[min_load_idx].url().to_string()) + .increment(1); + + return Some(min_load_idx); + } + + // Use cache-aware routing when balanced + let text = request_text.unwrap_or(""); + + if let Ok(tree) = self.tree.lock() { + let (matched_text, matched_worker) = tree.prefix_match(text); + let match_rate = if text.is_empty() { + 0.0 + } else { + matched_text.chars().count() as f32 / text.chars().count() as f32 + }; + + let selected_url = if match_rate > self.config.cache_threshold { + counter!("sgl_router_cache_hits_total").increment(1); + matched_worker.to_string() + } else { + counter!("sgl_router_cache_misses_total").increment(1); + tree.get_smallest_tenant() + }; + + // Find the index of the selected worker + let selected_idx = workers.iter().position(|w| w.url() == selected_url)?; + + // Only proceed if the worker is healthy + if !workers[selected_idx].is_healthy() { + return healthy_indices.first().copied(); + } + + // Update the tree with this request + tree.insert(text, &selected_url); + + // Increment processed counter + workers[selected_idx].increment_processed(); + counter!("sgl_router_processed_requests_total", "worker" => selected_url).increment(1); + + return Some(selected_idx); + } + + // Fallback to first healthy worker if tree operations fail + healthy_indices.first().copied() + } + + fn name(&self) -> &'static str { + "cache_aware" + } + + fn on_request_complete(&self, worker_url: &str, success: bool) { + // Could track success rates per worker for more intelligent routing + if !success { + // Optionally reduce affinity for failed requests + tracing::debug!( + "Request to {} completed with success={}", + worker_url, + success + ); + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn select_worker_pair( + &self, + prefill_workers: &[Box], + decode_workers: &[Box], + request_text: Option<&str>, + ) -> Option<(usize, usize)> { + // In PD mode: + // - Prefill: Use cache-aware routing for better cache utilization + // - Decode: Use least-load routing for better load distribution + + // Select prefill worker using cache-aware logic + let prefill_idx = self.select_worker(prefill_workers, request_text)?; + + // Select decode worker using least-load logic + let healthy_decode = get_healthy_worker_indices(decode_workers); + if healthy_decode.is_empty() { + return None; + } + + let decode_idx = healthy_decode + .iter() + .min_by_key(|&&idx| decode_workers[idx].load()) + .copied()?; + + Some((prefill_idx, decode_idx)) + } +} + +impl Default for CacheAwarePolicy { + fn default() -> Self { + Self::new() + } +} + +impl Drop for CacheAwarePolicy { + fn drop(&mut self) { + // Note: We can't properly stop the eviction thread since it's in an infinite loop + // In a production system, we'd use a channel or atomic flag to signal shutdown + if let Some(handle) = self.eviction_handle.take() { + // The thread will continue running until the program exits + // This is acceptable for now since the router typically runs for the lifetime of the program + drop(handle); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + + #[test] + fn test_cache_aware_with_balanced_load() { + // Create policy without eviction thread for testing + let config = CacheAwareConfig { + eviction_interval_secs: 0, // Disable eviction thread + ..Default::default() + }; + let policy = CacheAwarePolicy::with_config(config); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Initialize the policy with workers + policy.init_workers(&workers); + + // First request should be distributed + let idx1 = policy.select_worker(&workers, Some("hello world")).unwrap(); + + // Same request should go to same worker (cache hit) + let idx2 = policy.select_worker(&workers, Some("hello world")).unwrap(); + assert_eq!(idx1, idx2); + + // Similar request should also go to same worker + let idx3 = policy.select_worker(&workers, Some("hello")).unwrap(); + assert_eq!(idx1, idx3); + } + + #[test] + fn test_cache_aware_with_imbalanced_load() { + let policy = CacheAwarePolicy::with_config(CacheAwareConfig { + cache_threshold: 0.5, + balance_abs_threshold: 5, + balance_rel_threshold: 2.0, + eviction_interval_secs: 0, // Disable eviction thread + max_tree_size: 10000, + }); + + let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular); + let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular); + + // Create significant load imbalance + for _ in 0..20 { + worker1.increment_load(); + } + // worker2 has load 0 + + let workers: Vec> = vec![Box::new(worker1), Box::new(worker2)]; + policy.init_workers(&workers); + + // Should select worker2 (lower load) despite cache affinity + for _ in 0..5 { + let idx = policy.select_worker(&workers, Some("test")).unwrap(); + assert_eq!(idx, 1); // Should always pick worker2 + } + } + + #[test] + fn test_cache_aware_worker_removal() { + let config = CacheAwareConfig { + eviction_interval_secs: 0, // Disable eviction thread + ..Default::default() + }; + let policy = CacheAwarePolicy::with_config(config); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + policy.init_workers(&workers); + + // Route some requests + policy.select_worker(&workers, Some("test1")); + policy.select_worker(&workers, Some("test2")); + + // Remove a worker + policy.remove_worker("http://w1:8000"); + workers[0].set_healthy(false); + + // All requests should now go to worker2 + let idx = policy.select_worker(&workers, Some("test1")).unwrap(); + assert_eq!(idx, 1); + } +} diff --git a/sgl-router/src/policies/factory.rs b/sgl-router/src/policies/factory.rs new file mode 100644 index 000000000..c65785d63 --- /dev/null +++ b/sgl-router/src/policies/factory.rs @@ -0,0 +1,94 @@ +//! Factory for creating load balancing policies + +use super::{ + CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy, + RoundRobinPolicy, +}; +use crate::config::PolicyConfig; +use std::sync::Arc; + +/// Factory for creating policy instances +pub struct PolicyFactory; + +impl PolicyFactory { + /// Create a policy from configuration + pub fn create_from_config(config: &PolicyConfig) -> Arc { + match config { + PolicyConfig::Random => Arc::new(RandomPolicy::new()), + PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()), + PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()), + PolicyConfig::CacheAware { + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + eviction_interval_secs, + max_tree_size, + } => { + let config = CacheAwareConfig { + cache_threshold: *cache_threshold, + balance_abs_threshold: *balance_abs_threshold, + balance_rel_threshold: *balance_rel_threshold, + eviction_interval_secs: *eviction_interval_secs, + max_tree_size: *max_tree_size, + }; + Arc::new(CacheAwarePolicy::with_config(config)) + } + } + } + + /// Create a policy by name (for dynamic loading) + pub fn create_by_name(name: &str) -> Option> { + match name.to_lowercase().as_str() { + "random" => Some(Arc::new(RandomPolicy::new())), + "round_robin" | "roundrobin" => Some(Arc::new(RoundRobinPolicy::new())), + "power_of_two" | "poweroftwo" => Some(Arc::new(PowerOfTwoPolicy::new())), + "cache_aware" | "cacheaware" => Some(Arc::new(CacheAwarePolicy::new())), + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_from_config() { + // Test Random + let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); + assert_eq!(policy.name(), "random"); + + // Test RoundRobin + let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin); + assert_eq!(policy.name(), "round_robin"); + + // Test PowerOfTwo + let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }); + assert_eq!(policy.name(), "power_of_two"); + + // Test CacheAware + let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware { + cache_threshold: 0.7, + balance_abs_threshold: 10, + balance_rel_threshold: 1.5, + eviction_interval_secs: 30, + max_tree_size: 1000, + }); + assert_eq!(policy.name(), "cache_aware"); + } + + #[test] + fn test_create_by_name() { + assert!(PolicyFactory::create_by_name("random").is_some()); + assert!(PolicyFactory::create_by_name("RANDOM").is_some()); + assert!(PolicyFactory::create_by_name("round_robin").is_some()); + assert!(PolicyFactory::create_by_name("RoundRobin").is_some()); + assert!(PolicyFactory::create_by_name("power_of_two").is_some()); + assert!(PolicyFactory::create_by_name("PowerOfTwo").is_some()); + assert!(PolicyFactory::create_by_name("cache_aware").is_some()); + assert!(PolicyFactory::create_by_name("CacheAware").is_some()); + assert!(PolicyFactory::create_by_name("unknown").is_none()); + } +} diff --git a/sgl-router/src/policies/mod.rs b/sgl-router/src/policies/mod.rs new file mode 100644 index 000000000..83fdd95b0 --- /dev/null +++ b/sgl-router/src/policies/mod.rs @@ -0,0 +1,143 @@ +//! Load balancing policies for SGLang router +//! +//! This module provides a unified abstraction for routing policies that work +//! across both regular and prefill-decode (PD) routing modes. + +use crate::core::Worker; +use std::fmt::Debug; + +mod cache_aware; +mod factory; +mod power_of_two; +mod random; +mod round_robin; + +pub use cache_aware::CacheAwarePolicy; +pub use factory::PolicyFactory; +pub use power_of_two::PowerOfTwoPolicy; +pub use random::RandomPolicy; +pub use round_robin::RoundRobinPolicy; + +/// Core trait for load balancing policies +/// +/// This trait provides a unified interface for implementing routing algorithms +/// that can work with both regular single-worker selection and PD dual-worker selection. +pub trait LoadBalancingPolicy: Send + Sync + Debug { + /// Select a single worker from the available workers + /// + /// This is used for regular routing mode where requests go to a single worker. + fn select_worker( + &self, + workers: &[Box], + request_text: Option<&str>, + ) -> Option; + + /// Select a pair of workers (prefill and decode) for PD routing + /// + /// Returns indices of (prefill_worker, decode_worker) from their respective arrays. + /// Default implementation uses select_worker for each array independently. + fn select_worker_pair( + &self, + prefill_workers: &[Box], + decode_workers: &[Box], + request_text: Option<&str>, + ) -> Option<(usize, usize)> { + // Default implementation: independently select from each pool + let prefill_idx = self.select_worker(prefill_workers, request_text)?; + let decode_idx = self.select_worker(decode_workers, request_text)?; + Some((prefill_idx, decode_idx)) + } + + /// Update policy state after request completion + /// + /// This is called when a request completes (successfully or not) to allow + /// policies to update their internal state. + fn on_request_complete(&self, _worker_url: &str, _success: bool) { + // Default: no-op for stateless policies + } + + /// Get policy name for metrics and debugging + fn name(&self) -> &'static str; + + /// Update worker load information + /// + /// This is called periodically with current load information for load-aware policies. + fn update_loads(&self, _loads: &std::collections::HashMap) { + // Default: no-op for policies that don't use load information + } + + /// Reset any internal state + /// + /// This is useful for policies that maintain state (e.g., round-robin counters). + fn reset(&self) { + // Default: no-op for stateless policies + } + + /// Get as Any for downcasting + fn as_any(&self) -> &dyn std::any::Any; +} + +/// Configuration for cache-aware policy +#[derive(Debug, Clone)] +pub struct CacheAwareConfig { + pub cache_threshold: f32, + pub balance_abs_threshold: usize, + pub balance_rel_threshold: f32, + pub eviction_interval_secs: u64, + pub max_tree_size: usize, +} + +impl Default for CacheAwareConfig { + fn default() -> Self { + Self { + cache_threshold: 0.5, + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + eviction_interval_secs: 30, + max_tree_size: 10000, + } + } +} + +/// Helper function to filter healthy workers and return their indices +pub(crate) fn get_healthy_worker_indices(workers: &[Box]) -> Vec { + workers + .iter() + .enumerate() + .filter(|(_, w)| w.is_healthy()) + .map(|(idx, _)| idx) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + + #[test] + fn test_get_healthy_worker_indices() { + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w3:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // All healthy initially + let indices = get_healthy_worker_indices(&workers); + assert_eq!(indices, vec![0, 1, 2]); + + // Mark one unhealthy + workers[1].set_healthy(false); + let indices = get_healthy_worker_indices(&workers); + assert_eq!(indices, vec![0, 2]); + } +} diff --git a/sgl-router/src/policies/power_of_two.rs b/sgl-router/src/policies/power_of_two.rs new file mode 100644 index 000000000..53c846196 --- /dev/null +++ b/sgl-router/src/policies/power_of_two.rs @@ -0,0 +1,201 @@ +//! Power-of-two choices load balancing policy + +use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use crate::core::Worker; +use metrics::counter; +use rand::Rng; +use std::collections::HashMap; +use std::sync::RwLock; +use tracing::info; + +/// Power-of-two choices policy +/// +/// Randomly selects two workers and routes to the one with lower load. +/// This provides good load distribution with minimal coordination overhead. +#[derive(Debug)] +pub struct PowerOfTwoPolicy { + /// Cached load information from external monitoring + cached_loads: RwLock>, +} + +impl PowerOfTwoPolicy { + pub fn new() -> Self { + Self { + cached_loads: RwLock::new(HashMap::new()), + } + } + + fn get_worker_load(&self, worker: &dyn Worker) -> isize { + // First check cached loads (from external monitoring) + if let Ok(loads) = self.cached_loads.read() { + if let Some(&load) = loads.get(worker.url()) { + return load; + } + } + + // Fall back to local load counter + worker.load() as isize + } +} + +impl LoadBalancingPolicy for PowerOfTwoPolicy { + fn select_worker( + &self, + workers: &[Box], + _request_text: Option<&str>, + ) -> Option { + let healthy_indices = get_healthy_worker_indices(workers); + + if healthy_indices.is_empty() { + return None; + } + + if healthy_indices.len() == 1 { + return Some(healthy_indices[0]); + } + + // Select two random workers + let mut rng = rand::thread_rng(); + let idx1 = rng.gen_range(0..healthy_indices.len()); + let mut idx2 = rng.gen_range(0..healthy_indices.len()); + + // Ensure we pick two different workers + while idx2 == idx1 { + idx2 = rng.gen_range(0..healthy_indices.len()); + } + + let worker_idx1 = healthy_indices[idx1]; + let worker_idx2 = healthy_indices[idx2]; + + // Compare loads and select the less loaded one + let load1 = self.get_worker_load(workers[worker_idx1].as_ref()); + let load2 = self.get_worker_load(workers[worker_idx2].as_ref()); + + // Log selection for debugging + let selected_idx = if load1 <= load2 { + worker_idx1 + } else { + worker_idx2 + }; + + info!( + "Power-of-two selection: {}={} vs {}={} -> selected {}", + workers[worker_idx1].url(), + load1, + workers[worker_idx2].url(), + load2, + workers[selected_idx].url() + ); + + // Increment processed counter + workers[selected_idx].increment_processed(); + counter!("sgl_router_processed_requests_total", "worker" => workers[selected_idx].url().to_string()) + .increment(1); + + Some(selected_idx) + } + + fn name(&self) -> &'static str { + "power_of_two" + } + + fn update_loads(&self, loads: &HashMap) { + if let Ok(mut cached) = self.cached_loads.write() { + *cached = loads.clone(); + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +impl Default for PowerOfTwoPolicy { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + + #[test] + fn test_power_of_two_selection() { + let policy = PowerOfTwoPolicy::new(); + let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular); + let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular); + let worker3 = BasicWorker::new("http://w3:8000".to_string(), WorkerType::Regular); + + // Set different loads + for _ in 0..10 { + worker1.increment_load(); + } + for _ in 0..5 { + worker2.increment_load(); + } + // worker3 has load 0 + + let workers: Vec> = + vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)]; + + // Run multiple selections + let mut selected_counts = vec![0; 3]; + for _ in 0..100 { + if let Some(idx) = policy.select_worker(&workers, None) { + selected_counts[idx] += 1; + } + } + + // Worker with lowest load (worker3) should be selected most often + assert!(selected_counts[2] > selected_counts[1]); + assert!(selected_counts[1] > selected_counts[0]); + } + + #[test] + fn test_power_of_two_with_cached_loads() { + let policy = PowerOfTwoPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Update cached loads + let mut loads = HashMap::new(); + loads.insert("http://w1:8000".to_string(), 100); + loads.insert("http://w2:8000".to_string(), 10); + policy.update_loads(&loads); + + // Should prefer worker2 with lower cached load + let mut w2_selected = 0; + for _ in 0..50 { + if let Some(idx) = policy.select_worker(&workers, None) { + if idx == 1 { + w2_selected += 1; + } + } + } + + // Worker2 should be selected significantly more often + assert!(w2_selected > 35); // Should win most of the time + } + + #[test] + fn test_power_of_two_single_worker() { + let policy = PowerOfTwoPolicy::new(); + let workers: Vec> = vec![Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + ))]; + + // With single worker, should always select it + assert_eq!(policy.select_worker(&workers, None), Some(0)); + } +} diff --git a/sgl-router/src/policies/random.rs b/sgl-router/src/policies/random.rs new file mode 100644 index 000000000..50920bdf1 --- /dev/null +++ b/sgl-router/src/policies/random.rs @@ -0,0 +1,116 @@ +//! Random load balancing policy + +use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use crate::core::Worker; +use rand::Rng; + +/// Random selection policy +/// +/// Selects workers randomly with uniform distribution among healthy workers. +#[derive(Debug, Default)] +pub struct RandomPolicy; + +impl RandomPolicy { + pub fn new() -> Self { + Self + } +} + +impl LoadBalancingPolicy for RandomPolicy { + fn select_worker( + &self, + workers: &[Box], + _request_text: Option<&str>, + ) -> Option { + let healthy_indices = get_healthy_worker_indices(workers); + + if healthy_indices.is_empty() { + return None; + } + + let mut rng = rand::thread_rng(); + let random_idx = rng.gen_range(0..healthy_indices.len()); + Some(healthy_indices[random_idx]) + } + + fn name(&self) -> &'static str { + "random" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + use std::collections::HashMap; + + #[test] + fn test_random_selection() { + let policy = RandomPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w3:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Test multiple selections to ensure randomness + let mut counts = HashMap::new(); + for _ in 0..100 { + if let Some(idx) = policy.select_worker(&workers, None) { + *counts.entry(idx).or_insert(0) += 1; + } + } + + // All workers should be selected at least once + assert_eq!(counts.len(), 3); + assert!(counts.values().all(|&count| count > 0)); + } + + #[test] + fn test_random_with_unhealthy_workers() { + let policy = RandomPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Mark first worker as unhealthy + workers[0].set_healthy(false); + + // Should always select the healthy worker (index 1) + for _ in 0..10 { + assert_eq!(policy.select_worker(&workers, None), Some(1)); + } + } + + #[test] + fn test_random_no_healthy_workers() { + let policy = RandomPolicy::new(); + let workers: Vec> = vec![Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + ))]; + + workers[0].set_healthy(false); + assert_eq!(policy.select_worker(&workers, None), None); + } +} diff --git a/sgl-router/src/policies/round_robin.rs b/sgl-router/src/policies/round_robin.rs new file mode 100644 index 000000000..4401605f0 --- /dev/null +++ b/sgl-router/src/policies/round_robin.rs @@ -0,0 +1,136 @@ +//! Round-robin load balancing policy + +use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use crate::core::Worker; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Round-robin selection policy +/// +/// Selects workers in sequential order, cycling through all healthy workers. +#[derive(Debug, Default)] +pub struct RoundRobinPolicy { + counter: AtomicUsize, +} + +impl RoundRobinPolicy { + pub fn new() -> Self { + Self { + counter: AtomicUsize::new(0), + } + } +} + +impl LoadBalancingPolicy for RoundRobinPolicy { + fn select_worker( + &self, + workers: &[Box], + _request_text: Option<&str>, + ) -> Option { + let healthy_indices = get_healthy_worker_indices(workers); + + if healthy_indices.is_empty() { + return None; + } + + // Get and increment counter atomically + let count = self.counter.fetch_add(1, Ordering::Relaxed); + let selected_idx = count % healthy_indices.len(); + + Some(healthy_indices[selected_idx]) + } + + fn name(&self) -> &'static str { + "round_robin" + } + + fn reset(&self) { + self.counter.store(0, Ordering::Relaxed); + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + + #[test] + fn test_round_robin_selection() { + let policy = RoundRobinPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w3:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Should select workers in order: 0, 1, 2, 0, 1, 2, ... + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(1)); + assert_eq!(policy.select_worker(&workers, None), Some(2)); + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(1)); + } + + #[test] + fn test_round_robin_with_unhealthy_workers() { + let policy = RoundRobinPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w3:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Mark middle worker as unhealthy + workers[1].set_healthy(false); + + // Should skip unhealthy worker: 0, 2, 0, 2, ... + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(2)); + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(2)); + } + + #[test] + fn test_round_robin_reset() { + let policy = RoundRobinPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Advance the counter + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(1)); + + // Reset should start from beginning + policy.reset(); + assert_eq!(policy.select_worker(&workers, None), Some(0)); + } +} diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs deleted file mode 100644 index e8b68d7c5..000000000 --- a/sgl-router/src/router.rs +++ /dev/null @@ -1,1376 +0,0 @@ -use crate::core::{HealthChecker, Worker, WorkerFactory}; -use crate::pd_router::PDRouter; -use crate::pd_types::PDSelectionPolicy; -use crate::tree::Tree; -use ::metrics::{counter, gauge, histogram}; -use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; -use actix_web::{HttpRequest, HttpResponse}; -use futures_util::{StreamExt, TryStreamExt}; -use std::fmt::Debug; -use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex, RwLock}; -use std::thread; -use std::time::Duration; -use std::time::Instant; -use tokio; -use tracing::{debug, error, info, warn}; - -pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { - req.headers() - .iter() - .filter_map(|(name, value)| { - value - .to_str() - .ok() - .map(|v| (name.to_string(), v.to_string())) - }) - .collect() -} - -#[derive(Debug)] -pub enum Router { - RoundRobin { - workers: Arc>>>, - current_index: AtomicUsize, - timeout_secs: u64, - interval_secs: u64, - _health_checker: Option, - }, - Random { - workers: Arc>>>, - timeout_secs: u64, - interval_secs: u64, - _health_checker: Option, - }, - PrefillDecode { - pd_router: Arc, - }, - CacheAware { - /* - Cache-Aware Load Balancing Router - - This router combines two strategies to optimize both cache utilization and request distribution: - - 1. Cache-Aware Routing (Approximate Tree) - 2. Load Balancing (Shortest Queue with Balance Thresholds) - - The router dynamically switches between these strategies based on load conditions: - - Uses load balancing when the system is imbalanced - - Uses cache-aware routing when the system is balanced - - A system is considered imbalanced if both conditions are met: - 1. (max - min) > abs_threshold - 2. max > rel_threshold * min - - Strategy Details: - - 1. Cache-Aware Routing (Approximate Tree) - ------------------------------------------- - This strategy maintains an approximate radix tree for each worker based on request history, - eliminating the need for direct cache state queries. The tree stores raw text characters - instead of token IDs to avoid tokenization overhead. - - Process: - a. For each request, find the worker with the highest prefix match - b. If match rate > cache_threshold: - Route to the worker with highest match (likely has relevant data cached) - c. If match rate ≤ cache_threshold: - Route to the worker with smallest tree size (most available cache capacity) - d. Background maintenance: - Periodically evict least recently used leaf nodes to prevent memory overflow - - 2. Load Balancing (Shortest Queue) - ------------------------------------------- - This strategy tracks pending request counts per worker and routes new requests - to the least busy worker when the system is detected to be imbalanced. - - Configuration Parameters: - ------------------------ - 1. cache_threshold: (float, 0.0 to 1.0) - Minimum prefix match ratio to use highest-match routing. - Below this threshold, routes to worker with most available cache space. - - 2. balance_abs_threshold: (integer) - Absolute difference threshold for load imbalance detection. - System is potentially imbalanced if (max_load - min_load) > abs_threshold - - 3. balance_rel_threshold: (float) - Relative ratio threshold for load imbalance detection. - System is potentially imbalanced if max_load > min_load * rel_threshold - Used in conjunction with abs_threshold to determine final imbalance state. - - 4. eviction_interval_secs: (integer) - Interval between LRU eviction cycles for the approximate trees. - - 5. max_tree_size: (integer) - Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted - during the next eviction cycle. - */ - workers: Arc>>>, - tree: Arc>, - cache_threshold: f32, - balance_abs_threshold: usize, - balance_rel_threshold: f32, - timeout_secs: u64, - interval_secs: u64, - _eviction_thread: Option>, - _health_checker: Option, - }, -} - -#[derive(Debug, Clone)] -pub enum PolicyConfig { - RandomConfig { - timeout_secs: u64, - interval_secs: u64, - }, - RoundRobinConfig { - timeout_secs: u64, - interval_secs: u64, - }, - CacheAwareConfig { - cache_threshold: f32, - balance_abs_threshold: usize, - balance_rel_threshold: f32, - eviction_interval_secs: u64, - max_tree_size: usize, - timeout_secs: u64, - interval_secs: u64, - }, - PrefillDecodeConfig { - selection_policy: PDSelectionPolicy, - prefill_urls: Vec<(String, Option)>, // (url, bootstrap_port) - decode_urls: Vec, - timeout_secs: u64, - interval_secs: u64, - }, -} - -impl Router { - pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { - // Update active workers gauge - gauge!("sgl_router_active_workers").set(worker_urls.len() as f64); - - // Get timeout and interval from policy config - let (timeout_secs, interval_secs) = match &policy_config { - PolicyConfig::RandomConfig { - timeout_secs, - interval_secs, - } => (*timeout_secs, *interval_secs), - PolicyConfig::RoundRobinConfig { - timeout_secs, - interval_secs, - } => (*timeout_secs, *interval_secs), - PolicyConfig::CacheAwareConfig { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - PolicyConfig::PrefillDecodeConfig { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - }; - - // For PrefillDecode, we need to handle workers differently - match &policy_config { - PolicyConfig::PrefillDecodeConfig { .. } => { - // PD mode doesn't use the worker_urls parameter - // We'll validate PD workers separately - } - _ => { - // Wait until all workers are healthy for regular modes - let worker_urls = worker_urls.clone(); - std::thread::spawn(move || { - Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs) - }) - .join() - .map_err(|e| { - error!("Health-check thread panicked: {:?}", e); - format!("Health-check thread panicked: {e:?}") - })??; - } - } - - // Create Worker trait objects from URLs - let workers: Vec> = worker_urls - .iter() - .map(|url| WorkerFactory::create_regular(url.clone())) - .collect(); - - // Create router based on policy... - Ok(match policy_config { - PolicyConfig::RandomConfig { - timeout_secs, - interval_secs, - } => { - let workers = Arc::new(RwLock::new(workers)); - let health_checker = - crate::core::start_health_checker(Arc::clone(&workers), interval_secs); - Router::Random { - workers, - timeout_secs, - interval_secs, - _health_checker: Some(health_checker), - } - } - PolicyConfig::RoundRobinConfig { - timeout_secs, - interval_secs, - } => { - let workers = Arc::new(RwLock::new(workers)); - let health_checker = - crate::core::start_health_checker(Arc::clone(&workers), interval_secs); - Router::RoundRobin { - workers, - current_index: std::sync::atomic::AtomicUsize::new(0), - timeout_secs, - interval_secs, - _health_checker: Some(health_checker), - } - } - PolicyConfig::CacheAwareConfig { - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - eviction_interval_secs, - max_tree_size, - timeout_secs, - interval_secs, - } => { - let tree = Arc::new(Mutex::new(Tree::new())); - - // Create background eviction thread - let tree_clone = Arc::clone(&tree); - let workers = Arc::new(RwLock::new(workers)); - let workers_clone = Arc::clone(&workers); - let eviction_thread = thread::spawn(move || { - loop { - // Sleep for the specified interval - thread::sleep(Duration::from_secs(eviction_interval_secs)); - - let locked_tree_clone = tree_clone.lock().unwrap(); - // Run eviction - locked_tree_clone.evict_tenant_by_size(max_tree_size); - drop(locked_tree_clone); - - // Log worker loads and processed requests - let workers_guard = workers_clone.read().unwrap(); - let loads: Vec<(String, usize)> = workers_guard - .iter() - .map(|w| (w.url().to_string(), w.load())) - .collect(); - info!("Worker loads: {:?}", loads); - - let processed: Vec<(String, usize)> = workers_guard - .iter() - .map(|w| (w.url().to_string(), w.processed_requests())) - .collect(); - info!("Processed requests: {:?}", processed); - } - }); - - for worker in workers.read().unwrap().iter() { - tree.lock().unwrap().insert("", worker.url()); - } - - let health_checker = - crate::core::start_health_checker(Arc::clone(&workers), interval_secs); - - Router::CacheAware { - workers, - tree, - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - timeout_secs, - interval_secs, - _eviction_thread: Some(eviction_thread), - _health_checker: Some(health_checker), - } - } - PolicyConfig::PrefillDecodeConfig { - selection_policy, - prefill_urls, - decode_urls, - timeout_secs, - interval_secs, - } => { - // Create PDRouter instance - let pd_router = PDRouter::new( - prefill_urls, - decode_urls, - selection_policy, - timeout_secs, - interval_secs, - )?; - - Router::PrefillDecode { - pd_router: Arc::new(pd_router), - } - } - }) - } - - /// Get the current list of worker URLs - pub fn get_worker_urls(&self) -> Vec { - match self { - Router::RoundRobin { workers, .. } - | Router::Random { workers, .. } - | Router::CacheAware { workers, .. } => workers - .read() - .unwrap() - .iter() - .map(|w| w.url().to_string()) - .collect(), - Router::PrefillDecode { .. } => Vec::new(), - } - } - - pub fn wait_for_healthy_workers( - worker_urls: &[String], - timeout_secs: u64, - interval_secs: u64, - ) -> Result<(), String> { - let start_time = std::time::Instant::now(); - let sync_client = reqwest::blocking::Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .build() - .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - - loop { - if start_time.elapsed() > Duration::from_secs(timeout_secs) { - error!( - "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", - timeout_secs, worker_urls - ); - return Err(format!( - "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", - timeout_secs, worker_urls - )); - } - - let mut all_healthy = true; - let mut unhealthy_workers = Vec::new(); - - for url in worker_urls { - match sync_client.get(&format!("{}/health", url)).send() { - Ok(res) => { - if !res.status().is_success() { - let msg = format!( - "Worker heatlh check is pending with status {}", - res.status() - ); - info!("{}", msg); - all_healthy = false; - unhealthy_workers.push((url, msg)); - } - } - Err(_) => { - let msg = format!("Worker is not ready yet"); - info!("{}", msg); - all_healthy = false; - unhealthy_workers.push((url, msg)); - } - } - } - - if all_healthy { - info!("All workers are healthy"); - return Ok(()); - } else { - info!("Initializing workers:"); - for (url, reason) in &unhealthy_workers { - info!(" {} - {}", url, reason); - } - thread::sleep(Duration::from_secs(interval_secs)); - } - } - } - - fn select_first_worker(&self) -> Result { - match self { - Router::RoundRobin { workers, .. } - | Router::Random { workers, .. } - | Router::CacheAware { workers, .. } => { - let workers_guard = workers.read().unwrap(); - if workers_guard.is_empty() { - Err("No workers are available".to_string()) - } else { - Ok(workers_guard[0].url().to_string()) - } - } - Router::PrefillDecode { .. } => { - // For PD mode, we don't need this method as routing is handled by PDRouter - Err("PrefillDecode mode doesn't use select_first_worker".to_string()) - } - } - } - - pub async fn send_request( - &self, - client: &reqwest::Client, - worker_url: &str, - route: &str, - req: &HttpRequest, - ) -> HttpResponse { - let start = Instant::now(); - let mut request_builder = client.get(format!("{}{}", worker_url, route)); - - // Copy all headers from original request except for /health because it does not need authorization - if route != "/health" { - for (name, value) in copy_request_headers(req) { - // Skip Content-Type and Content-Length as .json() sets them - if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" - { - request_builder = request_builder.header(name, value); - } - } - } - - let response = match request_builder.send().await { - Ok(res) => { - let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - match res.bytes().await { - Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(e) => HttpResponse::InternalServerError() - .body(format!("Failed to read response body: {}", e)), - } - } - Err(e) => HttpResponse::InternalServerError().body(format!( - "Failed to send request to worker {}: {}", - worker_url, e - )), - }; - - // Record request metrics - if route != "/health" { - let duration = start.elapsed(); - counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); - histogram!("sgl_router_request_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); - - if !response.status().is_success() { - counter!("sgl_router_request_errors_total", "route" => route.to_string()) - .increment(1); - } - } - response - } - - pub async fn route_to_first( - &self, - client: &reqwest::Client, - route: &str, - req: &HttpRequest, - ) -> HttpResponse { - const MAX_REQUEST_RETRIES: u32 = 3; - const MAX_TOTAL_RETRIES: u32 = 6; - let mut total_retries = 0; - - while total_retries < MAX_TOTAL_RETRIES { - match self.select_first_worker() { - Ok(worker_url) => { - let mut request_retries = 0; - - // Try the same worker multiple times - while request_retries < MAX_REQUEST_RETRIES { - if total_retries >= 1 { - info!("Retrying request after {} failed attempts", total_retries); - } - - let response = self.send_request(client, &worker_url, route, req).await; - - if response.status().is_success() { - return response; - } else { - // if the worker is healthy, it means the request is bad, so return the error response - let health_response = - self.send_request(client, &worker_url, "/health", req).await; - if health_response.status().is_success() { - return response; - } - } - - warn!( - "Request to {} failed (attempt {}/{})", - worker_url, - request_retries + 1, - MAX_REQUEST_RETRIES - ); - - request_retries += 1; - total_retries += 1; - - if request_retries == MAX_REQUEST_RETRIES { - warn!("Removing failed worker: {}", worker_url); - self.remove_worker(&worker_url); - break; - } - } - } - Err(e) => return HttpResponse::InternalServerError().body(e), - } - } - - HttpResponse::InternalServerError().body("All retry attempts failed") - } - - pub async fn route_to_all( - &self, - client: &reqwest::Client, - route: &str, - req: &HttpRequest, - ) -> HttpResponse { - // Get all worker URLs based on router type - let worker_urls = match self { - Router::PrefillDecode { .. } => { - // For PD mode, route_to_all is not supported directly - // It should be handled by PDRouter if needed - return HttpResponse::NotImplemented() - .body("route_to_all not implemented for PrefillDecode mode"); - } - _ => self.get_worker_urls(), - }; - - // Send requests to all workers concurrently - let mut tasks = Vec::new(); - for worker_url in &worker_urls { - let mut request_builder = client.post(format!("{}{}", worker_url, route)); - - // Copy headers from original request - for (name, value) in copy_request_headers(req) { - request_builder = request_builder.header(name, value); - } - - tasks.push(request_builder.send()); - } - - // Wait for all responses - let results = futures_util::future::join_all(tasks).await; - - // Check if all succeeded - let all_success = results.iter().all(|r| { - r.as_ref() - .map(|res| res.status().is_success()) - .unwrap_or(false) - }); - - if all_success { - HttpResponse::Ok().body("Operation completed on all servers") - } else { - HttpResponse::InternalServerError().body("Operation failed on one or more servers") - } - } - - pub async fn get_all_loads( - &self, - client: &reqwest::Client, - _req: &HttpRequest, - ) -> HttpResponse { - // For PD mode, delegate to PDRouter - match self { - Router::PrefillDecode { pd_router } => { - return pd_router.get_loads(client).await; - } - _ => { - // For non-PD routers, handle normally - } - } - - let urls = self.get_worker_urls(); - let prefill_urls: Vec = Vec::new(); - let decode_urls = urls; - - // Collect loads from all servers - let mut prefill_loads = Vec::new(); - let mut decode_loads = Vec::new(); - - // Get prefill loads - for url in &prefill_urls { - let load = self.get_worker_load(client, url).await.unwrap_or(-1); - prefill_loads.push(serde_json::json!({ - "engine": format!("(Prefill@{})", url), - "load": load as i64 - })); - } - - // Get decode loads - for url in &decode_urls { - let load = self.get_worker_load(client, url).await.unwrap_or(-1); - decode_loads.push(serde_json::json!({ - "engine": format!("(Decode@{})", url), - "load": load as i64 - })); - } - - HttpResponse::Ok().json(serde_json::json!({ - "prefill": prefill_loads, - "decode": decode_loads - })) - } - - // New method to route typed requests directly - pub async fn route_typed_request< - T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, - >( - &self, - client: &reqwest::Client, - req: &HttpRequest, - typed_req: &T, - route: &str, - ) -> HttpResponse { - match self { - Router::PrefillDecode { .. } => HttpResponse::InternalServerError() - .body("PD routing should use specialized typed handlers"), - _ => { - // Handle retries like the original implementation - let start = Instant::now(); - const MAX_REQUEST_RETRIES: u32 = 3; - const MAX_TOTAL_RETRIES: u32 = 6; - let mut total_retries = 0; - - while total_retries < MAX_TOTAL_RETRIES { - // Extract routing text directly from typed request - let text = typed_req.extract_text_for_routing(); - let is_stream = typed_req.is_stream(); - - // Select worker based on text - let worker_url = self.select_generate_worker_from_text(&text); - let mut request_retries = 0; - - // Try the same worker multiple times - while request_retries < MAX_REQUEST_RETRIES { - if total_retries >= 1 { - info!("Retrying request after {} failed attempts", total_retries); - counter!("sgl_router_retries_total", "route" => route.to_string()) - .increment(1); - } - - // For CacheAware router, increment load before request - let load_incremented = match self { - Router::CacheAware { workers, .. } => { - let workers_guard = workers.read().unwrap(); - if let Some(worker) = - workers_guard.iter().find(|w| w.url() == &worker_url) - { - worker.increment_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); - true - } else { - false - } - } - _ => false, - }; - - // Send typed request directly - let response = self - .send_typed_request( - client, - req, - typed_req, - route, - &worker_url, - is_stream, - load_incremented, - ) - .await; - - if response.status().is_success() { - let duration = start.elapsed(); - histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); - return response; - } else { - // if the worker is healthy, it means the request is bad, so return the error response - let health_response = - self.send_request(client, &worker_url, "/health", req).await; - if health_response.status().is_success() { - counter!("sgl_router_request_errors_total", "route" => route.to_string()) - .increment(1); - return response; - } - } - - warn!( - "Generate request to {} failed (attempt {}/{})", - worker_url, - request_retries + 1, - MAX_REQUEST_RETRIES - ); - - request_retries += 1; - total_retries += 1; - - if request_retries == MAX_REQUEST_RETRIES { - warn!("Removing failed worker: {}", worker_url); - self.remove_worker(&worker_url); - break; - } - } - } - - counter!("sgl_router_request_errors_total", "route" => route.to_string()) - .increment(1); - HttpResponse::InternalServerError().body("All retry attempts failed") - } - } - } - - // Helper method to select worker from text (returns index for RoundRobin/Random, URL for CacheAware) - fn select_generate_worker_from_text(&self, text: &str) -> String { - match self { - Router::RoundRobin { - workers, - current_index, - .. - } => { - let workers_guard = workers.read().unwrap(); - let idx = current_index - .fetch_update( - std::sync::atomic::Ordering::SeqCst, - std::sync::atomic::Ordering::SeqCst, - |x| Some((x + 1) % workers_guard.len()), - ) - .unwrap(); - workers_guard[idx].url().to_string() - } - - Router::Random { workers, .. } => { - let workers_guard = workers.read().unwrap(); - workers_guard[rand::random::() % workers_guard.len()] - .url() - .to_string() - } - - Router::CacheAware { - workers, - tree, - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - .. - } => { - let tree = tree.lock().unwrap(); - let workers_guard = workers.read().unwrap(); - - // Get current load statistics from workers - let loads: Vec = workers_guard.iter().map(|w| w.load()).collect(); - let max_load = *loads.iter().max().unwrap_or(&0); - let min_load = *loads.iter().min().unwrap_or(&0); - - // Load is considered imbalanced if: - // 1. (max - min) > abs_threshold AND - // 2. max > rel_threshold * min - let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold - && (max_load as f32) > (min_load as f32 * balance_rel_threshold); - - let selected_url = if is_imbalanced { - // Log load balancing trigger and current queue state - let worker_loads: Vec<(String, usize)> = workers_guard - .iter() - .map(|w| (w.url().to_string(), w.load())) - .collect(); - - info!( - "Load balancing triggered due to workload imbalance:\n\ - Max load: {}, Min load: {}\n\ - Current worker loads: {:?}", - max_load, min_load, worker_loads - ); - - counter!("sgl_router_load_balancing_events_total").increment(1); - gauge!("sgl_router_max_load").set(max_load as f64); - gauge!("sgl_router_min_load").set(min_load as f64); - - // Use shortest queue routing when load is imbalanced - workers_guard - .iter() - .min_by_key(|w| w.load()) - .map(|w| w.url().to_string()) - .unwrap_or_else(|| workers_guard[0].url().to_string()) - } else { - // Use cache-aware routing when load is balanced - let (matched_text, matched_worker) = tree.prefix_match(&text); - let matched_rate = - matched_text.chars().count() as f32 / text.chars().count() as f32; - - if matched_rate > *cache_threshold { - counter!("sgl_router_cache_hits_total").increment(1); - matched_worker.to_string() - } else { - counter!("sgl_router_cache_misses_total").increment(1); - tree.get_smallest_tenant() - } - }; - - // Find the selected worker and increment processed counter only - if let Some(worker) = workers_guard.iter().find(|w| w.url() == &selected_url) { - worker.increment_processed(); - counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()) - .increment(1); - } - - tree.insert(&text, &selected_url); - - selected_url - } - Router::PrefillDecode { .. } => { - // For PD mode, we don't use this method - return "PD_MODE_ERROR".to_string(); - } - } - } - - // Send typed request directly without conversion - async fn send_typed_request( - &self, - client: &reqwest::Client, - req: &HttpRequest, - typed_req: &T, - route: &str, - worker_url: &str, - is_stream: bool, - load_incremented: bool, // Whether load was incremented for this request - ) -> HttpResponse { - let start = Instant::now(); - - // Debug: Log what we're sending - if let Ok(json_str) = serde_json::to_string_pretty(typed_req) { - debug!("Sending request to {}: {}", route, json_str); - } - - let mut request_builder = client - .post(format!("{}{}", worker_url, route)) - .json(typed_req); // Use json() directly with typed request - - // Copy all headers from original request - for (name, value) in copy_request_headers(req) { - // Skip Content-Type and Content-Length as .json() sets them - if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { - request_builder = request_builder.header(&name, &value); - } - } - - let res = match request_builder.send().await { - Ok(res) => res, - Err(e) => { - error!("Failed to send request to {}: {}", worker_url, e); - - // Decrement load on error for CacheAware router - if load_incremented { - if let Router::CacheAware { workers, .. } = self { - if let Ok(workers_guard) = workers.read() { - if let Some(worker) = - workers_guard.iter().find(|w| w.url() == worker_url) - { - worker.decrement_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); - } - } - } - } - - return HttpResponse::InternalServerError().body(format!("Request failed: {}", e)); - } - }; - - let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - if !is_stream { - // For non-streaming requests, get response first - let response = match res.bytes().await { - Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(e) => { - let error_msg = format!("Failed to get response body: {}", e); - HttpResponse::InternalServerError().body(error_msg) - } - }; - - // Decrement load counter for non-streaming CacheAware requests - if load_incremented && !is_stream { - if let Router::CacheAware { workers, .. } = self { - if let Ok(workers_guard) = workers.read() { - if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { - worker.decrement_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); - } - } - } - } - - // Record metrics - let duration = start.elapsed(); - histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); - counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); - - response - } else if let Router::CacheAware { workers, .. } = self { - // For streaming with CacheAware router, we need to manually decrement when done - let workers = Arc::clone(workers); - let worker_url = worker_url.to_string(); - - HttpResponse::build(status) - .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) - .streaming( - res.bytes_stream() - .map_err(|_| { - actix_web::error::ErrorInternalServerError("Failed to read stream") - }) - .inspect(move |bytes| { - if let Ok(bytes) = bytes { - if bytes - .as_ref() - .windows(12) - .any(|window| window == b"data: [DONE]") - { - if let Ok(workers_guard) = workers.read() { - if let Some(worker) = - workers_guard.iter().find(|w| w.url() == &worker_url) - { - worker.decrement_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); - debug!("Streaming is done!!") - } - } - } - } - }), - ) - } else { - // For non-CacheAware routers, just stream without load tracking - HttpResponse::build(status) - .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) - .streaming(res.bytes_stream().map_err(|_| { - actix_web::error::ErrorInternalServerError("Failed to read stream") - })) - } - } - - pub async fn add_worker(&self, worker_url: &str) -> Result { - let (timeout_secs, interval_secs) = match self { - Router::Random { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - Router::RoundRobin { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - Router::CacheAware { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - Router::PrefillDecode { .. } => { - // For PD mode, we don't support adding workers via this method - return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string()); - } - }; - - let start_time = std::time::Instant::now(); - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .build() - .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - - loop { - if start_time.elapsed() > Duration::from_secs(timeout_secs) { - error!( - "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", - timeout_secs, worker_url - ); - return Err(format!( - "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", - timeout_secs, worker_url - )); - } - - match client.get(&format!("{}/health", worker_url)).send().await { - Ok(res) => { - if res.status().is_success() { - match self { - Router::RoundRobin { workers, .. } - | Router::Random { workers, .. } - | Router::CacheAware { workers, .. } => { - info!("Worker {} health check passed", worker_url); - let mut workers_guard = workers.write().unwrap(); - if workers_guard.iter().any(|w| w.url() == worker_url) { - return Err(format!("Worker {} already exists", worker_url)); - } - info!("Added worker: {}", worker_url); - let new_worker = - WorkerFactory::create_regular(worker_url.to_string()); - workers_guard.push(new_worker); - gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); - } - Router::PrefillDecode { .. } => { - return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string()); - } - } - - // If cache aware, add worker to tree - if let Router::CacheAware { tree, .. } = self { - // Add worker to tree - tree.lock().unwrap().insert("", worker_url); - } - - return Ok(format!("Successfully added worker: {}", worker_url)); - } else { - info!( - "Worker {} health check is pending with status: {}.", - worker_url, - res.status() - ); - // if the url does not have http or https prefix, warn users - if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") - { - warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); - } - - tokio::time::sleep(Duration::from_secs(interval_secs)).await; - continue; - } - } - Err(e) => { - info!( - "Worker {} health check is pending with error: {}", - worker_url, e - ); - - // if the url does not have http or https prefix, warn users - if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { - warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); - } - - tokio::time::sleep(Duration::from_secs(interval_secs)).await; - continue; - } - } - } - } - - pub fn remove_worker(&self, worker_url: &str) { - match self { - Router::RoundRobin { workers, .. } - | Router::Random { workers, .. } - | Router::CacheAware { workers, .. } => { - let mut workers_guard = workers.write().unwrap(); - if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { - workers_guard.remove(index); - info!("Removed worker: {}", worker_url); - gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); - } else { - warn!("Worker {} not found, skipping removal", worker_url); - return; - } - } - Router::PrefillDecode { .. } => { - warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods."); - return; - } - } - - // if cache aware, remove the worker from the tree - if let Router::CacheAware { tree, .. } = self { - tree.lock().unwrap().remove_tenant(&worker_url); - info!("Removed worker from tree: {}", worker_url); - } - } - - /// Add a worker with PD mode support - pub async fn add_pd_worker( - &self, - worker_url: &str, - pod_type: crate::service_discovery::PodType, - bootstrap_port: Option, - ) -> Result { - match self { - Router::PrefillDecode { pd_router } => match pod_type { - crate::service_discovery::PodType::Prefill => pd_router - .add_prefill_server(worker_url.to_string(), bootstrap_port) - .await - .map_err(|e| e.to_string()), - crate::service_discovery::PodType::Decode => pd_router - .add_decode_server(worker_url.to_string()) - .await - .map_err(|e| e.to_string()), - crate::service_discovery::PodType::Regular => { - Err("Regular pod type not supported in PD mode".to_string()) - } - }, - _ => Err("add_pd_worker only supported in PD mode".to_string()), - } - } - - /// Remove a worker with PD mode support - pub async fn remove_pd_worker( - &self, - worker_url: &str, - pod_type: crate::service_discovery::PodType, - ) -> Result { - match self { - Router::PrefillDecode { pd_router } => match pod_type { - crate::service_discovery::PodType::Prefill => pd_router - .remove_prefill_server(worker_url) - .await - .map_err(|e| e.to_string()), - crate::service_discovery::PodType::Decode => pd_router - .remove_decode_server(worker_url) - .await - .map_err(|e| e.to_string()), - crate::service_discovery::PodType::Regular => { - Err("Regular pod type not supported in PD mode".to_string()) - } - }, - _ => Err("remove_pd_worker only supported in PD mode".to_string()), - } - } - - async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option { - match client.get(&format!("{}/get_load", worker_url)).send().await { - Ok(res) if res.status().is_success() => match res.bytes().await { - Ok(bytes) => match serde_json::from_slice::(&bytes) { - Ok(data) => data - .get("load") - .and_then(|v| v.as_i64()) - .map(|v| v as isize), - Err(e) => { - debug!("Failed to parse load response from {}: {}", worker_url, e); - None - } - }, - Err(e) => { - debug!("Failed to read load response from {}: {}", worker_url, e); - None - } - }, - Ok(res) => { - debug!( - "Worker {} returned non-success status: {}", - worker_url, - res.status() - ); - None - } - Err(e) => { - debug!("Failed to get load from {}: {}", worker_url, e); - None - } - } - } - - // PD-specific wrapper methods that delegate to PDRouter - pub async fn route_pd_health_generate( - &self, - _client: &reqwest::Client, - _req: &HttpRequest, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.health_generate(&pd_router.http_client).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn route_pd_generate_typed( - &self, - _client: &reqwest::Client, - req: &HttpRequest, - typed_req: crate::pd_types::GenerateReqInput, - route: &str, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router - .route_generate(&pd_router.http_client, req, typed_req, route) - .await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn route_pd_chat_typed( - &self, - _client: &reqwest::Client, - req: &HttpRequest, - typed_req: crate::pd_types::ChatReqInput, - route: &str, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router - .route_chat(&pd_router.http_client, req, typed_req, route) - .await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn get_pd_server_info( - &self, - _client: &reqwest::Client, - _req: &HttpRequest, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.get_server_info(&pd_router.http_client).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn get_pd_models( - &self, - _client: &reqwest::Client, - req: &HttpRequest, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.get_models(&pd_router.http_client, req).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.flush_cache(&pd_router.http_client).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn get_pd_model_info( - &self, - _client: &reqwest::Client, - req: &HttpRequest, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.get_model_info(&pd_router.http_client, req).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::service_discovery::PodType; - - fn create_test_regular_router() -> Router { - let workers = vec![ - WorkerFactory::create_regular("http://worker1:8080".to_string()), - WorkerFactory::create_regular("http://worker2:8080".to_string()), - ]; - Router::Random { - workers: Arc::new(RwLock::new(workers)), - timeout_secs: 5, - interval_secs: 1, - _health_checker: None, - } - } - - #[test] - fn test_router_get_worker_urls_regular() { - let router = create_test_regular_router(); - let urls = router.get_worker_urls(); - - assert_eq!(urls.len(), 2); - assert!(urls.contains(&"http://worker1:8080".to_string())); - assert!(urls.contains(&"http://worker2:8080".to_string())); - } - - // #[test] - // fn test_router_get_worker_urls_pd_mode() { - // // For PD mode, get_worker_urls returns empty list - // // Note: PDRouter::new requires health checks which fail in tests - // // This test would need a mock server or different test setup - // } - - #[tokio::test] - async fn test_add_pd_worker_with_regular_router() { - let router = create_test_regular_router(); - - let result = router - .add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081)) - .await; - - assert!(result.is_err()); - assert!(result - .unwrap_err() - .contains("add_pd_worker only supported in PD mode")); - } - - #[tokio::test] - async fn test_remove_pd_worker_with_regular_router() { - let router = create_test_regular_router(); - - let result = router - .remove_pd_worker("http://worker:8080", PodType::Decode) - .await; - - assert!(result.is_err()); - assert!(result - .unwrap_err() - .contains("remove_pd_worker only supported in PD mode")); - } - - // #[tokio::test] - // async fn test_add_pd_worker_with_pd_router_regular_type() { - // // Note: PDRouter::new requires health checks which fail in tests - // // This test would need a mock server or different test setup - // } - - // #[tokio::test] - // async fn test_remove_pd_worker_with_pd_router_regular_type() { - // // Note: PDRouter::new requires health checks which fail in tests - // // This test would need a mock server or different test setup - // } - - #[test] - fn test_select_first_worker_regular() { - let router = create_test_regular_router(); - let result = router.select_first_worker(); - - assert!(result.is_ok()); - assert_eq!(result.unwrap(), "http://worker1:8080"); - } - - // #[test] - // fn test_select_first_worker_pd_mode() { - // // Note: PDRouter::new requires health checks which fail in tests - // // This test would need a mock server or different test setup - // } - - #[test] - fn test_wait_for_healthy_workers_empty_list() { - let result = Router::wait_for_healthy_workers(&[], 1, 1); - assert!(result.is_ok()); - } - - #[test] - fn test_wait_for_healthy_workers_invalid_urls() { - // This test will timeout quickly since the URLs are invalid - let result = - Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("Timeout")); - } -} diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs new file mode 100644 index 000000000..201240121 --- /dev/null +++ b/sgl-router/src/routers/factory.rs @@ -0,0 +1,66 @@ +//! Factory for creating router instances + +use super::{pd_router::PDRouter, router::Router, RouterTrait}; +use crate::config::{PolicyConfig, RouterConfig, RoutingMode}; +use crate::policies::PolicyFactory; + +/// Factory for creating router instances based on configuration +pub struct RouterFactory; + +impl RouterFactory { + /// Create a router instance from configuration + pub fn create_router(config: &RouterConfig) -> Result, String> { + match &config.mode { + RoutingMode::Regular { worker_urls } => { + Self::create_regular_router(worker_urls, &config.policy, config) + } + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + } => Self::create_pd_router(prefill_urls, decode_urls, &config.policy, config), + } + } + + /// Create a regular router with injected policy + fn create_regular_router( + worker_urls: &[String], + policy_config: &PolicyConfig, + router_config: &RouterConfig, + ) -> Result, String> { + // Create policy + let policy = PolicyFactory::create_from_config(policy_config); + + // Create regular router with injected policy + let router = Router::new( + worker_urls.to_vec(), + policy, + router_config.worker_startup_timeout_secs, + router_config.worker_startup_check_interval_secs, + )?; + + Ok(Box::new(router)) + } + + /// Create a PD router with injected policy + fn create_pd_router( + prefill_urls: &[(String, Option)], + decode_urls: &[String], + policy_config: &PolicyConfig, + router_config: &RouterConfig, + ) -> Result, String> { + // Create policy directly from PolicyConfig + // All policies now support PD mode through the select_worker_pair method + let policy = PolicyFactory::create_from_config(policy_config); + + // Create PD router with injected policy + let router = PDRouter::new( + prefill_urls.to_vec(), + decode_urls.to_vec(), + policy, + router_config.worker_startup_timeout_secs, + router_config.worker_startup_check_interval_secs, + )?; + + Ok(Box::new(router)) + } +} diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs new file mode 100644 index 000000000..ffb6d93c7 --- /dev/null +++ b/sgl-router/src/routers/mod.rs @@ -0,0 +1,101 @@ +//! Router implementations + +use actix_web::{HttpRequest, HttpResponse}; +use async_trait::async_trait; +use reqwest::Client; +use std::fmt::Debug; + +pub mod factory; +pub mod pd_router; +pub mod pd_types; +pub mod request_adapter; +pub mod router; + +pub use factory::RouterFactory; + +/// Worker management trait for administrative operations +/// +/// This trait is separate from RouterTrait to allow Send futures +/// for use in service discovery and other background tasks +#[async_trait] +pub trait WorkerManagement: Send + Sync { + /// Add a worker to the router + async fn add_worker(&self, worker_url: &str) -> Result; + + /// Remove a worker from the router + fn remove_worker(&self, worker_url: &str); + + /// Get all worker URLs + fn get_worker_urls(&self) -> Vec; +} + +/// Core trait for all router implementations +/// +/// This trait provides a unified interface for routing requests, +/// regardless of whether it's a regular router or PD router. +#[async_trait(?Send)] +pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { + /// Get a reference to self as Any for downcasting + fn as_any(&self) -> &dyn std::any::Any; + /// Route a health check request + async fn health(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Route a health generate request + async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Get server information + async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Get available models + async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Get model information + async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Route a generate request + async fn route_generate( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse; + + /// Route a chat completion request + async fn route_chat( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse; + + /// Route a completion request + async fn route_completion( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse; + + /// Flush cache on all workers + async fn flush_cache(&self, client: &Client) -> HttpResponse; + + /// Get worker loads (for monitoring) + async fn get_worker_loads(&self, client: &Client) -> HttpResponse; + + /// Get router type name + fn router_type(&self) -> &'static str; + + /// Check if this is a PD router + fn is_pd_mode(&self) -> bool { + self.router_type() == "pd" + } + + /// Server liveness check - is the server process running + fn liveness(&self) -> HttpResponse { + // Simple liveness check - if we can respond, we're alive + HttpResponse::Ok().body("OK") + } + + /// Server readiness check - is the server ready to handle requests + fn readiness(&self) -> HttpResponse; +} diff --git a/sgl-router/src/pd_router.rs b/sgl-router/src/routers/pd_router.rs similarity index 67% rename from sgl-router/src/pd_router.rs rename to sgl-router/src/routers/pd_router.rs index a1f04c7d2..2ac8f9027 100644 --- a/sgl-router/src/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -1,10 +1,11 @@ // PD (Prefill-Decode) Router Implementation // This module handles routing for disaggregated prefill-decode systems +use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError}; +use super::request_adapter::ToPdRequest; use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; -use crate::pd_types::{ - api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy, -}; +use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +use crate::policies::LoadBalancingPolicy; use crate::tree::Tree; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; @@ -17,13 +18,11 @@ use std::time::{Duration, Instant}; use tracing::{debug, error, info, warn}; use uuid::Uuid; -// Removed over-engineered ProxyResponse - using HttpResponse directly - #[derive(Debug)] pub struct PDRouter { pub prefill_workers: Arc>>>, pub decode_workers: Arc>>>, - pub selection_policy: PDSelectionPolicy, + pub policy: Arc, pub prefill_tree: Option>>, pub timeout_secs: u64, pub interval_secs: u64, @@ -42,7 +41,7 @@ impl PDRouter { bootstrap_port: Option, ) -> Result { // Wait for the new server to be healthy - crate::router::Router::wait_for_healthy_workers( + crate::routers::router::Router::wait_for_healthy_workers( &[url.clone()], self.timeout_secs, self.interval_secs, @@ -78,7 +77,7 @@ impl PDRouter { pub async fn add_decode_server(&self, url: String) -> Result { // Wait for the new server to be healthy - crate::router::Router::wait_for_healthy_workers( + crate::routers::router::Router::wait_for_healthy_workers( &[url.clone()], self.timeout_secs, self.interval_secs, @@ -103,9 +102,6 @@ impl PDRouter { workers.push(worker); - // Initialize load tracking - // Worker tracks its own load internally - info!("Added decode server: {}", url); Ok(format!("Successfully added decode server: {}", url)) } @@ -128,9 +124,6 @@ impl PDRouter { }); } - // Remove from load tracking - // Worker load tracking is internal - // Remove from cache tree if using cache-aware policy if let Some(ref tree) = self.prefill_tree { // Note: Tree doesn't have a remove method, so we rebuild it @@ -170,7 +163,7 @@ impl PDRouter { pub fn new( prefill_urls: Vec<(String, Option)>, decode_urls: Vec, - selection_policy: PDSelectionPolicy, + policy: Arc, timeout_secs: u64, interval_secs: u64, ) -> Result { @@ -185,25 +178,38 @@ impl PDRouter { .map(WorkerFactory::create_decode) .collect(); - // Wait for PD workers to be healthy + // Wait for PD workers to be healthy (skip if empty - for service discovery mode) let all_urls: Vec = prefill_workers .iter() .chain(decode_workers.iter()) .map(|worker| worker.url().to_string()) .collect(); - crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?; + if !all_urls.is_empty() { + crate::routers::router::Router::wait_for_healthy_workers( + &all_urls, + timeout_secs, + interval_secs, + )?; + } // Initialize cache-aware components if needed - let prefill_tree = match &selection_policy { - PDSelectionPolicy::CacheAware { .. } => { - let tree = Arc::new(Mutex::new(Tree::new())); - // Initialize tree with prefill workers - for worker in &prefill_workers { - tree.lock().unwrap().insert("", worker.url()); - } - Some(tree) + let prefill_tree = if policy.name() == "cache_aware" { + // Initialize the policy's internal tree with prefill workers + if let Some(cache_policy) = policy + .as_any() + .downcast_ref::() + { + cache_policy.init_workers(&prefill_workers); } - _ => None, + + let tree = Arc::new(Mutex::new(Tree::new())); + // Initialize tree with prefill workers + for worker in &prefill_workers { + tree.lock().unwrap().insert("", worker.url()); + } + Some(tree) + } else { + None }; // Set up background load monitoring for power-of-two selection @@ -216,10 +222,11 @@ impl PDRouter { .build() .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - let load_monitor_handle = if matches!(selection_policy, PDSelectionPolicy::PowerOfTwo) { + let load_monitor_handle = if policy.name() == "power_of_two" { let monitor_urls = all_urls.clone(); let monitor_interval = interval_secs; let monitor_client = http_client.clone(); + let policy_clone = Arc::clone(&policy); Some(Arc::new(tokio::spawn(async move { Self::monitor_worker_loads_with_client( @@ -227,6 +234,7 @@ impl PDRouter { tx, monitor_interval, monitor_client, + policy_clone, ) .await; }))) @@ -246,7 +254,7 @@ impl PDRouter { Ok(PDRouter { prefill_workers, decode_workers, - selection_policy, + policy, prefill_tree, timeout_secs, interval_secs, @@ -270,15 +278,21 @@ impl PDRouter { let _request_id = Uuid::new_v4(); // Get stream flag and return_logprob flag before moving the request - let is_stream = typed_req.is_stream(); + let is_stream = typed_req.stream; let return_logprob = typed_req .other .get("return_logprob") .and_then(|v| v.as_bool()) .unwrap_or(false); + // Extract text for cache-aware routing from the typed request + let request_text = typed_req.text.as_ref().and_then(|t| match t { + super::pd_types::InputText::Single(s) => Some(s.as_str()), + super::pd_types::InputText::Batch(v) => v.first().map(|s| s.as_str()), + }); + // Select servers - let (prefill, decode) = match self.select_pd_pair(client).await { + let (prefill, decode) = match self.select_pd_pair(client, request_text).await { Ok(pair) => pair, Err(e) => { error!("Failed to select PD pair: {}", e); @@ -339,15 +353,24 @@ impl PDRouter { let start = Instant::now(); // Get stream flag and return_logprob flag before moving the request - let is_stream = typed_req.is_stream(); + let is_stream = typed_req.stream; let return_logprob = typed_req .other .get("return_logprob") .and_then(|v| v.as_bool()) .unwrap_or(false); + // Extract text for cache-aware routing from chat messages + let request_text = typed_req + .other + .get("messages") + .and_then(|messages| messages.as_array()) + .and_then(|arr| arr.first()) + .and_then(|msg| msg.get("content")) + .and_then(|content| content.as_str()); + // Select servers - let (prefill, decode) = match self.select_pd_pair(client).await { + let (prefill, decode) = match self.select_pd_pair(client, request_text).await { Ok(pair) => pair, Err(e) => { error!("Failed to select PD pair: {}", e); @@ -424,7 +447,7 @@ impl PDRouter { .json(&json_request); // Copy headers from original request - for (name, value) in crate::router::copy_request_headers(req) { + for (name, value) in crate::routers::router::copy_request_headers(req) { if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { prefill_request = prefill_request.header(&name, &value); decode_request = decode_request.header(&name, &value); @@ -620,104 +643,47 @@ impl PDRouter { async fn select_pd_pair( &self, _client: &reqwest::Client, + request_text: Option<&str>, ) -> Result<(Box, Box), String> { - // Check we have workers - if self + // Get read locks for both worker lists + let prefill_workers = self .prefill_workers .read() - .map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))? - .is_empty() - { - return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string()); - } - if self + .map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))?; + let decode_workers = self .decode_workers .read() - .map_err(|e| format!("Failed to acquire decode workers lock: {}", e))? - .is_empty() - { + .map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?; + + // Check we have workers + if prefill_workers.is_empty() { + return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string()); + } + if decode_workers.is_empty() { return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string()); } - match &self.selection_policy { - PDSelectionPolicy::Random => self.select_random(), - PDSelectionPolicy::PowerOfTwo => self.select_power_of_two().await, - PDSelectionPolicy::CacheAware { .. } => { - // TODO: Implement cache-aware selection - self.select_power_of_two().await + // Use the policy to select worker pair + match self + .policy + .select_worker_pair(&prefill_workers, &decode_workers, request_text) + { + Some((prefill_idx, decode_idx)) => { + let prefill = prefill_workers[prefill_idx].clone_worker(); + let decode = decode_workers[decode_idx].clone_worker(); + Ok((prefill, decode)) } + None => Err("Failed to select worker pair".to_string()), } } - fn select_random(&self) -> Result<(Box, Box), String> { - let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; - let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; - - let prefill = prefill_list[rand::random::() % prefill_list.len()].clone_worker(); - let decode = decode_list[rand::random::() % decode_list.len()].clone_worker(); - - Ok((prefill, decode)) - } - - async fn select_power_of_two(&self) -> Result<(Box, Box), String> { - let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; - let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; - - let (p1_idx, p2_idx) = get_two_random_indices(prefill_list.len()); - let (d1_idx, d2_idx) = get_two_random_indices(decode_list.len()); - - let loads = self.worker_loads.borrow(); - - let p1_load = loads - .get(prefill_list[p1_idx].url()) - .copied() - .unwrap_or(isize::MAX); - let p2_load = loads - .get(prefill_list[p2_idx].url()) - .copied() - .unwrap_or(isize::MAX); - let d1_load = loads - .get(decode_list[d1_idx].url()) - .copied() - .unwrap_or(isize::MAX); - let d2_load = loads - .get(decode_list[d2_idx].url()) - .copied() - .unwrap_or(isize::MAX); - - info!( - "Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}", - prefill_list[p1_idx].url(), - p1_load, - prefill_list[p2_idx].url(), - p2_load, - decode_list[d1_idx].url(), - d1_load, - decode_list[d2_idx].url(), - d2_load - ); - - let selected_prefill = if p1_load <= p2_load { - prefill_list[p1_idx].clone_worker() - } else { - prefill_list[p2_idx].clone_worker() - }; - - let selected_decode = if d1_load <= d2_load { - decode_list[d1_idx].clone_worker() - } else { - decode_list[d2_idx].clone_worker() - }; - - Ok((selected_prefill, selected_decode)) - } - // Background task to monitor worker loads with shared client async fn monitor_worker_loads_with_client( worker_urls: Vec, tx: tokio::sync::watch::Sender>, interval_secs: u64, client: reqwest::Client, + policy: Arc, ) { loop { let mut loads = HashMap::new(); @@ -742,6 +708,9 @@ impl PDRouter { debug!("Worker loads updated: {:?}", loads); + // Update the policy with current loads + policy.update_loads(&loads); + // Check if receiver is still active if tx.send(loads).is_err() { info!("Load monitor receiver dropped, shutting down monitor task"); @@ -792,18 +761,6 @@ impl PDRouter { } // Helper functions -fn get_two_random_indices(len: usize) -> (usize, usize) { - if len == 1 { - (0, 0) - } else { - let idx1 = rand::random::() % len; - let mut idx2 = rand::random::() % len; - while idx2 == idx1 { - idx2 = rand::random::() % len; - } - (idx1, idx2) - } -} async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option { match client.get(format!("{}/get_load", worker_url)).send().await { @@ -841,61 +798,72 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option HttpResponse { - let mut all_healthy = true; - let mut unhealthy_servers = Vec::new(); + // Test model generation capability by selecting a random pair and testing them + // Note: This endpoint actually causes the model to generate tokens, so we only test one pair - // Collect all worker URLs with their types - let mut worker_infos = Vec::new(); + // Select a random worker pair using the policy + let (prefill, decode) = match self.select_pd_pair(client, None).await { + Ok(pair) => pair, + Err(e) => { + return HttpResponse::ServiceUnavailable() + .body(format!("No healthy worker pair available: {}", e)); + } + }; - for worker in self.prefill_workers.read().unwrap().iter() { - worker_infos.push((worker.url().to_string(), "prefill")); - } + // Test prefill server's health_generate + let prefill_url = format!("{}/health_generate", prefill.url()); + let prefill_result = client.get(&prefill_url).send().await; - for worker in self.decode_workers.read().unwrap().iter() { - worker_infos.push((worker.url().to_string(), "decode")); - } + // Test decode server's health_generate + let decode_url = format!("{}/health_generate", decode.url()); + let decode_result = client.get(&decode_url).send().await; - // Create tasks with URL tracking - let tasks: Vec<_> = worker_infos - .iter() - .map(|(url, _)| { - let health_url = format!("{}/health_generate", url); - client.get(&health_url).send() - }) - .collect(); + // Check results + let mut errors = Vec::new(); - let results = futures_util::future::join_all(tasks).await; - - for ((url, worker_type), result) in worker_infos.iter().zip(results.into_iter()) { - match result { - Ok(res) if res.status().is_success() => { - debug!("Health check passed for {} server: {}", worker_type, url); - } - Ok(res) => { - all_healthy = false; - let msg = format!( - "{} server {} returned status {}", - worker_type, - url, - res.status() - ); - error!("{}", msg); - unhealthy_servers.push(msg); - } - Err(e) => { - all_healthy = false; - let msg = format!("{} server {} error: {}", worker_type, url, e); - error!("{}", msg); - unhealthy_servers.push(msg); - } + match prefill_result { + Ok(res) if res.status().is_success() => { + debug!( + "Health generate passed for prefill server: {}", + prefill.url() + ); + } + Ok(res) => { + errors.push(format!( + "Prefill {} returned status {}", + prefill.url(), + res.status() + )); + } + Err(e) => { + errors.push(format!("Prefill {} error: {}", prefill.url(), e)); } } - if all_healthy { - HttpResponse::Ok().body("Health check passed on all servers") + match decode_result { + Ok(res) if res.status().is_success() => { + debug!("Health generate passed for decode server: {}", decode.url()); + } + Ok(res) => { + errors.push(format!( + "Decode {} returned status {}", + decode.url(), + res.status() + )); + } + Err(e) => { + errors.push(format!("Decode {} error: {}", decode.url(), e)); + } + } + + if errors.is_empty() { + HttpResponse::Ok().body(format!( + "Health generate passed on selected pair: prefill={}, decode={}", + prefill.url(), + decode.url() + )) } else { - HttpResponse::ServiceUnavailable() - .body(format!("Health check failed: {:?}", unhealthy_servers)) + HttpResponse::ServiceUnavailable().body(format!("Health generate failed: {:?}", errors)) } } @@ -955,7 +923,7 @@ impl PDRouter { if let Some(worker_url) = first_worker_url { // Send request directly without going through Router let mut request_builder = client.get(format!("{}/v1/models", worker_url)); - for (name, value) in crate::router::copy_request_headers(req) { + for (name, value) in crate::routers::router::copy_request_headers(req) { if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { request_builder = request_builder.header(name, value); @@ -1035,7 +1003,7 @@ impl PDRouter { if let Some(worker_url) = first_worker_url { let mut request_builder = client.get(format!("{}/get_model_info", worker_url)); - for (name, value) in crate::router::copy_request_headers(req) { + for (name, value) in crate::routers::router::copy_request_headers(req) { if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { request_builder = request_builder.header(name, value); @@ -1102,3 +1070,324 @@ impl PDRouter { } } } + +use crate::routers::{RouterTrait, WorkerManagement}; +use async_trait::async_trait; +use reqwest::Client; + +#[async_trait] +impl WorkerManagement for PDRouter { + async fn add_worker(&self, _worker_url: &str) -> Result { + // For PD router, we don't support adding workers via this generic method + Err( + "PD router requires specific add_prefill_server or add_decode_server methods" + .to_string(), + ) + } + + fn remove_worker(&self, worker_url: &str) { + // For PD router, we would need to know if it's a prefill or decode server + // For now, try both + if let Ok(mut workers) = self.prefill_workers.write() { + if let Some(index) = workers.iter().position(|w| w.url() == worker_url) { + workers.remove(index); + info!("Removed prefill worker: {}", worker_url); + return; + } + } + + if let Ok(mut workers) = self.decode_workers.write() { + if let Some(index) = workers.iter().position(|w| w.url() == worker_url) { + workers.remove(index); + info!("Removed decode worker: {}", worker_url); + } + } + } + + fn get_worker_urls(&self) -> Vec { + let mut urls = Vec::new(); + + // Add prefill worker URLs + if let Ok(workers) = self.prefill_workers.read() { + for worker in workers.iter() { + urls.push(worker.url().to_string()); + } + } + + // Add decode worker URLs + if let Ok(workers) = self.decode_workers.read() { + for worker in workers.iter() { + urls.push(worker.url().to_string()); + } + } + + urls + } +} + +#[async_trait(?Send)] +impl RouterTrait for PDRouter { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + async fn health(&self, _client: &Client, _req: &HttpRequest) -> HttpResponse { + // This is a server readiness check - checking if we have healthy workers + // Workers handle their own health checks in the background + let mut all_healthy = true; + let mut unhealthy_servers = Vec::new(); + + // Check prefill servers + for worker in self.prefill_workers.read().unwrap().iter() { + if !worker.is_healthy() { + all_healthy = false; + unhealthy_servers.push(format!("Prefill: {}", worker.url())); + } + } + + // Check decode servers + for worker in self.decode_workers.read().unwrap().iter() { + if !worker.is_healthy() { + all_healthy = false; + unhealthy_servers.push(format!("Decode: {}", worker.url())); + } + } + + if all_healthy { + HttpResponse::Ok().body("All servers healthy") + } else { + HttpResponse::ServiceUnavailable() + .body(format!("Unhealthy servers: {:?}", unhealthy_servers)) + } + } + + async fn health_generate(&self, client: &Client, _req: &HttpRequest) -> HttpResponse { + // Use the existing PDRouter health_generate method + PDRouter::health_generate(self, client).await + } + + async fn get_server_info(&self, client: &Client, _req: &HttpRequest) -> HttpResponse { + // Use the existing PDRouter get_server_info method + PDRouter::get_server_info(self, client).await + } + + async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + // Get first prefill worker URL to avoid holding lock across await + let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { + workers.first().map(|w| w.url().to_string()) + } else { + return HttpResponse::InternalServerError().body("Failed to access prefill workers"); + }; + + if let Some(worker_url) = first_worker_url { + // Send request directly without going through Router + let mut request_builder = client.get(format!("{}/v1/models", worker_url)); + for (name, value) in crate::routers::router::copy_request_headers(req) { + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" + { + request_builder = request_builder.header(name, value); + } + } + match request_builder.send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to send request: {}", e)), + } + } else { + HttpResponse::ServiceUnavailable().body("No prefill servers available") + } + } + + async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + // For PD router, get model info from the first prefill server + // Get first prefill worker URL to avoid holding lock across await + let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { + workers.first().map(|w| w.url().to_string()) + } else { + return HttpResponse::InternalServerError().body("Failed to access prefill workers"); + }; + + if let Some(worker_url) = first_worker_url { + let mut request_builder = client.get(format!("{}/get_model_info", worker_url)); + for (name, value) in crate::routers::router::copy_request_headers(req) { + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" + { + request_builder = request_builder.header(name, value); + } + } + match request_builder.send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to send request: {}", e)), + } + } else { + HttpResponse::ServiceUnavailable().body("No prefill servers available") + } + } + + async fn route_generate( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + match serde_json::from_value::(body.clone()) { + Ok(openai_req) => { + // Convert OpenAI format to PD format + let pd_req = openai_req.to_pd_request(); + PDRouter::route_generate(self, client, req, pd_req, "/generate").await + } + Err(_) => { + // If that fails, try to deserialize directly as PD format (for backwards compatibility) + match serde_json::from_value::(body) { + Ok(pd_req) => { + PDRouter::route_generate(self, client, req, pd_req, "/generate").await + } + Err(e) => { + HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)) + } + } + } + } + } + + async fn route_chat( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + match serde_json::from_value::(body.clone()) { + Ok(openai_req) => { + // Convert OpenAI format to PD format + let pd_req = openai_req.to_pd_request(); + PDRouter::route_chat(self, client, req, pd_req, "/v1/chat/completions").await + } + Err(_) => { + // If that fails, try to deserialize directly as PD format (for backwards compatibility) + match serde_json::from_value::(body) { + Ok(pd_req) => { + PDRouter::route_chat(self, client, req, pd_req, "/v1/chat/completions") + .await + } + Err(e) => { + HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)) + } + } + } + } + } + + async fn route_completion( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + match serde_json::from_value::(body.clone()) { + Ok(openai_req) => { + // Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput) + let pd_req = openai_req.to_pd_request(); + PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await + } + Err(_) => { + // If that fails, try to deserialize directly as PD format (for backwards compatibility) + match serde_json::from_value::(body) { + Ok(pd_req) => { + PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await + } + Err(e) => { + HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)) + } + } + } + } + } + + async fn flush_cache(&self, client: &Client) -> HttpResponse { + // Use the existing PDRouter flush_cache method + PDRouter::flush_cache(self, client).await + } + + async fn get_worker_loads(&self, client: &Client) -> HttpResponse { + // Use the existing PDRouter get_loads method + PDRouter::get_loads(self, client).await + } + + fn router_type(&self) -> &'static str { + "pd" + } + + fn readiness(&self) -> HttpResponse { + // PD router is ready if it has at least one healthy prefill AND one healthy decode worker + let healthy_prefill_count = self + .prefill_workers + .read() + .unwrap() + .iter() + .filter(|w| w.is_healthy()) + .count(); + + let healthy_decode_count = self + .decode_workers + .read() + .unwrap() + .iter() + .filter(|w| w.is_healthy()) + .count(); + + let total_prefill = self.prefill_workers.read().unwrap().len(); + let total_decode = self.decode_workers.read().unwrap().len(); + + if healthy_prefill_count > 0 && healthy_decode_count > 0 { + HttpResponse::Ok().json(serde_json::json!({ + "status": "ready", + "prefill": { + "healthy": healthy_prefill_count, + "total": total_prefill + }, + "decode": { + "healthy": healthy_decode_count, + "total": total_decode + } + })) + } else { + let mut reasons = Vec::new(); + if healthy_prefill_count == 0 { + reasons.push("no healthy prefill workers"); + } + if healthy_decode_count == 0 { + reasons.push("no healthy decode workers"); + } + + HttpResponse::ServiceUnavailable().json(serde_json::json!({ + "status": "not_ready", + "reason": reasons.join(", "), + "prefill": { + "healthy": healthy_prefill_count, + "total": total_prefill + }, + "decode": { + "healthy": healthy_decode_count, + "total": total_decode + } + })) + } + } +} diff --git a/sgl-router/src/pd_types.rs b/sgl-router/src/routers/pd_types.rs similarity index 100% rename from sgl-router/src/pd_types.rs rename to sgl-router/src/routers/pd_types.rs diff --git a/sgl-router/src/request_adapter.rs b/sgl-router/src/routers/request_adapter.rs similarity index 99% rename from sgl-router/src/request_adapter.rs rename to sgl-router/src/routers/request_adapter.rs index 4396cc4d7..f5611bbe4 100644 --- a/sgl-router/src/request_adapter.rs +++ b/sgl-router/src/routers/request_adapter.rs @@ -1,9 +1,9 @@ // Request adapter to bridge OpenAI API types with PD routing requirements +use super::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch}; use crate::openai_api_types::{ ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray, }; -use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch}; use serde_json::Value; /// Adapter trait to convert OpenAI requests to PD-compatible requests diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs new file mode 100644 index 000000000..ef44348ec --- /dev/null +++ b/sgl-router/src/routers/router.rs @@ -0,0 +1,1055 @@ +use crate::core::{HealthChecker, Worker, WorkerFactory}; +use crate::policies::LoadBalancingPolicy; +use ::metrics::{counter, gauge, histogram}; +use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; +use actix_web::{HttpRequest, HttpResponse}; +use futures_util::{StreamExt, TryStreamExt}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use std::thread; +use std::time::{Duration, Instant}; +use tracing::{debug, error, info, warn}; + +pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { + req.headers() + .iter() + .filter_map(|(name, value)| { + value + .to_str() + .ok() + .map(|v| (name.to_string(), v.to_string())) + }) + .collect() +} + +/// Regular router that uses injected load balancing policies +#[derive(Debug)] +pub struct Router { + workers: Arc>>>, + policy: Arc, + timeout_secs: u64, + interval_secs: u64, + _worker_loads: Arc>>, + _load_monitor_handle: Option>>, + _health_checker: Option, +} + +impl Router { + /// Create a new router with injected policy + pub fn new( + worker_urls: Vec, + policy: Arc, + timeout_secs: u64, + interval_secs: u64, + ) -> Result { + // Update active workers gauge + gauge!("sgl_router_active_workers").set(worker_urls.len() as f64); + + // Wait for workers to be healthy (skip if empty - for service discovery mode) + if !worker_urls.is_empty() { + Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; + } + + // Create Worker trait objects from URLs + let workers: Vec> = worker_urls + .iter() + .map(|url| WorkerFactory::create_regular(url.clone())) + .collect(); + + // Initialize policy with workers if needed (e.g., for cache-aware) + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + cache_aware.init_workers(&workers); + } + + let workers = Arc::new(RwLock::new(workers)); + let health_checker = crate::core::start_health_checker(Arc::clone(&workers), interval_secs); + + // Setup load monitoring for PowerOfTwo policy + let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); + let worker_loads = Arc::new(rx); + + let load_monitor_handle = if policy.name() == "power_of_two" { + let monitor_urls = worker_urls.clone(); + let monitor_interval = interval_secs; + let policy_clone = Arc::clone(&policy); + + Some(Arc::new(tokio::spawn(async move { + Self::monitor_worker_loads(monitor_urls, tx, monitor_interval, policy_clone).await; + }))) + } else { + None + }; + + Ok(Router { + workers, + policy, + timeout_secs, + interval_secs, + _worker_loads: worker_loads, + _load_monitor_handle: load_monitor_handle, + _health_checker: Some(health_checker), + }) + } + + /// Get the current list of worker URLs + pub fn get_worker_urls(&self) -> Vec { + self.workers + .read() + .unwrap() + .iter() + .map(|w| w.url().to_string()) + .collect() + } + + pub fn wait_for_healthy_workers( + worker_urls: &[String], + timeout_secs: u64, + interval_secs: u64, + ) -> Result<(), String> { + let start_time = std::time::Instant::now(); + let sync_client = reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls + ); + return Err(format!( + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls + )); + } + + let mut all_healthy = true; + let mut unhealthy_workers = Vec::new(); + + for url in worker_urls { + match sync_client.get(&format!("{}/health", url)).send() { + Ok(res) => { + if !res.status().is_success() { + let msg = format!( + "Worker heatlh check is pending with status {}", + res.status() + ); + info!("{}", msg); + all_healthy = false; + unhealthy_workers.push((url, msg)); + } + } + Err(_) => { + let msg = format!("Worker is not ready yet"); + info!("{}", msg); + all_healthy = false; + unhealthy_workers.push((url, msg)); + } + } + } + + if all_healthy { + info!("All workers are healthy"); + return Ok(()); + } else { + info!("Initializing workers:"); + for (url, reason) in &unhealthy_workers { + info!(" {} - {}", url, reason); + } + thread::sleep(Duration::from_secs(interval_secs)); + } + } + } + + fn select_first_worker(&self) -> Result { + let workers_guard = self.workers.read().unwrap(); + if workers_guard.is_empty() { + Err("No workers are available".to_string()) + } else { + Ok(workers_guard[0].url().to_string()) + } + } + + pub async fn send_request( + &self, + client: &reqwest::Client, + worker_url: &str, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { + let start = Instant::now(); + let mut request_builder = client.get(format!("{}{}", worker_url, route)); + + // Copy all headers from original request except for /health because it does not need authorization + if route != "/health" { + for (name, value) in copy_request_headers(req) { + // Skip Content-Type and Content-Length as .json() sets them + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" + { + request_builder = request_builder.header(name, value); + } + } + } + + let response = match request_builder.send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError().body(format!( + "Failed to send request to worker {}: {}", + worker_url, e + )), + }; + + // Record request metrics + if route != "/health" { + let duration = start.elapsed(); + counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); + histogram!("sgl_router_request_duration_seconds", "route" => route.to_string()) + .record(duration.as_secs_f64()); + + if !response.status().is_success() { + counter!("sgl_router_request_errors_total", "route" => route.to_string()) + .increment(1); + } + } + response + } + + pub async fn route_to_first( + &self, + client: &reqwest::Client, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { + const MAX_REQUEST_RETRIES: u32 = 3; + const MAX_TOTAL_RETRIES: u32 = 6; + let mut total_retries = 0; + + while total_retries < MAX_TOTAL_RETRIES { + match self.select_first_worker() { + Ok(worker_url) => { + let mut request_retries = 0; + + // Try the same worker multiple times + while request_retries < MAX_REQUEST_RETRIES { + if total_retries >= 1 { + info!("Retrying request after {} failed attempts", total_retries); + } + + let response = self.send_request(client, &worker_url, route, req).await; + + if response.status().is_success() { + return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } + } + + warn!( + "Request to {} failed (attempt {}/{})", + worker_url, + request_retries + 1, + MAX_REQUEST_RETRIES + ); + + request_retries += 1; + total_retries += 1; + + if request_retries == MAX_REQUEST_RETRIES { + warn!("Removing failed worker: {}", worker_url); + self.remove_worker(&worker_url); + break; + } + } + } + Err(e) => return HttpResponse::InternalServerError().body(e), + } + } + + HttpResponse::InternalServerError().body("All retry attempts failed") + } + + pub async fn route_to_all( + &self, + client: &reqwest::Client, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { + // Get all worker URLs + let worker_urls = self.get_worker_urls(); + + // Send requests to all workers concurrently + let mut tasks = Vec::new(); + for worker_url in &worker_urls { + let mut request_builder = client.post(format!("{}{}", worker_url, route)); + + // Copy headers from original request + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + + tasks.push(request_builder.send()); + } + + // Wait for all responses + let results = futures_util::future::join_all(tasks).await; + + // Check if all succeeded + let all_success = results.iter().all(|r| { + r.as_ref() + .map(|res| res.status().is_success()) + .unwrap_or(false) + }); + + if all_success { + HttpResponse::Ok().body("Operation completed on all servers") + } else { + HttpResponse::InternalServerError().body("Operation failed on one or more servers") + } + } + + pub async fn get_all_loads( + &self, + client: &reqwest::Client, + _req: &HttpRequest, + ) -> HttpResponse { + let urls = self.get_worker_urls(); + let prefill_urls: Vec = Vec::new(); + let decode_urls = urls; + + // Collect loads from all servers + let mut prefill_loads = Vec::new(); + let mut decode_loads = Vec::new(); + + // Get prefill loads + for url in &prefill_urls { + let load = self.get_worker_load(client, url).await.unwrap_or(-1); + prefill_loads.push(serde_json::json!({ + "engine": format!("(Prefill@{})", url), + "load": load as i64 + })); + } + + // Get decode loads + for url in &decode_urls { + let load = self.get_worker_load(client, url).await.unwrap_or(-1); + decode_loads.push(serde_json::json!({ + "engine": format!("(Decode@{})", url), + "load": load as i64 + })); + } + + HttpResponse::Ok().json(serde_json::json!({ + "prefill": prefill_loads, + "decode": decode_loads + })) + } + + // New method to route typed requests directly + pub async fn route_typed_request< + T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, + >( + &self, + client: &reqwest::Client, + req: &HttpRequest, + typed_req: &T, + route: &str, + ) -> HttpResponse { + // Handle retries like the original implementation + let start = Instant::now(); + const MAX_REQUEST_RETRIES: u32 = 3; + const MAX_TOTAL_RETRIES: u32 = 6; + let mut total_retries = 0; + + while total_retries < MAX_TOTAL_RETRIES { + // Extract routing text directly from typed request + let text = typed_req.extract_text_for_routing(); + let is_stream = typed_req.is_stream(); + + // Select worker based on text + let worker_url = self.select_generate_worker_from_text(&text); + let mut request_retries = 0; + + // Try the same worker multiple times + while request_retries < MAX_REQUEST_RETRIES { + if total_retries >= 1 { + info!("Retrying request after {} failed attempts", total_retries); + counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1); + } + + // Increment load before request if using RAII load tracking + let load_incremented = if self.policy.name() == "cache_aware" { + let workers_guard = self.workers.read().unwrap(); + if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url) { + worker.increment_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + true + } else { + false + } + } else { + false + }; + + // Send typed request directly + let response = self + .send_typed_request( + client, + req, + typed_req, + route, + &worker_url, + is_stream, + load_incremented, + ) + .await; + + if response.status().is_success() { + let duration = start.elapsed(); + histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) + .record(duration.as_secs_f64()); + return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + counter!("sgl_router_request_errors_total", "route" => route.to_string()) + .increment(1); + return response; + } + } + + warn!( + "Generate request to {} failed (attempt {}/{})", + worker_url, + request_retries + 1, + MAX_REQUEST_RETRIES + ); + + request_retries += 1; + total_retries += 1; + + if request_retries == MAX_REQUEST_RETRIES { + warn!("Removing failed worker: {}", worker_url); + self.remove_worker(&worker_url); + break; + } + } + } + + counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1); + HttpResponse::InternalServerError().body("All retry attempts failed") + } + + // Helper method to select worker from text using the policy + fn select_generate_worker_from_text(&self, text: &str) -> String { + let workers = self.workers.read().unwrap(); + + match self.policy.select_worker(&workers, Some(text)) { + Some(idx) => workers[idx].url().to_string(), + None => { + warn!("No healthy workers available"); + String::new() + } + } + } + + // Send typed request directly without conversion + async fn send_typed_request( + &self, + client: &reqwest::Client, + req: &HttpRequest, + typed_req: &T, + route: &str, + worker_url: &str, + is_stream: bool, + load_incremented: bool, // Whether load was incremented for this request + ) -> HttpResponse { + let start = Instant::now(); + + // Debug: Log what we're sending + if let Ok(json_str) = serde_json::to_string_pretty(typed_req) { + debug!("Sending request to {}: {}", route, json_str); + } + + let mut request_builder = client + .post(format!("{}{}", worker_url, route)) + .json(typed_req); // Use json() directly with typed request + + // Copy all headers from original request + for (name, value) in copy_request_headers(req) { + // Skip Content-Type and Content-Length as .json() sets them + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { + request_builder = request_builder.header(&name, &value); + } + } + + let res = match request_builder.send().await { + Ok(res) => res, + Err(e) => { + error!("Failed to send request to {}: {}", worker_url, e); + + // Decrement load on error if it was incremented + if load_incremented { + if let Ok(workers_guard) = self.workers.read() { + if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { + worker.decrement_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + } + } + } + + return HttpResponse::InternalServerError().body(format!("Request failed: {}", e)); + } + }; + + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + if !is_stream { + // For non-streaming requests, get response first + let response = match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => { + let error_msg = format!("Failed to get response body: {}", e); + HttpResponse::InternalServerError().body(error_msg) + } + }; + + // Decrement load counter for non-streaming requests if it was incremented + if load_incremented && !is_stream { + if let Ok(workers_guard) = self.workers.read() { + if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { + worker.decrement_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + } + } + } + + // Record metrics + let duration = start.elapsed(); + histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) + .record(duration.as_secs_f64()); + counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); + + response + } else if load_incremented { + // For streaming with load tracking, we need to manually decrement when done + let workers = Arc::clone(&self.workers); + let worker_url = worker_url.to_string(); + + HttpResponse::build(status) + .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .streaming( + res.bytes_stream() + .map_err(|_| { + actix_web::error::ErrorInternalServerError("Failed to read stream") + }) + .inspect(move |bytes| { + if let Ok(bytes) = bytes { + if bytes + .as_ref() + .windows(12) + .any(|window| window == b"data: [DONE]") + { + if let Ok(workers_guard) = workers.read() { + if let Some(worker) = + workers_guard.iter().find(|w| w.url() == &worker_url) + { + worker.decrement_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + debug!("Streaming is done!!") + } + } + } + } + }), + ) + } else { + // For requests without load tracking, just stream + HttpResponse::build(status) + .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .streaming(res.bytes_stream().map_err(|_| { + actix_web::error::ErrorInternalServerError("Failed to read stream") + })) + } + } + + pub async fn add_worker(&self, worker_url: &str) -> Result { + let start_time = std::time::Instant::now(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(self.timeout_secs)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; + + loop { + if start_time.elapsed() > Duration::from_secs(self.timeout_secs) { + error!( + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + self.timeout_secs, worker_url + ); + return Err(format!( + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + self.timeout_secs, worker_url + )); + } + + match client.get(&format!("{}/health", worker_url)).send().await { + Ok(res) => { + if res.status().is_success() { + info!("Worker {} health check passed", worker_url); + let mut workers_guard = self.workers.write().unwrap(); + if workers_guard.iter().any(|w| w.url() == worker_url) { + return Err(format!("Worker {} already exists", worker_url)); + } + info!("Added worker: {}", worker_url); + let new_worker = WorkerFactory::create_regular(worker_url.to_string()); + workers_guard.push(new_worker); + gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); + + // If cache aware policy, initialize the worker in the tree + if let Some(cache_aware) = + self.policy + .as_any() + .downcast_ref::() + { + // Get updated workers after adding + drop(workers_guard); + let workers_guard = self.workers.read().unwrap(); + cache_aware.init_workers(&workers_guard); + } + + return Ok(format!("Successfully added worker: {}", worker_url)); + } else { + info!( + "Worker {} health check is pending with status: {}.", + worker_url, + res.status() + ); + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") + { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(self.interval_secs)).await; + continue; + } + } + Err(e) => { + info!( + "Worker {} health check is pending with error: {}", + worker_url, e + ); + + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(self.interval_secs)).await; + continue; + } + } + } + } + + pub fn remove_worker(&self, worker_url: &str) { + let mut workers_guard = self.workers.write().unwrap(); + if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { + workers_guard.remove(index); + info!("Removed worker: {}", worker_url); + gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); + } else { + warn!("Worker {} not found, skipping removal", worker_url); + return; + } + + // If cache aware policy, remove the worker from the tree + if let Some(cache_aware) = self + .policy + .as_any() + .downcast_ref::() + { + cache_aware.remove_worker(worker_url); + info!("Removed worker from tree: {}", worker_url); + } + } + + async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option { + match client.get(&format!("{}/get_load", worker_url)).send().await { + Ok(res) if res.status().is_success() => match res.bytes().await { + Ok(bytes) => match serde_json::from_slice::(&bytes) { + Ok(data) => data + .get("load") + .and_then(|v| v.as_i64()) + .map(|v| v as isize), + Err(e) => { + debug!("Failed to parse load response from {}: {}", worker_url, e); + None + } + }, + Err(e) => { + debug!("Failed to read load response from {}: {}", worker_url, e); + None + } + }, + Ok(res) => { + debug!( + "Worker {} returned non-success status: {}", + worker_url, + res.status() + ); + None + } + Err(e) => { + debug!("Failed to get load from {}: {}", worker_url, e); + None + } + } + } + + // Background task to monitor worker loads + async fn monitor_worker_loads( + worker_urls: Vec, + tx: tokio::sync::watch::Sender>, + interval_secs: u64, + policy: Arc, + ) { + let client = match reqwest::Client::builder() + .timeout(Duration::from_secs(5)) + .build() + { + Ok(c) => c, + Err(e) => { + error!("Failed to create HTTP client for load monitoring: {}", e); + return; + } + }; + + let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); + + loop { + interval.tick().await; + + let mut loads = HashMap::new(); + for url in &worker_urls { + if let Some(load) = Self::get_worker_load_static(&client, url).await { + loads.insert(url.clone(), load); + debug!("Worker {} load: {}", url, load); + } + } + + if !loads.is_empty() { + // Update policy with new loads + policy.update_loads(&loads); + + // Send to watchers + if let Err(e) = tx.send(loads) { + error!("Failed to send load update: {}", e); + } + } + } + } + + // Static version of get_worker_load for use in monitoring task + async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option { + match client.get(&format!("{}/get_load", worker_url)).send().await { + Ok(res) if res.status().is_success() => match res.bytes().await { + Ok(bytes) => match serde_json::from_slice::(&bytes) { + Ok(data) => data + .get("load") + .and_then(|v| v.as_i64()) + .map(|v| v as isize), + Err(e) => { + debug!("Failed to parse load response from {}: {}", worker_url, e); + None + } + }, + Err(e) => { + debug!("Failed to read load response from {}: {}", worker_url, e); + None + } + }, + Ok(res) => { + debug!( + "Worker {} returned non-success status: {}", + worker_url, + res.status() + ); + None + } + Err(e) => { + debug!("Failed to get load from {}: {}", worker_url, e); + None + } + } + } +} + +use crate::routers::{RouterTrait, WorkerManagement}; +use async_trait::async_trait; +use reqwest::Client; + +#[async_trait] +impl WorkerManagement for Router { + async fn add_worker(&self, worker_url: &str) -> Result { + Router::add_worker(self, worker_url).await + } + + fn remove_worker(&self, worker_url: &str) { + Router::remove_worker(self, worker_url) + } + + fn get_worker_urls(&self) -> Vec { + Router::get_worker_urls(self) + } +} + +#[async_trait(?Send)] +impl RouterTrait for Router { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + async fn health(&self, _client: &Client, _req: &HttpRequest) -> HttpResponse { + // Check local health state of all workers (consistent with PD router) + // Note: This uses cached health status from background health checks, not live checks + let mut all_healthy = true; + let mut unhealthy_servers = Vec::new(); + + for worker in self.workers.read().unwrap().iter() { + if !worker.is_healthy() { + all_healthy = false; + unhealthy_servers.push(worker.url().to_string()); + } + } + + if all_healthy { + HttpResponse::Ok().body("All servers healthy") + } else { + HttpResponse::ServiceUnavailable() + .body(format!("Unhealthy servers: {:?}", unhealthy_servers)) + } + } + + async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + // Test model generation capability by sending to first available worker + // Note: This endpoint actually causes the model to generate a token, so we only test one worker + self.route_to_first(client, "/health_generate", req).await + } + + async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + self.route_to_first(client, "/get_server_info", req).await + } + + async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + self.route_to_first(client, "/v1/models", req).await + } + + async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + self.route_to_first(client, "/get_model_info", req).await + } + + async fn route_generate( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + // Convert JSON to typed request + match serde_json::from_value::(body) { + Ok(typed_req) => { + self.route_typed_request(client, req, &typed_req, "/generate") + .await + } + Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)), + } + } + + async fn route_chat( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + // Convert JSON to typed request + match serde_json::from_value::(body) { + Ok(typed_req) => { + self.route_typed_request(client, req, &typed_req, "/v1/chat/completions") + .await + } + Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)), + } + } + + async fn route_completion( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + // Convert JSON to typed request + match serde_json::from_value::(body) { + Ok(typed_req) => { + self.route_typed_request(client, req, &typed_req, "/v1/completions") + .await + } + Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)), + } + } + + async fn flush_cache(&self, client: &Client) -> HttpResponse { + // Get all worker URLs + let worker_urls = self.get_worker_urls(); + + // Send requests to all workers concurrently without headers + let mut tasks = Vec::new(); + for worker_url in &worker_urls { + let request_builder = client.post(format!("{}/flush_cache", worker_url)); + tasks.push(request_builder.send()); + } + + // Wait for all responses + let results = futures_util::future::join_all(tasks).await; + + // Check if all succeeded + let all_success = results.iter().all(|r| { + r.as_ref() + .map(|res| res.status().is_success()) + .unwrap_or(false) + }); + + if all_success { + HttpResponse::Ok().body("Cache flushed on all servers") + } else { + HttpResponse::InternalServerError().body("Cache flush failed on one or more servers") + } + } + + async fn get_worker_loads(&self, client: &Client) -> HttpResponse { + let urls = self.get_worker_urls(); + let mut loads = Vec::new(); + + // Get loads from all workers + for url in &urls { + let load = self.get_worker_load(client, url).await.unwrap_or(-1); + loads.push(serde_json::json!({ + "worker": url, + "load": load + })); + } + + HttpResponse::Ok().json(serde_json::json!({ + "workers": loads + })) + } + + fn router_type(&self) -> &'static str { + "regular" + } + + fn readiness(&self) -> HttpResponse { + // Regular router is ready if it has at least one healthy worker + let healthy_count = self + .workers + .read() + .unwrap() + .iter() + .filter(|w| w.is_healthy()) + .count(); + + if healthy_count > 0 { + HttpResponse::Ok().json(serde_json::json!({ + "status": "ready", + "healthy_workers": healthy_count, + "total_workers": self.workers.read().unwrap().len() + })) + } else { + HttpResponse::ServiceUnavailable().json(serde_json::json!({ + "status": "not_ready", + "reason": "no healthy workers available", + "total_workers": self.workers.read().unwrap().len() + })) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::policies::RandomPolicy; + use std::collections::HashMap; + + fn create_test_regular_router() -> Router { + let workers = vec![ + WorkerFactory::create_regular("http://worker1:8080".to_string()), + WorkerFactory::create_regular("http://worker2:8080".to_string()), + ]; + let (_, rx) = tokio::sync::watch::channel(HashMap::new()); + Router { + workers: Arc::new(RwLock::new(workers)), + policy: Arc::new(RandomPolicy::new()), + timeout_secs: 5, + interval_secs: 1, + _worker_loads: Arc::new(rx), + _load_monitor_handle: None, + _health_checker: None, + } + } + + #[test] + fn test_router_get_worker_urls_regular() { + let router = create_test_regular_router(); + let urls = router.get_worker_urls(); + + assert_eq!(urls.len(), 2); + assert!(urls.contains(&"http://worker1:8080".to_string())); + assert!(urls.contains(&"http://worker2:8080".to_string())); + } + + #[test] + fn test_select_first_worker_regular() { + let router = create_test_regular_router(); + let result = router.select_first_worker(); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "http://worker1:8080"); + } + + #[test] + fn test_wait_for_healthy_workers_empty_list() { + let result = Router::wait_for_healthy_workers(&[], 1, 1); + assert!(result.is_ok()); + } + + #[test] + fn test_wait_for_healthy_workers_invalid_urls() { + // This test will timeout quickly since the URLs are invalid + let result = + Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Timeout")); + } +} diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index bb2695b93..69340eefe 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,9 +1,8 @@ +use crate::config::RouterConfig; use crate::logging::{self, LoggingConfig}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::prometheus::{self, PrometheusConfig}; -use crate::request_adapter::ToPdRequest; -use crate::router::PolicyConfig; -use crate::router::Router; +use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use actix_web::{ error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder, @@ -19,27 +18,19 @@ use tracing::{error, info, warn, Level}; #[derive(Debug)] pub struct AppState { - router: Arc, + router: Arc, client: Client, - is_pd_mode: bool, // Add flag to track PD mode } impl AppState { - pub fn new( - worker_urls: Vec, - client: Client, - policy_config: PolicyConfig, - ) -> Result { - // Check if this is PD mode from policy config - let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. }); + pub fn new(router_config: RouterConfig, client: Client) -> Result { + // Use RouterFactory to create the appropriate router type + let router = RouterFactory::create_router(&router_config)?; - // Create router based on policy - let router = Arc::new(Router::new(worker_urls, policy_config)?); - Ok(Self { - router, - client, - is_pd_mode, - }) + // Convert Box to Arc + let router = Arc::from(router); + + Ok(Self { router, client }) } } @@ -76,65 +67,39 @@ fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error } } +#[get("/liveness")] +async fn liveness(_req: HttpRequest, data: web::Data) -> impl Responder { + data.router.liveness() +} + +#[get("/readiness")] +async fn readiness(_req: HttpRequest, data: web::Data) -> impl Responder { + data.router.readiness() +} + #[get("/health")] async fn health(req: HttpRequest, data: web::Data) -> impl Responder { - data.router - .route_to_first(&data.client, "/health", &req) - .await + data.router.health(&data.client, &req).await } #[get("/health_generate")] async fn health_generate(req: HttpRequest, data: web::Data) -> impl Responder { - // Check if we're in PD mode - if data.is_pd_mode { - // For PD mode, check health on all servers - data.router - .route_pd_health_generate(&data.client, &req) - .await - } else { - // Regular mode - data.router - .route_to_first(&data.client, "/health_generate", &req) - .await - } + data.router.health_generate(&data.client, &req).await } #[get("/get_server_info")] async fn get_server_info(req: HttpRequest, data: web::Data) -> impl Responder { - if data.is_pd_mode { - // For PD mode, aggregate info from both prefill and decode servers - data.router.get_pd_server_info(&data.client, &req).await - } else { - // Regular mode - return first server's info - data.router - .route_to_first(&data.client, "/get_server_info", &req) - .await - } + data.router.get_server_info(&data.client, &req).await } #[get("/v1/models")] async fn v1_models(req: HttpRequest, data: web::Data) -> impl Responder { - if data.is_pd_mode { - // For PD mode, return models from the first prefill server - data.router.get_pd_models(&data.client, &req).await - } else { - // Regular mode - data.router - .route_to_first(&data.client, "/v1/models", &req) - .await - } + data.router.get_models(&data.client, &req).await } #[get("/get_model_info")] async fn get_model_info(req: HttpRequest, data: web::Data) -> impl Responder { - if data.is_pd_mode { - // For PD mode, get model info from the first prefill server - data.router.get_pd_model_info(&data.client, &req).await - } else { - data.router - .route_to_first(&data.client, "/get_model_info", &req) - .await - } + data.router.get_model_info(&data.client, &req).await } #[post("/generate")] @@ -143,24 +108,12 @@ async fn generate( body: web::Json, state: web::Data, ) -> Result { - let client = &state.client; - let router = &state.router; - - // Use typed request directly for both PD and regular routing - if state.is_pd_mode { - // For PD mode, convert to PD request with bootstrap - let pd_request = body.into_inner().to_pd_request(); - - Ok(router - .route_pd_generate_typed(&client, &req, pd_request, "/generate") - .await) - } else { - // For regular mode, use typed request directly - let request = body.into_inner(); - Ok(router - .route_typed_request(&client, &req, &request, "/generate") - .await) - } + let json_body = serde_json::to_value(body.into_inner()) + .map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; + Ok(state + .router + .route_generate(&state.client, &req, json_body) + .await) } #[post("/v1/chat/completions")] @@ -169,24 +122,12 @@ async fn v1_chat_completions( body: web::Json, state: web::Data, ) -> Result { - let client = &state.client; - let router = &state.router; - - // Use typed request directly for both PD and regular routing - if state.is_pd_mode { - // For PD mode, convert to PD request with bootstrap - let pd_request = body.into_inner().to_pd_request(); - - Ok(router - .route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions") - .await) - } else { - // For regular mode, use typed request directly - let request = body.into_inner(); - Ok(router - .route_typed_request(&client, &req, &request, "/v1/chat/completions") - .await) - } + let json_body = serde_json::to_value(body.into_inner()) + .map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; + Ok(state + .router + .route_chat(&state.client, &req, json_body) + .await) } #[post("/v1/completions")] @@ -195,24 +136,12 @@ async fn v1_completions( body: web::Json, state: web::Data, ) -> Result { - let client = &state.client; - let router = &state.router; - - // Use typed request directly for both PD and regular routing - if state.is_pd_mode { - // For PD mode, convert to PD request with bootstrap - let pd_request = body.into_inner().to_pd_request(); - - Ok(router - .route_pd_generate_typed(&client, &req, pd_request, "/v1/completions") - .await) - } else { - // For regular mode, use typed request directly - let request = body.into_inner(); - Ok(router - .route_typed_request(&client, &req, &request, "/v1/completions") - .await) - } + let json_body = serde_json::to_value(body.into_inner()) + .map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; + Ok(state + .router + .route_completion(&state.client, &req, json_body) + .await) } #[post("/add_worker")] @@ -254,29 +183,19 @@ async fn remove_worker( } #[post("/flush_cache")] -async fn flush_cache(req: HttpRequest, data: web::Data) -> impl Responder { - if data.is_pd_mode { - // For PD mode, flush cache on both prefill and decode servers - data.router.route_pd_flush_cache(&data.client).await - } else { - // Route to all workers for cache flushing - data.router - .route_to_all(&data.client, "/flush_cache", &req) - .await - } +async fn flush_cache(_req: HttpRequest, data: web::Data) -> impl Responder { + data.router.flush_cache(&data.client).await } #[get("/get_loads")] -async fn get_loads(req: HttpRequest, data: web::Data) -> impl Responder { - // Get loads from all workers - data.router.get_all_loads(&data.client, &req).await +async fn get_loads(_req: HttpRequest, data: web::Data) -> impl Responder { + data.router.get_worker_loads(&data.client).await } pub struct ServerConfig { pub host: String, pub port: u16, - pub worker_urls: Vec, - pub policy_config: PolicyConfig, + pub router_config: RouterConfig, pub max_payload_size: usize, pub log_dir: Option, pub log_level: Option, @@ -324,8 +243,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { } info!("🚧 Initializing router on {}:{}", config.host, config.port); - info!("🚧 Initializing workers on {:?}", config.worker_urls); - info!("🚧 Policy Config: {:?}", config.policy_config); + info!("🚧 Router mode: {:?}", config.router_config.mode); + info!("🚧 Policy: {:?}", config.router_config.policy); info!( "🚧 Max payload size: {} MB", config.max_payload_size / (1024 * 1024) @@ -345,12 +264,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .build() .expect("Failed to create HTTP client"); - let app_state_init = AppState::new( - config.worker_urls.clone(), - client.clone(), - config.policy_config.clone(), - ) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + let app_state_init = AppState::new(config.router_config.clone(), client.clone()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; let router_arc = Arc::clone(&app_state_init.router); let app_state = web::Data::new(app_state_init); @@ -397,6 +312,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(v1_completions) .service(v1_models) .service(get_model_info) + .service(liveness) + .service(readiness) .service(health) .service(health_generate) .service(get_server_info) diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 0e78717ce..72d78b490 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -1,4 +1,4 @@ -use crate::router::Router; +use crate::routers::RouterTrait; use futures::{StreamExt, TryStreamExt}; use k8s_openapi::api::core::v1::Pod; @@ -176,7 +176,7 @@ impl PodInfo { pub async fn start_service_discovery( config: ServiceDiscoveryConfig, - router: Arc, + router: Arc, ) -> Result, kube::Error> { // Don't initialize anything if service discovery is disabled if !config.enabled { @@ -346,7 +346,7 @@ pub async fn start_service_discovery( async fn handle_pod_event( pod_info: &PodInfo, tracked_pods: Arc>>, - router: Arc, + router: Arc, port: u16, pd_mode: bool, ) { @@ -379,17 +379,32 @@ async fn handle_pod_event( pod_info.name, pod_info.pod_type, worker_url ); + // Handle PD mode with specific pod types let result = if pd_mode && pod_info.pod_type.is_some() { - // Use PD-aware worker management - if let Some(pod_type) = &pod_info.pod_type { - router - .add_pd_worker(&worker_url, pod_type.clone(), pod_info.bootstrap_port) - .await + // Need to import PDRouter type + use crate::routers::pd_router::PDRouter; + + // Try to downcast to PDRouter + if let Some(pd_router) = router.as_any().downcast_ref::() { + match &pod_info.pod_type { + Some(PodType::Prefill) => pd_router + .add_prefill_server(worker_url.clone(), pod_info.bootstrap_port) + .await + .map_err(|e| e.to_string()), + Some(PodType::Decode) => pd_router + .add_decode_server(worker_url.clone()) + .await + .map_err(|e| e.to_string()), + Some(PodType::Regular) | None => { + // Fall back to regular add_worker for regular pods + router.add_worker(&worker_url).await + } + } } else { - Err("Pod type is None in PD mode".to_string()) + Err("PD mode enabled but router is not a PDRouter".to_string()) } } else { - // Fallback to regular worker management + // Regular mode or no pod type specified router.add_worker(&worker_url).await }; @@ -412,7 +427,7 @@ async fn handle_pod_event( async fn handle_pod_deletion( pod_info: &PodInfo, tracked_pods: Arc>>, - router: Arc, + router: Arc, port: u16, pd_mode: bool, ) { @@ -435,18 +450,34 @@ async fn handle_pod_deletion( pod_info.name, pod_info.pod_type, worker_url ); + // Handle PD mode removal if pd_mode && pod_info.pod_type.is_some() { - // Use PD-aware worker removal - if let Some(pod_type) = &pod_info.pod_type { - if let Err(e) = router.remove_pd_worker(&worker_url, pod_type.clone()).await { - error!( - "Failed to remove PD worker {} from router: {}", - worker_url, e - ); + use crate::routers::pd_router::PDRouter; + + // Try to downcast to PDRouter for PD-specific removal + if let Some(pd_router) = router.as_any().downcast_ref::() { + match &pod_info.pod_type { + Some(PodType::Prefill) => { + if let Err(e) = pd_router.remove_prefill_server(&worker_url).await { + error!("Failed to remove prefill server {}: {}", worker_url, e); + } + } + Some(PodType::Decode) => { + if let Err(e) = pd_router.remove_decode_server(&worker_url).await { + error!("Failed to remove decode server {}: {}", worker_url, e); + } + } + Some(PodType::Regular) | None => { + // Fall back to regular remove_worker + router.remove_worker(&worker_url); + } } + } else { + // PD mode but not a PDRouter, use generic removal + router.remove_worker(&worker_url); } } else { - // Fallback to regular worker removal + // Regular mode removal router.remove_worker(&worker_url); } } else { @@ -462,11 +493,9 @@ async fn handle_pod_deletion( #[cfg(test)] mod tests { use super::*; - use crate::router::Router; use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus}; use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time; - use std::sync::RwLock; // Helper function to create a Pod for testing PodInfo::from_pod fn create_k8s_pod( @@ -546,14 +575,14 @@ mod tests { } // Helper to create a Router instance for testing event handlers - fn create_test_router() -> Arc { - let workers = Arc::new(RwLock::new(Vec::new())); - Arc::new(Router::Random { - workers, - timeout_secs: 5, - interval_secs: 1, - _health_checker: None, - }) + fn create_test_router() -> Arc { + use crate::config::PolicyConfig; + use crate::policies::PolicyFactory; + use crate::routers::router::Router; + + let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); + let router = Router::new(vec![], policy, 5, 1).unwrap(); + Arc::new(router) as Arc } // Helper to create a PD config for testing diff --git a/sgl-router/tests/benchmark_integration.rs b/sgl-router/tests/benchmark_integration.rs index b21c93fcf..317859000 100644 --- a/sgl-router/tests/benchmark_integration.rs +++ b/sgl-router/tests/benchmark_integration.rs @@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, }; -use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest}; +use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; #[test] fn test_benchmark_request_creation() { diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 02b8c99f5..ceb5fe9e6 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -8,12 +8,18 @@ //! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type. //! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode. +// TODO: This test file needs to be updated for the new configuration structure +// where RoutingMode and PolicyConfig are separate + #[cfg(test)] mod test_pd_routing { use rand::Rng; use serde_json::json; - use sglang_router_rs::pd_types::PDSelectionPolicy; - use sglang_router_rs::router::{PolicyConfig, Router}; + use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; + use sglang_router_rs::core::{WorkerFactory, WorkerType}; + use sglang_router_rs::routers::pd_types::get_hostname; + use sglang_router_rs::routers::pd_types::PDSelectionPolicy; + use sglang_router_rs::routers::RouterFactory; // Test-only struct to help validate PD request parsing #[derive(Debug)] @@ -116,49 +122,68 @@ mod test_pd_routing { #[test] fn test_pd_router_configuration() { - // Test PrefillDecodeConfig creation with various policies - // This config is used when pd_disaggregation=true - let configs = vec![ - PolicyConfig::PrefillDecodeConfig { - selection_policy: PDSelectionPolicy::Random, - prefill_urls: vec![ - ("http://prefill1:8080".to_string(), Some(9000)), - ("http://prefill2:8080".to_string(), None), - ], - decode_urls: vec![ - "http://decode1:8080".to_string(), - "http://decode2:8080".to_string(), - ], - timeout_secs: 10, - interval_secs: 1, - }, - PolicyConfig::PrefillDecodeConfig { - selection_policy: PDSelectionPolicy::PowerOfTwo, - prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))], - decode_urls: vec!["http://decode:8080".to_string()], - timeout_secs: 5, - interval_secs: 1, - }, - PolicyConfig::PrefillDecodeConfig { - selection_policy: PDSelectionPolicy::CacheAware { + // Test PD router configuration with various policies + // In the new structure, RoutingMode and PolicyConfig are separate + let test_cases = vec![ + ( + RoutingMode::PrefillDecode { + prefill_urls: vec![ + ("http://prefill1:8080".to_string(), Some(9000)), + ("http://prefill2:8080".to_string(), None), + ], + decode_urls: vec![ + "http://decode1:8080".to_string(), + "http://decode2:8080".to_string(), + ], + }, + PolicyConfig::Random, + ), + ( + RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))], + decode_urls: vec!["http://decode:8080".to_string()], + }, + PolicyConfig::PowerOfTwo { + load_check_interval_secs: 5, + }, + ), + ( + RoutingMode::PrefillDecode { + prefill_urls: vec![ + ("http://p1:8080".to_string(), Some(9000)), + ("http://p2:8080".to_string(), Some(9001)), + ("http://p3:8080".to_string(), Some(9002)), + ], + decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()], + }, + PolicyConfig::CacheAware { cache_threshold: 0.7, balance_abs_threshold: 20, balance_rel_threshold: 1.2, + eviction_interval_secs: 60, + max_tree_size: 1000000, }, - prefill_urls: vec![ - ("http://p1:8080".to_string(), Some(9000)), - ("http://p2:8080".to_string(), Some(9001)), - ("http://p3:8080".to_string(), Some(9002)), - ], - decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()], - timeout_secs: 10, - interval_secs: 2, - }, + ), ]; - for config in configs { + for (mode, policy) in test_cases { + let config = RouterConfig { + mode, + policy, + host: "127.0.0.1".to_string(), + port: 3001, + max_payload_size: 1024 * 1024, + request_timeout_secs: 60, + worker_startup_timeout_secs: 10, + worker_startup_check_interval_secs: 1, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + }; + // Router creation will fail due to health checks, but config should be valid - let result = Router::new(vec![], config); + let result = RouterFactory::create_router(&config); assert!(result.is_err()); let error_msg = result.unwrap_err(); // Error should be about health/timeout, not configuration @@ -225,9 +250,6 @@ mod test_pd_routing { #[test] fn test_bootstrap_injection_simulation() { - use sglang_router_rs::core::{WorkerFactory, WorkerType}; - use sglang_router_rs::pd_types::get_hostname; - // Since we can't test the actual inject_bootstrap_fields function here // (it's private in the router module), we'll test the expected behavior @@ -315,8 +337,6 @@ mod test_pd_routing { #[test] fn test_hostname_extraction() { - use sglang_router_rs::pd_types::get_hostname; - // Test various URL formats let test_cases = vec![ ("http://localhost:8080", "localhost"), @@ -662,7 +682,6 @@ mod test_pd_routing { #[test] fn test_bootstrap_injection_with_benchmark_requests() { use sglang_router_rs::core::{WorkerFactory, WorkerType}; - use sglang_router_rs::pd_types::get_hostname; // Test bootstrap injection with actual benchmark request patterns let mut benchmark_request = json!({ @@ -790,9 +809,6 @@ mod test_pd_routing { #[test] fn test_large_batch_bootstrap_injection() { - use sglang_router_rs::core::{WorkerFactory, WorkerType}; - use sglang_router_rs::pd_types::get_hostname; - // Test bootstrap injection performance with very large batches // This simulates the bench_one_batch_server.py scenario let large_batch_sizes = vec![1024, 4096, 8192];