[router] Refactor router and policy traits with dependency injection (#7987)
Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: Keru Yang <rukeyang@gmail.com> Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: Philip Zhu <phlipzhux@gmail.com>
This commit is contained in:
294
.github/workflows/pr-test-pd-router.yml
vendored
294
.github/workflows/pr-test-pd-router.yml
vendored
@@ -131,110 +131,199 @@ jobs:
|
||||
SERVER_PID=$!
|
||||
echo "server_pid=$SERVER_PID" >> $GITHUB_OUTPUT
|
||||
|
||||
echo "Waiting for router to become healthy..."
|
||||
TIMEOUT=300
|
||||
ELAPSED=0
|
||||
while [ $ELAPSED -lt $TIMEOUT ]; do
|
||||
if curl --connect-timeout 5 --silent http://127.0.0.9:8000 > /dev/null 2>&1; then
|
||||
echo "✓ Router is reachable"
|
||||
break
|
||||
# Wait for all 8 servers to be healthy (script already does this)
|
||||
wait_count=0
|
||||
while [ $wait_count -lt 30 ]; do
|
||||
if ps -p $SERVER_PID > /dev/null; then
|
||||
# Check if the startup script printed success message
|
||||
sleep 2
|
||||
wait_count=$((wait_count + 1))
|
||||
else
|
||||
# Script exited - check if it was successful
|
||||
wait $SERVER_PID
|
||||
exit_code=$?
|
||||
if [ $exit_code -eq 0 ]; then
|
||||
echo "✓ All disaggregation servers are healthy"
|
||||
break
|
||||
else
|
||||
echo "Error: Server startup failed with code $exit_code"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
if ! ps -p $SERVER_PID > /dev/null; then
|
||||
echo "Error: Server processes failed to start"
|
||||
exit 1
|
||||
fi
|
||||
echo "Waiting for router... (${ELAPSED}s/${TIMEOUT}s)"
|
||||
sleep 10
|
||||
ELAPSED=$((ELAPSED + 10))
|
||||
done
|
||||
|
||||
if [ $ELAPSED -ge $TIMEOUT ]; then
|
||||
echo "Error: Router health check timeout after ${TIMEOUT}s"
|
||||
exit 1
|
||||
fi
|
||||
echo "✓ Servers started (PID: $SERVER_PID)"
|
||||
|
||||
echo "✓ Servers started and healthy (PID: $SERVER_PID)"
|
||||
|
||||
- name: Test API functionality
|
||||
timeout-minutes: 5
|
||||
- name: Test all policies sequentially
|
||||
timeout-minutes: 30
|
||||
run: |
|
||||
POLICIES=("random" "round_robin" "cache_aware" "power_of_two")
|
||||
BASE_URL="http://127.0.0.9:8000"
|
||||
|
||||
echo "Testing API completions..."
|
||||
response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer test-token" \
|
||||
-d '{
|
||||
"model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a Python function to calculate fibonacci numbers recursively"}
|
||||
],
|
||||
"stream": false,
|
||||
"max_tokens": 100
|
||||
}')
|
||||
for policy in "${POLICIES[@]}"; do
|
||||
echo ""
|
||||
echo "=================================================="
|
||||
echo "Testing policy: $policy"
|
||||
echo "=================================================="
|
||||
|
||||
if echo "$response" | jq -e '.choices[0].message.content' > /dev/null 2>&1; then
|
||||
echo "✓ API test passed"
|
||||
else
|
||||
echo "✗ API test failed: $response"
|
||||
exit 1
|
||||
fi
|
||||
# Start router with the current policy
|
||||
echo "Starting router with policy: $policy..."
|
||||
python3 -m sglang_router.launch_router \
|
||||
--pd-disaggregation \
|
||||
--policy "$policy" \
|
||||
--prefill http://127.0.0.1:30001 9001 \
|
||||
--prefill http://127.0.0.2:30002 9002 \
|
||||
--prefill http://127.0.0.3:30003 9003 \
|
||||
--prefill http://127.0.0.4:30004 9004 \
|
||||
--decode http://127.0.0.5:30005 \
|
||||
--decode http://127.0.0.6:30006 \
|
||||
--decode http://127.0.0.7:30007 \
|
||||
--decode http://127.0.0.8:30008 \
|
||||
--host 127.0.0.9 \
|
||||
--port 8000 &
|
||||
ROUTER_PID=$!
|
||||
|
||||
echo "Testing streaming API..."
|
||||
stream_response=$(timeout 30 curl -s -X POST "$BASE_URL/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer test-token" \
|
||||
-d '{
|
||||
"model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Count from 1 to 5"}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 50
|
||||
}')
|
||||
# Wait for router to become healthy
|
||||
echo "Waiting for router to become healthy..."
|
||||
TIMEOUT=60
|
||||
ELAPSED=0
|
||||
while [ $ELAPSED -lt $TIMEOUT ]; do
|
||||
if curl --connect-timeout 5 --silent http://127.0.0.9:8000 > /dev/null 2>&1; then
|
||||
echo "✓ Router is reachable"
|
||||
break
|
||||
fi
|
||||
if ! ps -p $ROUTER_PID > /dev/null; then
|
||||
echo "Error: Router process died"
|
||||
exit 1
|
||||
fi
|
||||
sleep 5
|
||||
ELAPSED=$((ELAPSED + 5))
|
||||
done
|
||||
|
||||
if echo "$stream_response" | grep -q "data:"; then
|
||||
echo "✓ Streaming API test passed"
|
||||
else
|
||||
echo "✗ Streaming API test failed"
|
||||
exit 1
|
||||
fi
|
||||
if [ $ELAPSED -ge $TIMEOUT ]; then
|
||||
echo "Error: Router health check timeout"
|
||||
kill $ROUTER_PID 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Run benchmark test
|
||||
timeout-minutes: 5
|
||||
run: |
|
||||
echo "Running benchmark test..."
|
||||
benchmark_output=$(python3 -m sglang.bench_one_batch_server \
|
||||
--model-path "/raid/models/meta-llama/Llama-3.1-8B-Instruct" \
|
||||
--base-url "http://127.0.0.9:8000" \
|
||||
--batch-size 8 \
|
||||
--input-len 4096 \
|
||||
--output-len 5 \
|
||||
--skip-warmup)
|
||||
# Test API functionality
|
||||
echo "Testing API completions for $policy..."
|
||||
response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer test-token" \
|
||||
-d '{
|
||||
"model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a Python function to calculate fibonacci numbers recursively"}
|
||||
],
|
||||
"stream": false,
|
||||
"max_tokens": 100
|
||||
}')
|
||||
|
||||
echo "$benchmark_output"
|
||||
if echo "$response" | jq -e '.choices[0].message.content' > /dev/null 2>&1; then
|
||||
echo "✓ API test passed for $policy"
|
||||
else
|
||||
echo "✗ API test failed for $policy: $response"
|
||||
kill $ROUTER_PID 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract metrics from output
|
||||
latency=$(echo "$benchmark_output" | grep "latency:" | awk '{print $2}' | sed 's/s//')
|
||||
input_throughput=$(echo "$benchmark_output" | grep "input throughput:" | awk '{print $3}')
|
||||
output_throughput=$(echo "$benchmark_output" | grep "output throughput:" | awk '{print $3}')
|
||||
# Test streaming
|
||||
echo "Testing streaming API for $policy..."
|
||||
stream_response=$(timeout 30 curl -s -X POST "$BASE_URL/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer test-token" \
|
||||
-d '{
|
||||
"model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Count from 1 to 5"}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 50
|
||||
}')
|
||||
|
||||
# Validate performance (latency<1.5s, input>20k, output>1k)
|
||||
command -v bc >/dev/null || (apt-get update && apt-get install -y bc)
|
||||
if echo "$stream_response" | grep -q "data:"; then
|
||||
echo "✓ Streaming API test passed for $policy"
|
||||
else
|
||||
echo "✗ Streaming API test failed for $policy"
|
||||
kill $ROUTER_PID 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Performance: ${latency}s | ${input_throughput} | ${output_throughput} tok/s"
|
||||
# Run benchmark
|
||||
echo "Running benchmark for $policy..."
|
||||
benchmark_output=$(python3 -m sglang.bench_one_batch_server \
|
||||
--model-path "/raid/models/meta-llama/Llama-3.1-8B-Instruct" \
|
||||
--base-url "http://127.0.0.9:8000" \
|
||||
--batch-size 8 \
|
||||
--input-len 4096 \
|
||||
--output-len 5 \
|
||||
--skip-warmup)
|
||||
|
||||
fail=""
|
||||
(( $(echo "$latency > 1.5" | bc -l) )) && fail="Latency too high (${latency}s>1.5s) "
|
||||
(( $(echo "$input_throughput < 20000" | bc -l) )) && fail="${fail}Input too low (${input_throughput}<20k) "
|
||||
(( $(echo "$output_throughput < 1000" | bc -l) )) && fail="${fail}Output too low (${output_throughput}<1k) "
|
||||
echo "$benchmark_output"
|
||||
|
||||
if [ -n "$fail" ]; then
|
||||
echo "✗ Benchmark failed: $fail"
|
||||
exit 1
|
||||
else
|
||||
echo "✓ Performance validation passed"
|
||||
fi
|
||||
# Save benchmark output
|
||||
echo "$benchmark_output" > "benchmark_${policy}.txt"
|
||||
|
||||
# Extract and validate metrics
|
||||
latency=$(echo "$benchmark_output" | grep "latency:" | awk '{print $2}' | sed 's/s//')
|
||||
input_throughput=$(echo "$benchmark_output" | grep "input throughput:" | awk '{print $3}')
|
||||
output_throughput=$(echo "$benchmark_output" | grep "output throughput:" | awk '{print $3}')
|
||||
|
||||
command -v bc >/dev/null || (apt-get update && apt-get install -y bc)
|
||||
|
||||
echo "Performance for $policy: ${latency}s | ${input_throughput} | ${output_throughput} tok/s"
|
||||
|
||||
# Validate performance
|
||||
fail=""
|
||||
(( $(echo "$latency > 1.5" | bc -l) )) && fail="Latency too high (${latency}s>1.5s) "
|
||||
(( $(echo "$input_throughput < 20000" | bc -l) )) && fail="${fail}Input too low (${input_throughput}<20k) "
|
||||
(( $(echo "$output_throughput < 1000" | bc -l) )) && fail="${fail}Output too low (${output_throughput}<1k) "
|
||||
|
||||
if [ -n "$fail" ]; then
|
||||
echo "✗ Benchmark failed for $policy: $fail"
|
||||
kill $ROUTER_PID 2>/dev/null || true
|
||||
exit 1
|
||||
else
|
||||
echo "✓ Performance validation passed for $policy"
|
||||
fi
|
||||
|
||||
# Stop router before testing next policy
|
||||
echo "Stopping router for $policy..."
|
||||
# First try graceful shutdown
|
||||
kill $ROUTER_PID 2>/dev/null || true
|
||||
|
||||
# Wait up to 5 seconds for graceful shutdown
|
||||
for i in {1..5}; do
|
||||
if ! ps -p $ROUTER_PID > /dev/null 2>&1; then
|
||||
echo "Router stopped gracefully"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Force kill if still running
|
||||
if ps -p $ROUTER_PID > /dev/null 2>&1; then
|
||||
echo "Force killing router..."
|
||||
kill -9 $ROUTER_PID 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Short delay to ensure port is released
|
||||
sleep 2
|
||||
|
||||
echo "✓ Completed testing for $policy"
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "✅ All policies tested successfully!"
|
||||
|
||||
|
||||
- name: Upload benchmark results
|
||||
if: success()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: benchmark-results-all-policies
|
||||
path: benchmark_*.txt
|
||||
|
||||
- name: Cleanup servers
|
||||
if: always()
|
||||
@@ -247,3 +336,34 @@ jobs:
|
||||
sleep 5
|
||||
remaining=$(ps aux | grep -c "sglang.launch_server" || echo "0")
|
||||
echo "Cleanup completed. Remaining processes: $remaining"
|
||||
|
||||
summarize-benchmarks:
|
||||
needs: test-disaggregation
|
||||
runs-on: ubuntu-latest
|
||||
if: success()
|
||||
|
||||
steps:
|
||||
- name: Download benchmark results
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: benchmark-results-all-policies
|
||||
|
||||
- name: Create benchmark summary
|
||||
run: |
|
||||
echo "## PD Router Benchmark Results Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "| Policy | Latency (s) | Input Throughput (tok/s) | Output Throughput (tok/s) |" >> $GITHUB_STEP_SUMMARY
|
||||
echo "|--------|-------------|-------------------------|--------------------------|" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
for policy in random round_robin cache_aware power_of_two; do
|
||||
if [ -f "benchmark_${policy}.txt" ]; then
|
||||
latency=$(grep "latency:" "benchmark_${policy}.txt" | awk '{print $2}')
|
||||
input_throughput=$(grep "input throughput:" "benchmark_${policy}.txt" | awk '{print $3}')
|
||||
output_throughput=$(grep "output throughput:" "benchmark_${policy}.txt" | awk '{print $3}')
|
||||
|
||||
echo "| ${policy} | ${latency} | ${input_throughput} | ${output_throughput} |" >> $GITHUB_STEP_SUMMARY
|
||||
fi
|
||||
done
|
||||
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "✅ All policies tested successfully!" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
@@ -87,20 +87,8 @@ while true; do
|
||||
fi
|
||||
done
|
||||
|
||||
# Launch the router
|
||||
echo "Launching router at 127.0.0.9:8000..."
|
||||
python3 -m sglang_router.launch_router \
|
||||
--pd-disaggregation \
|
||||
--policy power_of_two \
|
||||
--prefill http://127.0.0.1:30001 9001 \
|
||||
--prefill http://127.0.0.2:30002 9002 \
|
||||
--prefill http://127.0.0.3:30003 9003 \
|
||||
--prefill http://127.0.0.4:30004 9004 \
|
||||
--decode http://127.0.0.5:30005 \
|
||||
--decode http://127.0.0.6:30006 \
|
||||
--decode http://127.0.0.7:30007 \
|
||||
--decode http://127.0.0.8:30008 \
|
||||
--host 127.0.0.9 \
|
||||
--port 8000 &
|
||||
# Don't launch router here - just keep servers running
|
||||
echo "✅ All disaggregation servers are ready and waiting for router connections"
|
||||
|
||||
wait # Wait for all background jobs to finish
|
||||
# Keep the script running
|
||||
wait # Wait for all background server jobs
|
||||
|
||||
@@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||
SamplingParams, StringOrArray, UserMessageContent,
|
||||
};
|
||||
use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest};
|
||||
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
|
||||
|
||||
// Sample request data for benchmarks
|
||||
fn create_sample_generate_request() -> GenerateRequest {
|
||||
|
||||
@@ -164,56 +164,47 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
"""Test that policy validation works correctly for PD and regular modes."""
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
# Test 1: PowerOfTwo is only valid in PD mode
|
||||
# Test 1: PowerOfTwo requires at least 2 workers
|
||||
args = self.create_router_args(
|
||||
pd_disaggregation=False,
|
||||
policy="power_of_two",
|
||||
worker_urls=["http://localhost:8000"],
|
||||
worker_urls=["http://localhost:8000"], # Only 1 worker
|
||||
)
|
||||
|
||||
# Should raise error
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
launch_router(args)
|
||||
self.assertIn(
|
||||
"PowerOfTwo policy is only supported in PD disaggregated mode",
|
||||
"Power-of-two policy requires at least 2 workers",
|
||||
str(cm.exception),
|
||||
)
|
||||
|
||||
# Test 2: RoundRobin is not valid in PD mode
|
||||
# Test 2: PowerOfTwo with sufficient workers should succeed
|
||||
args = self.create_router_args(
|
||||
pd_disaggregation=True,
|
||||
policy="round_robin",
|
||||
prefill=[["http://prefill1:8080", "9000"]],
|
||||
decode=[["http://decode1:8081"]],
|
||||
worker_urls=[],
|
||||
pd_disaggregation=False,
|
||||
policy="power_of_two",
|
||||
worker_urls=["http://localhost:8000", "http://localhost:8001"], # 2 workers
|
||||
)
|
||||
# This should not raise an error (validation passes)
|
||||
|
||||
# Should raise error
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
launch_router(args)
|
||||
self.assertIn(
|
||||
"RoundRobin policy is not supported in PD disaggregated mode",
|
||||
str(cm.exception),
|
||||
)
|
||||
|
||||
# Test 3: Valid combinations should not raise errors
|
||||
# Test 3: All policies now work in both modes
|
||||
# Regular mode with RoundRobin
|
||||
args = self.create_router_args(
|
||||
pd_disaggregation=False,
|
||||
policy="round_robin",
|
||||
worker_urls=["http://localhost:8000"],
|
||||
)
|
||||
# This should not raise (though it may fail to connect)
|
||||
# This should not raise validation error
|
||||
|
||||
# PD mode with PowerOfTwo
|
||||
# PD mode with RoundRobin (now supported!)
|
||||
args = self.create_router_args(
|
||||
pd_disaggregation=True,
|
||||
policy="power_of_two",
|
||||
policy="round_robin",
|
||||
prefill=[["http://prefill1:8080", "9000"]],
|
||||
decode=[["http://decode1:8081"]],
|
||||
worker_urls=[],
|
||||
)
|
||||
# This should not raise (though it may fail to connect)
|
||||
# This should not raise validation error
|
||||
|
||||
def test_pd_service_discovery_args_parsing(self):
|
||||
"""Test PD service discovery CLI argument parsing."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{ConfigError, ConfigResult};
|
||||
use super::ConfigResult;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -215,6 +215,7 @@ impl RouterConfig {
|
||||
self.metrics.is_some()
|
||||
}
|
||||
|
||||
/* Commented out - no longer needed without compatibility layer
|
||||
/// Convert to routing PolicyConfig for internal use
|
||||
pub fn to_routing_policy_config(&self) -> ConfigResult<crate::router::PolicyConfig> {
|
||||
match (&self.mode, &self.policy) {
|
||||
@@ -291,4 +292,5 @@ impl RouterConfig {
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
@@ -255,29 +255,8 @@ impl ConfigValidator {
|
||||
|
||||
/// Validate compatibility between different configuration sections
|
||||
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
|
||||
// Check mode and policy compatibility
|
||||
match (&config.mode, &config.policy) {
|
||||
(RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => {
|
||||
// PowerOfTwo is only supported in PD mode
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "PowerOfTwo policy is only supported in PD disaggregated mode"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
(RoutingMode::PrefillDecode { .. }, PolicyConfig::RoundRobin) => {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "RoundRobin policy is not supported in PD disaggregated mode"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
(RoutingMode::PrefillDecode { .. }, PolicyConfig::CacheAware { .. }) => {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "CacheAware policy is not supported in PD disaggregated mode"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
// All policies are now supported for both router types thanks to the unified trait design
|
||||
// No mode/policy restrictions needed anymore
|
||||
|
||||
// Check if service discovery is enabled for worker count validation
|
||||
let has_service_discovery = config.discovery.as_ref().map_or(false, |d| d.enabled);
|
||||
@@ -459,8 +438,8 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_incompatible_policy() {
|
||||
// RoundRobin with PD mode
|
||||
fn test_validate_roundrobin_with_pd_mode() {
|
||||
// RoundRobin with PD mode is now supported
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||
@@ -470,16 +449,12 @@ mod tests {
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("RoundRobin policy is not supported in PD disaggregated mode"));
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cache_aware_with_pd_mode() {
|
||||
// CacheAware with PD mode should fail
|
||||
// CacheAware with PD mode is now supported
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||
@@ -495,16 +470,12 @@ mod tests {
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("CacheAware policy is not supported in PD disaggregated mode"));
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_power_of_two_with_regular_mode() {
|
||||
// PowerOfTwo with Regular mode should fail
|
||||
// PowerOfTwo with Regular mode is now supported
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![
|
||||
@@ -518,10 +489,6 @@ mod tests {
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("PowerOfTwo policy is only supported in PD disaggregated mode"));
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,9 @@ pub mod logging;
|
||||
use std::collections::HashMap;
|
||||
pub mod core;
|
||||
pub mod openai_api_types;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
pub mod policies;
|
||||
pub mod prometheus;
|
||||
pub mod request_adapter;
|
||||
pub mod router;
|
||||
pub mod routers;
|
||||
pub mod server;
|
||||
pub mod service_discovery;
|
||||
pub mod tree;
|
||||
@@ -241,11 +239,6 @@ impl Router {
|
||||
))
|
||||
})?;
|
||||
|
||||
// Convert to internal policy config
|
||||
let policy_config = router_config
|
||||
.to_routing_policy_config()
|
||||
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
|
||||
|
||||
// Create service discovery config if enabled
|
||||
let service_discovery_config = if self.service_discovery {
|
||||
Some(service_discovery::ServiceDiscoveryConfig {
|
||||
@@ -282,8 +275,7 @@ impl Router {
|
||||
server::startup(server::ServerConfig {
|
||||
host: self.host.clone(),
|
||||
port: self.port,
|
||||
worker_urls: self.worker_urls.clone(),
|
||||
policy_config,
|
||||
router_config,
|
||||
max_payload_size: self.max_payload_size,
|
||||
log_dir: self.log_dir.clone(),
|
||||
log_level: self.log_level.clone(),
|
||||
|
||||
399
sgl-router/src/policies/cache_aware.rs
Normal file
399
sgl-router/src/policies/cache_aware.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
/*
|
||||
Cache-Aware Load Balancing Router
|
||||
|
||||
This router combines two strategies to optimize both cache utilization and request distribution:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
2. Load Balancing (Shortest Queue with Balance Thresholds)
|
||||
|
||||
The router dynamically switches between these strategies based on load conditions:
|
||||
- Uses load balancing when the system is imbalanced
|
||||
- Uses cache-aware routing when the system is balanced
|
||||
|
||||
A system is considered imbalanced if both conditions are met:
|
||||
1. (max - min) > abs_threshold
|
||||
2. max > rel_threshold * min
|
||||
|
||||
Strategy Details:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
-------------------------------------------
|
||||
This strategy maintains an approximate radix tree for each worker based on request history,
|
||||
eliminating the need for direct cache state queries. The tree stores raw text characters
|
||||
instead of token IDs to avoid tokenization overhead.
|
||||
|
||||
Process:
|
||||
a. For each request, find the worker with the highest prefix match
|
||||
b. If match rate > cache_threshold:
|
||||
Route to the worker with highest match (likely has relevant data cached)
|
||||
c. If match rate ≤ cache_threshold:
|
||||
Route to the worker with smallest tree size (most available cache capacity)
|
||||
d. Background maintenance:
|
||||
Periodically evict least recently used leaf nodes to prevent memory overflow
|
||||
|
||||
2. Load Balancing (Shortest Queue)
|
||||
-------------------------------------------
|
||||
This strategy tracks pending request counts per worker and routes new requests
|
||||
to the least busy worker when the system is detected to be imbalanced.
|
||||
|
||||
Configuration Parameters:
|
||||
------------------------
|
||||
1. cache_threshold: (float, 0.0 to 1.0)
|
||||
Minimum prefix match ratio to use highest-match routing.
|
||||
Below this threshold, routes to worker with most available cache space.
|
||||
|
||||
2. balance_abs_threshold: (integer)
|
||||
Absolute difference threshold for load imbalance detection.
|
||||
System is potentially imbalanced if (max_load - min_load) > abs_threshold
|
||||
|
||||
3. balance_rel_threshold: (float)
|
||||
Relative ratio threshold for load imbalance detection.
|
||||
System is potentially imbalanced if max_load > min_load * rel_threshold
|
||||
Used in conjunction with abs_threshold to determine final imbalance state.
|
||||
|
||||
4. eviction_interval_secs: (integer)
|
||||
Interval between LRU eviction cycles for the approximate trees.
|
||||
|
||||
5. max_tree_size: (integer)
|
||||
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
|
||||
during the next eviction cycle.
|
||||
*/
|
||||
|
||||
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::tree::Tree;
|
||||
use metrics::{counter, gauge};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Cache-aware routing policy
|
||||
///
|
||||
/// Routes requests based on cache affinity when load is balanced,
|
||||
/// switches to shortest-queue routing when load is imbalanced.
|
||||
#[derive(Debug)]
|
||||
pub struct CacheAwarePolicy {
|
||||
config: CacheAwareConfig,
|
||||
tree: Arc<Mutex<Tree>>,
|
||||
eviction_handle: Option<thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl CacheAwarePolicy {
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(CacheAwareConfig::default())
|
||||
}
|
||||
|
||||
pub fn with_config(config: CacheAwareConfig) -> Self {
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
|
||||
// Start background eviction thread if configured
|
||||
let eviction_handle = if config.eviction_interval_secs > 0 {
|
||||
let tree_clone = Arc::clone(&tree);
|
||||
let max_tree_size = config.max_tree_size;
|
||||
let interval = config.eviction_interval_secs;
|
||||
|
||||
Some(thread::spawn(move || loop {
|
||||
thread::sleep(Duration::from_secs(interval));
|
||||
|
||||
if let Ok(tree_guard) = tree_clone.lock() {
|
||||
tree_guard.evict_tenant_by_size(max_tree_size);
|
||||
debug!("Cache eviction completed, max_size: {}", max_tree_size);
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self {
|
||||
config,
|
||||
tree,
|
||||
eviction_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the tree with worker URLs
|
||||
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
for worker in workers {
|
||||
tree.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from the tree
|
||||
pub fn remove_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.remove_tenant(url);
|
||||
}
|
||||
}
|
||||
|
||||
/// Run cache eviction to prevent unbounded growth
|
||||
pub fn evict_cache(&self, max_size: usize) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.evict_tenant_by_size(max_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get current load statistics
|
||||
let loads: Vec<usize> = workers.iter().map(|w| w.load()).collect();
|
||||
let max_load = *loads.iter().max().unwrap_or(&0);
|
||||
let min_load = *loads.iter().min().unwrap_or(&0);
|
||||
|
||||
// Check if load is imbalanced
|
||||
let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold
|
||||
&& (max_load as f32) > (min_load as f32 * self.config.balance_rel_threshold);
|
||||
|
||||
if is_imbalanced {
|
||||
// Log load balancing trigger
|
||||
let worker_loads: Vec<(String, usize)> = workers
|
||||
.iter()
|
||||
.map(|w| (w.url().to_string(), w.load()))
|
||||
.collect();
|
||||
|
||||
info!(
|
||||
"Load balancing triggered due to workload imbalance:\n\
|
||||
Max load: {}, Min load: {}\n\
|
||||
Current worker loads: {:?}",
|
||||
max_load, min_load, worker_loads
|
||||
);
|
||||
|
||||
counter!("sgl_router_load_balancing_events_total").increment(1);
|
||||
gauge!("sgl_router_max_load").set(max_load as f64);
|
||||
gauge!("sgl_router_min_load").set(min_load as f64);
|
||||
|
||||
// Use shortest queue when imbalanced
|
||||
let min_load_idx = healthy_indices
|
||||
.iter()
|
||||
.min_by_key(|&&idx| workers[idx].load())
|
||||
.copied()?;
|
||||
|
||||
// Increment processed counter
|
||||
workers[min_load_idx].increment_processed();
|
||||
counter!("sgl_router_processed_requests_total", "worker" => workers[min_load_idx].url().to_string())
|
||||
.increment(1);
|
||||
|
||||
return Some(min_load_idx);
|
||||
}
|
||||
|
||||
// Use cache-aware routing when balanced
|
||||
let text = request_text.unwrap_or("");
|
||||
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
let (matched_text, matched_worker) = tree.prefix_match(text);
|
||||
let match_rate = if text.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
matched_text.chars().count() as f32 / text.chars().count() as f32
|
||||
};
|
||||
|
||||
let selected_url = if match_rate > self.config.cache_threshold {
|
||||
counter!("sgl_router_cache_hits_total").increment(1);
|
||||
matched_worker.to_string()
|
||||
} else {
|
||||
counter!("sgl_router_cache_misses_total").increment(1);
|
||||
tree.get_smallest_tenant()
|
||||
};
|
||||
|
||||
// Find the index of the selected worker
|
||||
let selected_idx = workers.iter().position(|w| w.url() == selected_url)?;
|
||||
|
||||
// Only proceed if the worker is healthy
|
||||
if !workers[selected_idx].is_healthy() {
|
||||
return healthy_indices.first().copied();
|
||||
}
|
||||
|
||||
// Update the tree with this request
|
||||
tree.insert(text, &selected_url);
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
counter!("sgl_router_processed_requests_total", "worker" => selected_url).increment(1);
|
||||
|
||||
return Some(selected_idx);
|
||||
}
|
||||
|
||||
// Fallback to first healthy worker if tree operations fail
|
||||
healthy_indices.first().copied()
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"cache_aware"
|
||||
}
|
||||
|
||||
fn on_request_complete(&self, worker_url: &str, success: bool) {
|
||||
// Could track success rates per worker for more intelligent routing
|
||||
if !success {
|
||||
// Optionally reduce affinity for failed requests
|
||||
tracing::debug!(
|
||||
"Request to {} completed with success={}",
|
||||
worker_url,
|
||||
success
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn select_worker_pair(
|
||||
&self,
|
||||
prefill_workers: &[Box<dyn Worker>],
|
||||
decode_workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<(usize, usize)> {
|
||||
// In PD mode:
|
||||
// - Prefill: Use cache-aware routing for better cache utilization
|
||||
// - Decode: Use least-load routing for better load distribution
|
||||
|
||||
// Select prefill worker using cache-aware logic
|
||||
let prefill_idx = self.select_worker(prefill_workers, request_text)?;
|
||||
|
||||
// Select decode worker using least-load logic
|
||||
let healthy_decode = get_healthy_worker_indices(decode_workers);
|
||||
if healthy_decode.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let decode_idx = healthy_decode
|
||||
.iter()
|
||||
.min_by_key(|&&idx| decode_workers[idx].load())
|
||||
.copied()?;
|
||||
|
||||
Some((prefill_idx, decode_idx))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CacheAwarePolicy {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CacheAwarePolicy {
|
||||
fn drop(&mut self) {
|
||||
// Note: We can't properly stop the eviction thread since it's in an infinite loop
|
||||
// In a production system, we'd use a channel or atomic flag to signal shutdown
|
||||
if let Some(handle) = self.eviction_handle.take() {
|
||||
// The thread will continue running until the program exits
|
||||
// This is acceptable for now since the router typically runs for the lifetime of the program
|
||||
drop(handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_cache_aware_with_balanced_load() {
|
||||
// Create policy without eviction thread for testing
|
||||
let config = CacheAwareConfig {
|
||||
eviction_interval_secs: 0, // Disable eviction thread
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Initialize the policy with workers
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// First request should be distributed
|
||||
let idx1 = policy.select_worker(&workers, Some("hello world")).unwrap();
|
||||
|
||||
// Same request should go to same worker (cache hit)
|
||||
let idx2 = policy.select_worker(&workers, Some("hello world")).unwrap();
|
||||
assert_eq!(idx1, idx2);
|
||||
|
||||
// Similar request should also go to same worker
|
||||
let idx3 = policy.select_worker(&workers, Some("hello")).unwrap();
|
||||
assert_eq!(idx1, idx3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_aware_with_imbalanced_load() {
|
||||
let policy = CacheAwarePolicy::with_config(CacheAwareConfig {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 5,
|
||||
balance_rel_threshold: 2.0,
|
||||
eviction_interval_secs: 0, // Disable eviction thread
|
||||
max_tree_size: 10000,
|
||||
});
|
||||
|
||||
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
|
||||
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
|
||||
|
||||
// Create significant load imbalance
|
||||
for _ in 0..20 {
|
||||
worker1.increment_load();
|
||||
}
|
||||
// worker2 has load 0
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(worker1), Box::new(worker2)];
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// Should select worker2 (lower load) despite cache affinity
|
||||
for _ in 0..5 {
|
||||
let idx = policy.select_worker(&workers, Some("test")).unwrap();
|
||||
assert_eq!(idx, 1); // Should always pick worker2
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_aware_worker_removal() {
|
||||
let config = CacheAwareConfig {
|
||||
eviction_interval_secs: 0, // Disable eviction thread
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// Route some requests
|
||||
policy.select_worker(&workers, Some("test1"));
|
||||
policy.select_worker(&workers, Some("test2"));
|
||||
|
||||
// Remove a worker
|
||||
policy.remove_worker("http://w1:8000");
|
||||
workers[0].set_healthy(false);
|
||||
|
||||
// All requests should now go to worker2
|
||||
let idx = policy.select_worker(&workers, Some("test1")).unwrap();
|
||||
assert_eq!(idx, 1);
|
||||
}
|
||||
}
|
||||
94
sgl-router/src/policies/factory.rs
Normal file
94
sgl-router/src/policies/factory.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
//! Factory for creating load balancing policies
|
||||
|
||||
use super::{
|
||||
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
|
||||
RoundRobinPolicy,
|
||||
};
|
||||
use crate::config::PolicyConfig;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Factory for creating policy instances
|
||||
pub struct PolicyFactory;
|
||||
|
||||
impl PolicyFactory {
|
||||
/// Create a policy from configuration
|
||||
pub fn create_from_config(config: &PolicyConfig) -> Arc<dyn LoadBalancingPolicy> {
|
||||
match config {
|
||||
PolicyConfig::Random => Arc::new(RandomPolicy::new()),
|
||||
PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()),
|
||||
PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()),
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
} => {
|
||||
let config = CacheAwareConfig {
|
||||
cache_threshold: *cache_threshold,
|
||||
balance_abs_threshold: *balance_abs_threshold,
|
||||
balance_rel_threshold: *balance_rel_threshold,
|
||||
eviction_interval_secs: *eviction_interval_secs,
|
||||
max_tree_size: *max_tree_size,
|
||||
};
|
||||
Arc::new(CacheAwarePolicy::with_config(config))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a policy by name (for dynamic loading)
|
||||
pub fn create_by_name(name: &str) -> Option<Arc<dyn LoadBalancingPolicy>> {
|
||||
match name.to_lowercase().as_str() {
|
||||
"random" => Some(Arc::new(RandomPolicy::new())),
|
||||
"round_robin" | "roundrobin" => Some(Arc::new(RoundRobinPolicy::new())),
|
||||
"power_of_two" | "poweroftwo" => Some(Arc::new(PowerOfTwoPolicy::new())),
|
||||
"cache_aware" | "cacheaware" => Some(Arc::new(CacheAwarePolicy::new())),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_from_config() {
|
||||
// Test Random
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
assert_eq!(policy.name(), "random");
|
||||
|
||||
// Test RoundRobin
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin);
|
||||
assert_eq!(policy.name(), "round_robin");
|
||||
|
||||
// Test PowerOfTwo
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
});
|
||||
assert_eq!(policy.name(), "power_of_two");
|
||||
|
||||
// Test CacheAware
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.7,
|
||||
balance_abs_threshold: 10,
|
||||
balance_rel_threshold: 1.5,
|
||||
eviction_interval_secs: 30,
|
||||
max_tree_size: 1000,
|
||||
});
|
||||
assert_eq!(policy.name(), "cache_aware");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_by_name() {
|
||||
assert!(PolicyFactory::create_by_name("random").is_some());
|
||||
assert!(PolicyFactory::create_by_name("RANDOM").is_some());
|
||||
assert!(PolicyFactory::create_by_name("round_robin").is_some());
|
||||
assert!(PolicyFactory::create_by_name("RoundRobin").is_some());
|
||||
assert!(PolicyFactory::create_by_name("power_of_two").is_some());
|
||||
assert!(PolicyFactory::create_by_name("PowerOfTwo").is_some());
|
||||
assert!(PolicyFactory::create_by_name("cache_aware").is_some());
|
||||
assert!(PolicyFactory::create_by_name("CacheAware").is_some());
|
||||
assert!(PolicyFactory::create_by_name("unknown").is_none());
|
||||
}
|
||||
}
|
||||
143
sgl-router/src/policies/mod.rs
Normal file
143
sgl-router/src/policies/mod.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
//! Load balancing policies for SGLang router
|
||||
//!
|
||||
//! This module provides a unified abstraction for routing policies that work
|
||||
//! across both regular and prefill-decode (PD) routing modes.
|
||||
|
||||
use crate::core::Worker;
|
||||
use std::fmt::Debug;
|
||||
|
||||
mod cache_aware;
|
||||
mod factory;
|
||||
mod power_of_two;
|
||||
mod random;
|
||||
mod round_robin;
|
||||
|
||||
pub use cache_aware::CacheAwarePolicy;
|
||||
pub use factory::PolicyFactory;
|
||||
pub use power_of_two::PowerOfTwoPolicy;
|
||||
pub use random::RandomPolicy;
|
||||
pub use round_robin::RoundRobinPolicy;
|
||||
|
||||
/// Core trait for load balancing policies
|
||||
///
|
||||
/// This trait provides a unified interface for implementing routing algorithms
|
||||
/// that can work with both regular single-worker selection and PD dual-worker selection.
|
||||
pub trait LoadBalancingPolicy: Send + Sync + Debug {
|
||||
/// Select a single worker from the available workers
|
||||
///
|
||||
/// This is used for regular routing mode where requests go to a single worker.
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<usize>;
|
||||
|
||||
/// Select a pair of workers (prefill and decode) for PD routing
|
||||
///
|
||||
/// Returns indices of (prefill_worker, decode_worker) from their respective arrays.
|
||||
/// Default implementation uses select_worker for each array independently.
|
||||
fn select_worker_pair(
|
||||
&self,
|
||||
prefill_workers: &[Box<dyn Worker>],
|
||||
decode_workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<(usize, usize)> {
|
||||
// Default implementation: independently select from each pool
|
||||
let prefill_idx = self.select_worker(prefill_workers, request_text)?;
|
||||
let decode_idx = self.select_worker(decode_workers, request_text)?;
|
||||
Some((prefill_idx, decode_idx))
|
||||
}
|
||||
|
||||
/// Update policy state after request completion
|
||||
///
|
||||
/// This is called when a request completes (successfully or not) to allow
|
||||
/// policies to update their internal state.
|
||||
fn on_request_complete(&self, _worker_url: &str, _success: bool) {
|
||||
// Default: no-op for stateless policies
|
||||
}
|
||||
|
||||
/// Get policy name for metrics and debugging
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Update worker load information
|
||||
///
|
||||
/// This is called periodically with current load information for load-aware policies.
|
||||
fn update_loads(&self, _loads: &std::collections::HashMap<String, isize>) {
|
||||
// Default: no-op for policies that don't use load information
|
||||
}
|
||||
|
||||
/// Reset any internal state
|
||||
///
|
||||
/// This is useful for policies that maintain state (e.g., round-robin counters).
|
||||
fn reset(&self) {
|
||||
// Default: no-op for stateless policies
|
||||
}
|
||||
|
||||
/// Get as Any for downcasting
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
}
|
||||
|
||||
/// Configuration for cache-aware policy
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheAwareConfig {
|
||||
pub cache_threshold: f32,
|
||||
pub balance_abs_threshold: usize,
|
||||
pub balance_rel_threshold: f32,
|
||||
pub eviction_interval_secs: u64,
|
||||
pub max_tree_size: usize,
|
||||
}
|
||||
|
||||
impl Default for CacheAwareConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 30,
|
||||
max_tree_size: 10000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to filter healthy workers and return their indices
|
||||
pub(crate) fn get_healthy_worker_indices(workers: &[Box<dyn Worker>]) -> Vec<usize> {
|
||||
workers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, w)| w.is_healthy())
|
||||
.map(|(idx, _)| idx)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_get_healthy_worker_indices() {
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// All healthy initially
|
||||
let indices = get_healthy_worker_indices(&workers);
|
||||
assert_eq!(indices, vec![0, 1, 2]);
|
||||
|
||||
// Mark one unhealthy
|
||||
workers[1].set_healthy(false);
|
||||
let indices = get_healthy_worker_indices(&workers);
|
||||
assert_eq!(indices, vec![0, 2]);
|
||||
}
|
||||
}
|
||||
201
sgl-router/src/policies/power_of_two.rs
Normal file
201
sgl-router/src/policies/power_of_two.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
//! Power-of-two choices load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use metrics::counter;
|
||||
use rand::Rng;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
use tracing::info;
|
||||
|
||||
/// Power-of-two choices policy
|
||||
///
|
||||
/// Randomly selects two workers and routes to the one with lower load.
|
||||
/// This provides good load distribution with minimal coordination overhead.
|
||||
#[derive(Debug)]
|
||||
pub struct PowerOfTwoPolicy {
|
||||
/// Cached load information from external monitoring
|
||||
cached_loads: RwLock<HashMap<String, isize>>,
|
||||
}
|
||||
|
||||
impl PowerOfTwoPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cached_loads: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_worker_load(&self, worker: &dyn Worker) -> isize {
|
||||
// First check cached loads (from external monitoring)
|
||||
if let Ok(loads) = self.cached_loads.read() {
|
||||
if let Some(&load) = loads.get(worker.url()) {
|
||||
return load;
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to local load counter
|
||||
worker.load() as isize
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for PowerOfTwoPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if healthy_indices.len() == 1 {
|
||||
return Some(healthy_indices[0]);
|
||||
}
|
||||
|
||||
// Select two random workers
|
||||
let mut rng = rand::thread_rng();
|
||||
let idx1 = rng.gen_range(0..healthy_indices.len());
|
||||
let mut idx2 = rng.gen_range(0..healthy_indices.len());
|
||||
|
||||
// Ensure we pick two different workers
|
||||
while idx2 == idx1 {
|
||||
idx2 = rng.gen_range(0..healthy_indices.len());
|
||||
}
|
||||
|
||||
let worker_idx1 = healthy_indices[idx1];
|
||||
let worker_idx2 = healthy_indices[idx2];
|
||||
|
||||
// Compare loads and select the less loaded one
|
||||
let load1 = self.get_worker_load(workers[worker_idx1].as_ref());
|
||||
let load2 = self.get_worker_load(workers[worker_idx2].as_ref());
|
||||
|
||||
// Log selection for debugging
|
||||
let selected_idx = if load1 <= load2 {
|
||||
worker_idx1
|
||||
} else {
|
||||
worker_idx2
|
||||
};
|
||||
|
||||
info!(
|
||||
"Power-of-two selection: {}={} vs {}={} -> selected {}",
|
||||
workers[worker_idx1].url(),
|
||||
load1,
|
||||
workers[worker_idx2].url(),
|
||||
load2,
|
||||
workers[selected_idx].url()
|
||||
);
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
counter!("sgl_router_processed_requests_total", "worker" => workers[selected_idx].url().to_string())
|
||||
.increment(1);
|
||||
|
||||
Some(selected_idx)
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"power_of_two"
|
||||
}
|
||||
|
||||
fn update_loads(&self, loads: &HashMap<String, isize>) {
|
||||
if let Ok(mut cached) = self.cached_loads.write() {
|
||||
*cached = loads.clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PowerOfTwoPolicy {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_selection() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
|
||||
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
|
||||
let worker3 = BasicWorker::new("http://w3:8000".to_string(), WorkerType::Regular);
|
||||
|
||||
// Set different loads
|
||||
for _ in 0..10 {
|
||||
worker1.increment_load();
|
||||
}
|
||||
for _ in 0..5 {
|
||||
worker2.increment_load();
|
||||
}
|
||||
// worker3 has load 0
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> =
|
||||
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
|
||||
|
||||
// Run multiple selections
|
||||
let mut selected_counts = vec![0; 3];
|
||||
for _ in 0..100 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
selected_counts[idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Worker with lowest load (worker3) should be selected most often
|
||||
assert!(selected_counts[2] > selected_counts[1]);
|
||||
assert!(selected_counts[1] > selected_counts[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_with_cached_loads() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Update cached loads
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("http://w1:8000".to_string(), 100);
|
||||
loads.insert("http://w2:8000".to_string(), 10);
|
||||
policy.update_loads(&loads);
|
||||
|
||||
// Should prefer worker2 with lower cached load
|
||||
let mut w2_selected = 0;
|
||||
for _ in 0..50 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
if idx == 1 {
|
||||
w2_selected += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Worker2 should be selected significantly more often
|
||||
assert!(w2_selected > 35); // Should win most of the time
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_single_worker() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
))];
|
||||
|
||||
// With single worker, should always select it
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
}
|
||||
}
|
||||
116
sgl-router/src/policies/random.rs
Normal file
116
sgl-router/src/policies/random.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
//! Random load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use rand::Rng;
|
||||
|
||||
/// Random selection policy
|
||||
///
|
||||
/// Selects workers randomly with uniform distribution among healthy workers.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RandomPolicy;
|
||||
|
||||
impl RandomPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for RandomPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let random_idx = rng.gen_range(0..healthy_indices.len());
|
||||
Some(healthy_indices[random_idx])
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"random"
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_random_selection() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Test multiple selections to ensure randomness
|
||||
let mut counts = HashMap::new();
|
||||
for _ in 0..100 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
*counts.entry(idx).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// All workers should be selected at least once
|
||||
assert_eq!(counts.len(), 3);
|
||||
assert!(counts.values().all(|&count| count > 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_with_unhealthy_workers() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Mark first worker as unhealthy
|
||||
workers[0].set_healthy(false);
|
||||
|
||||
// Should always select the healthy worker (index 1)
|
||||
for _ in 0..10 {
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_no_healthy_workers() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
))];
|
||||
|
||||
workers[0].set_healthy(false);
|
||||
assert_eq!(policy.select_worker(&workers, None), None);
|
||||
}
|
||||
}
|
||||
136
sgl-router/src/policies/round_robin.rs
Normal file
136
sgl-router/src/policies/round_robin.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
//! Round-robin load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
/// Round-robin selection policy
|
||||
///
|
||||
/// Selects workers in sequential order, cycling through all healthy workers.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RoundRobinPolicy {
|
||||
counter: AtomicUsize,
|
||||
}
|
||||
|
||||
impl RoundRobinPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
counter: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for RoundRobinPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get and increment counter atomically
|
||||
let count = self.counter.fetch_add(1, Ordering::Relaxed);
|
||||
let selected_idx = count % healthy_indices.len();
|
||||
|
||||
Some(healthy_indices[selected_idx])
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"round_robin"
|
||||
}
|
||||
|
||||
fn reset(&self) {
|
||||
self.counter.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_round_robin_selection() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Should select workers in order: 0, 1, 2, 0, 1, 2, ...
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(2));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_robin_with_unhealthy_workers() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Mark middle worker as unhealthy
|
||||
workers[1].set_healthy(false);
|
||||
|
||||
// Should skip unhealthy worker: 0, 2, 0, 2, ...
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(2));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_robin_reset() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Advance the counter
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
|
||||
// Reset should start from beginning
|
||||
policy.reset();
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
66
sgl-router/src/routers/factory.rs
Normal file
66
sgl-router/src/routers/factory.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
//! Factory for creating router instances
|
||||
|
||||
use super::{pd_router::PDRouter, router::Router, RouterTrait};
|
||||
use crate::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||
use crate::policies::PolicyFactory;
|
||||
|
||||
/// Factory for creating router instances based on configuration
|
||||
pub struct RouterFactory;
|
||||
|
||||
impl RouterFactory {
|
||||
/// Create a router instance from configuration
|
||||
pub fn create_router(config: &RouterConfig) -> Result<Box<dyn RouterTrait>, String> {
|
||||
match &config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_regular_router(worker_urls, &config.policy, config)
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
} => Self::create_pd_router(prefill_urls, decode_urls, &config.policy, config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a regular router with injected policy
|
||||
fn create_regular_router(
|
||||
worker_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
router_config: &RouterConfig,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policy
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Create regular router with injected policy
|
||||
let router = Router::new(
|
||||
worker_urls.to_vec(),
|
||||
policy,
|
||||
router_config.worker_startup_timeout_secs,
|
||||
router_config.worker_startup_check_interval_secs,
|
||||
)?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a PD router with injected policy
|
||||
fn create_pd_router(
|
||||
prefill_urls: &[(String, Option<u16>)],
|
||||
decode_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
router_config: &RouterConfig,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policy directly from PolicyConfig
|
||||
// All policies now support PD mode through the select_worker_pair method
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Create PD router with injected policy
|
||||
let router = PDRouter::new(
|
||||
prefill_urls.to_vec(),
|
||||
decode_urls.to_vec(),
|
||||
policy,
|
||||
router_config.worker_startup_timeout_secs,
|
||||
router_config.worker_startup_check_interval_secs,
|
||||
)?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
}
|
||||
101
sgl-router/src/routers/mod.rs
Normal file
101
sgl-router/src/routers/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
//! Router implementations
|
||||
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub mod factory;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
pub mod request_adapter;
|
||||
pub mod router;
|
||||
|
||||
pub use factory::RouterFactory;
|
||||
|
||||
/// Worker management trait for administrative operations
|
||||
///
|
||||
/// This trait is separate from RouterTrait to allow Send futures
|
||||
/// for use in service discovery and other background tasks
|
||||
#[async_trait]
|
||||
pub trait WorkerManagement: Send + Sync {
|
||||
/// Add a worker to the router
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String>;
|
||||
|
||||
/// Remove a worker from the router
|
||||
fn remove_worker(&self, worker_url: &str);
|
||||
|
||||
/// Get all worker URLs
|
||||
fn get_worker_urls(&self) -> Vec<String>;
|
||||
}
|
||||
|
||||
/// Core trait for all router implementations
|
||||
///
|
||||
/// This trait provides a unified interface for routing requests,
|
||||
/// regardless of whether it's a regular router or PD router.
|
||||
#[async_trait(?Send)]
|
||||
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
/// Get a reference to self as Any for downcasting
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
/// Route a health check request
|
||||
async fn health(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Route a health generate request
|
||||
async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Get server information
|
||||
async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Get available models
|
||||
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Get model information
|
||||
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Route a generate request
|
||||
async fn route_generate(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse;
|
||||
|
||||
/// Route a chat completion request
|
||||
async fn route_chat(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse;
|
||||
|
||||
/// Route a completion request
|
||||
async fn route_completion(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse;
|
||||
|
||||
/// Flush cache on all workers
|
||||
async fn flush_cache(&self, client: &Client) -> HttpResponse;
|
||||
|
||||
/// Get worker loads (for monitoring)
|
||||
async fn get_worker_loads(&self, client: &Client) -> HttpResponse;
|
||||
|
||||
/// Get router type name
|
||||
fn router_type(&self) -> &'static str;
|
||||
|
||||
/// Check if this is a PD router
|
||||
fn is_pd_mode(&self) -> bool {
|
||||
self.router_type() == "pd"
|
||||
}
|
||||
|
||||
/// Server liveness check - is the server process running
|
||||
fn liveness(&self) -> HttpResponse {
|
||||
// Simple liveness check - if we can respond, we're alive
|
||||
HttpResponse::Ok().body("OK")
|
||||
}
|
||||
|
||||
/// Server readiness check - is the server ready to handle requests
|
||||
fn readiness(&self) -> HttpResponse;
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
// PD (Prefill-Decode) Router Implementation
|
||||
// This module handles routing for disaggregated prefill-decode systems
|
||||
|
||||
use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError};
|
||||
use super::request_adapter::ToPdRequest;
|
||||
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
||||
use crate::pd_types::{
|
||||
api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy,
|
||||
};
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::tree::Tree;
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
@@ -17,13 +18,11 @@ use std::time::{Duration, Instant};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
// Removed over-engineered ProxyResponse - using HttpResponse directly
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PDRouter {
|
||||
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
pub selection_policy: PDSelectionPolicy,
|
||||
pub policy: Arc<dyn LoadBalancingPolicy>,
|
||||
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
||||
pub timeout_secs: u64,
|
||||
pub interval_secs: u64,
|
||||
@@ -42,7 +41,7 @@ impl PDRouter {
|
||||
bootstrap_port: Option<u16>,
|
||||
) -> Result<String, PDRouterError> {
|
||||
// Wait for the new server to be healthy
|
||||
crate::router::Router::wait_for_healthy_workers(
|
||||
crate::routers::router::Router::wait_for_healthy_workers(
|
||||
&[url.clone()],
|
||||
self.timeout_secs,
|
||||
self.interval_secs,
|
||||
@@ -78,7 +77,7 @@ impl PDRouter {
|
||||
|
||||
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
|
||||
// Wait for the new server to be healthy
|
||||
crate::router::Router::wait_for_healthy_workers(
|
||||
crate::routers::router::Router::wait_for_healthy_workers(
|
||||
&[url.clone()],
|
||||
self.timeout_secs,
|
||||
self.interval_secs,
|
||||
@@ -103,9 +102,6 @@ impl PDRouter {
|
||||
|
||||
workers.push(worker);
|
||||
|
||||
// Initialize load tracking
|
||||
// Worker tracks its own load internally
|
||||
|
||||
info!("Added decode server: {}", url);
|
||||
Ok(format!("Successfully added decode server: {}", url))
|
||||
}
|
||||
@@ -128,9 +124,6 @@ impl PDRouter {
|
||||
});
|
||||
}
|
||||
|
||||
// Remove from load tracking
|
||||
// Worker load tracking is internal
|
||||
|
||||
// Remove from cache tree if using cache-aware policy
|
||||
if let Some(ref tree) = self.prefill_tree {
|
||||
// Note: Tree doesn't have a remove method, so we rebuild it
|
||||
@@ -170,7 +163,7 @@ impl PDRouter {
|
||||
pub fn new(
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
decode_urls: Vec<String>,
|
||||
selection_policy: PDSelectionPolicy,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<Self, String> {
|
||||
@@ -185,25 +178,38 @@ impl PDRouter {
|
||||
.map(WorkerFactory::create_decode)
|
||||
.collect();
|
||||
|
||||
// Wait for PD workers to be healthy
|
||||
// Wait for PD workers to be healthy (skip if empty - for service discovery mode)
|
||||
let all_urls: Vec<String> = prefill_workers
|
||||
.iter()
|
||||
.chain(decode_workers.iter())
|
||||
.map(|worker| worker.url().to_string())
|
||||
.collect();
|
||||
crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?;
|
||||
if !all_urls.is_empty() {
|
||||
crate::routers::router::Router::wait_for_healthy_workers(
|
||||
&all_urls,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
)?;
|
||||
}
|
||||
|
||||
// Initialize cache-aware components if needed
|
||||
let prefill_tree = match &selection_policy {
|
||||
PDSelectionPolicy::CacheAware { .. } => {
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
// Initialize tree with prefill workers
|
||||
for worker in &prefill_workers {
|
||||
tree.lock().unwrap().insert("", worker.url());
|
||||
}
|
||||
Some(tree)
|
||||
let prefill_tree = if policy.name() == "cache_aware" {
|
||||
// Initialize the policy's internal tree with prefill workers
|
||||
if let Some(cache_policy) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.init_workers(&prefill_workers);
|
||||
}
|
||||
_ => None,
|
||||
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
// Initialize tree with prefill workers
|
||||
for worker in &prefill_workers {
|
||||
tree.lock().unwrap().insert("", worker.url());
|
||||
}
|
||||
Some(tree)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Set up background load monitoring for power-of-two selection
|
||||
@@ -216,10 +222,11 @@ impl PDRouter {
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let load_monitor_handle = if matches!(selection_policy, PDSelectionPolicy::PowerOfTwo) {
|
||||
let load_monitor_handle = if policy.name() == "power_of_two" {
|
||||
let monitor_urls = all_urls.clone();
|
||||
let monitor_interval = interval_secs;
|
||||
let monitor_client = http_client.clone();
|
||||
let policy_clone = Arc::clone(&policy);
|
||||
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
Self::monitor_worker_loads_with_client(
|
||||
@@ -227,6 +234,7 @@ impl PDRouter {
|
||||
tx,
|
||||
monitor_interval,
|
||||
monitor_client,
|
||||
policy_clone,
|
||||
)
|
||||
.await;
|
||||
})))
|
||||
@@ -246,7 +254,7 @@ impl PDRouter {
|
||||
Ok(PDRouter {
|
||||
prefill_workers,
|
||||
decode_workers,
|
||||
selection_policy,
|
||||
policy,
|
||||
prefill_tree,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
@@ -270,15 +278,21 @@ impl PDRouter {
|
||||
let _request_id = Uuid::new_v4();
|
||||
|
||||
// Get stream flag and return_logprob flag before moving the request
|
||||
let is_stream = typed_req.is_stream();
|
||||
let is_stream = typed_req.stream;
|
||||
let return_logprob = typed_req
|
||||
.other
|
||||
.get("return_logprob")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Extract text for cache-aware routing from the typed request
|
||||
let request_text = typed_req.text.as_ref().and_then(|t| match t {
|
||||
super::pd_types::InputText::Single(s) => Some(s.as_str()),
|
||||
super::pd_types::InputText::Batch(v) => v.first().map(|s| s.as_str()),
|
||||
});
|
||||
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(client).await {
|
||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair: {}", e);
|
||||
@@ -339,15 +353,24 @@ impl PDRouter {
|
||||
let start = Instant::now();
|
||||
|
||||
// Get stream flag and return_logprob flag before moving the request
|
||||
let is_stream = typed_req.is_stream();
|
||||
let is_stream = typed_req.stream;
|
||||
let return_logprob = typed_req
|
||||
.other
|
||||
.get("return_logprob")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Extract text for cache-aware routing from chat messages
|
||||
let request_text = typed_req
|
||||
.other
|
||||
.get("messages")
|
||||
.and_then(|messages| messages.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|msg| msg.get("content"))
|
||||
.and_then(|content| content.as_str());
|
||||
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(client).await {
|
||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair: {}", e);
|
||||
@@ -424,7 +447,7 @@ impl PDRouter {
|
||||
.json(&json_request);
|
||||
|
||||
// Copy headers from original request
|
||||
for (name, value) in crate::router::copy_request_headers(req) {
|
||||
for (name, value) in crate::routers::router::copy_request_headers(req) {
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
|
||||
prefill_request = prefill_request.header(&name, &value);
|
||||
decode_request = decode_request.header(&name, &value);
|
||||
@@ -620,104 +643,47 @@ impl PDRouter {
|
||||
async fn select_pd_pair(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
request_text: Option<&str>,
|
||||
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
// Check we have workers
|
||||
if self
|
||||
// Get read locks for both worker lists
|
||||
let prefill_workers = self
|
||||
.prefill_workers
|
||||
.read()
|
||||
.map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))?
|
||||
.is_empty()
|
||||
{
|
||||
return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string());
|
||||
}
|
||||
if self
|
||||
.map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))?;
|
||||
let decode_workers = self
|
||||
.decode_workers
|
||||
.read()
|
||||
.map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?
|
||||
.is_empty()
|
||||
{
|
||||
.map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?;
|
||||
|
||||
// Check we have workers
|
||||
if prefill_workers.is_empty() {
|
||||
return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string());
|
||||
}
|
||||
if decode_workers.is_empty() {
|
||||
return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string());
|
||||
}
|
||||
|
||||
match &self.selection_policy {
|
||||
PDSelectionPolicy::Random => self.select_random(),
|
||||
PDSelectionPolicy::PowerOfTwo => self.select_power_of_two().await,
|
||||
PDSelectionPolicy::CacheAware { .. } => {
|
||||
// TODO: Implement cache-aware selection
|
||||
self.select_power_of_two().await
|
||||
// Use the policy to select worker pair
|
||||
match self
|
||||
.policy
|
||||
.select_worker_pair(&prefill_workers, &decode_workers, request_text)
|
||||
{
|
||||
Some((prefill_idx, decode_idx)) => {
|
||||
let prefill = prefill_workers[prefill_idx].clone_worker();
|
||||
let decode = decode_workers[decode_idx].clone_worker();
|
||||
Ok((prefill, decode))
|
||||
}
|
||||
None => Err("Failed to select worker pair".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn select_random(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
||||
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
||||
|
||||
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone_worker();
|
||||
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone_worker();
|
||||
|
||||
Ok((prefill, decode))
|
||||
}
|
||||
|
||||
async fn select_power_of_two(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
||||
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
||||
|
||||
let (p1_idx, p2_idx) = get_two_random_indices(prefill_list.len());
|
||||
let (d1_idx, d2_idx) = get_two_random_indices(decode_list.len());
|
||||
|
||||
let loads = self.worker_loads.borrow();
|
||||
|
||||
let p1_load = loads
|
||||
.get(prefill_list[p1_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
let p2_load = loads
|
||||
.get(prefill_list[p2_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
let d1_load = loads
|
||||
.get(decode_list[d1_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
let d2_load = loads
|
||||
.get(decode_list[d2_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
|
||||
info!(
|
||||
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}",
|
||||
prefill_list[p1_idx].url(),
|
||||
p1_load,
|
||||
prefill_list[p2_idx].url(),
|
||||
p2_load,
|
||||
decode_list[d1_idx].url(),
|
||||
d1_load,
|
||||
decode_list[d2_idx].url(),
|
||||
d2_load
|
||||
);
|
||||
|
||||
let selected_prefill = if p1_load <= p2_load {
|
||||
prefill_list[p1_idx].clone_worker()
|
||||
} else {
|
||||
prefill_list[p2_idx].clone_worker()
|
||||
};
|
||||
|
||||
let selected_decode = if d1_load <= d2_load {
|
||||
decode_list[d1_idx].clone_worker()
|
||||
} else {
|
||||
decode_list[d2_idx].clone_worker()
|
||||
};
|
||||
|
||||
Ok((selected_prefill, selected_decode))
|
||||
}
|
||||
|
||||
// Background task to monitor worker loads with shared client
|
||||
async fn monitor_worker_loads_with_client(
|
||||
worker_urls: Vec<String>,
|
||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||
interval_secs: u64,
|
||||
client: reqwest::Client,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
) {
|
||||
loop {
|
||||
let mut loads = HashMap::new();
|
||||
@@ -742,6 +708,9 @@ impl PDRouter {
|
||||
|
||||
debug!("Worker loads updated: {:?}", loads);
|
||||
|
||||
// Update the policy with current loads
|
||||
policy.update_loads(&loads);
|
||||
|
||||
// Check if receiver is still active
|
||||
if tx.send(loads).is_err() {
|
||||
info!("Load monitor receiver dropped, shutting down monitor task");
|
||||
@@ -792,18 +761,6 @@ impl PDRouter {
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
fn get_two_random_indices(len: usize) -> (usize, usize) {
|
||||
if len == 1 {
|
||||
(0, 0)
|
||||
} else {
|
||||
let idx1 = rand::random::<usize>() % len;
|
||||
let mut idx2 = rand::random::<usize>() % len;
|
||||
while idx2 == idx1 {
|
||||
idx2 = rand::random::<usize>() % len;
|
||||
}
|
||||
(idx1, idx2)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
match client.get(format!("{}/get_load", worker_url)).send().await {
|
||||
@@ -841,61 +798,72 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i
|
||||
// PD-specific endpoints
|
||||
impl PDRouter {
|
||||
pub async fn health_generate(&self, client: &reqwest::Client) -> HttpResponse {
|
||||
let mut all_healthy = true;
|
||||
let mut unhealthy_servers = Vec::new();
|
||||
// Test model generation capability by selecting a random pair and testing them
|
||||
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
||||
|
||||
// Collect all worker URLs with their types
|
||||
let mut worker_infos = Vec::new();
|
||||
// Select a random worker pair using the policy
|
||||
let (prefill, decode) = match self.select_pd_pair(client, None).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
return HttpResponse::ServiceUnavailable()
|
||||
.body(format!("No healthy worker pair available: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
for worker in self.prefill_workers.read().unwrap().iter() {
|
||||
worker_infos.push((worker.url().to_string(), "prefill"));
|
||||
}
|
||||
// Test prefill server's health_generate
|
||||
let prefill_url = format!("{}/health_generate", prefill.url());
|
||||
let prefill_result = client.get(&prefill_url).send().await;
|
||||
|
||||
for worker in self.decode_workers.read().unwrap().iter() {
|
||||
worker_infos.push((worker.url().to_string(), "decode"));
|
||||
}
|
||||
// Test decode server's health_generate
|
||||
let decode_url = format!("{}/health_generate", decode.url());
|
||||
let decode_result = client.get(&decode_url).send().await;
|
||||
|
||||
// Create tasks with URL tracking
|
||||
let tasks: Vec<_> = worker_infos
|
||||
.iter()
|
||||
.map(|(url, _)| {
|
||||
let health_url = format!("{}/health_generate", url);
|
||||
client.get(&health_url).send()
|
||||
})
|
||||
.collect();
|
||||
// Check results
|
||||
let mut errors = Vec::new();
|
||||
|
||||
let results = futures_util::future::join_all(tasks).await;
|
||||
|
||||
for ((url, worker_type), result) in worker_infos.iter().zip(results.into_iter()) {
|
||||
match result {
|
||||
Ok(res) if res.status().is_success() => {
|
||||
debug!("Health check passed for {} server: {}", worker_type, url);
|
||||
}
|
||||
Ok(res) => {
|
||||
all_healthy = false;
|
||||
let msg = format!(
|
||||
"{} server {} returned status {}",
|
||||
worker_type,
|
||||
url,
|
||||
res.status()
|
||||
);
|
||||
error!("{}", msg);
|
||||
unhealthy_servers.push(msg);
|
||||
}
|
||||
Err(e) => {
|
||||
all_healthy = false;
|
||||
let msg = format!("{} server {} error: {}", worker_type, url, e);
|
||||
error!("{}", msg);
|
||||
unhealthy_servers.push(msg);
|
||||
}
|
||||
match prefill_result {
|
||||
Ok(res) if res.status().is_success() => {
|
||||
debug!(
|
||||
"Health generate passed for prefill server: {}",
|
||||
prefill.url()
|
||||
);
|
||||
}
|
||||
Ok(res) => {
|
||||
errors.push(format!(
|
||||
"Prefill {} returned status {}",
|
||||
prefill.url(),
|
||||
res.status()
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
errors.push(format!("Prefill {} error: {}", prefill.url(), e));
|
||||
}
|
||||
}
|
||||
|
||||
if all_healthy {
|
||||
HttpResponse::Ok().body("Health check passed on all servers")
|
||||
match decode_result {
|
||||
Ok(res) if res.status().is_success() => {
|
||||
debug!("Health generate passed for decode server: {}", decode.url());
|
||||
}
|
||||
Ok(res) => {
|
||||
errors.push(format!(
|
||||
"Decode {} returned status {}",
|
||||
decode.url(),
|
||||
res.status()
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
errors.push(format!("Decode {} error: {}", decode.url(), e));
|
||||
}
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
HttpResponse::Ok().body(format!(
|
||||
"Health generate passed on selected pair: prefill={}, decode={}",
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
))
|
||||
} else {
|
||||
HttpResponse::ServiceUnavailable()
|
||||
.body(format!("Health check failed: {:?}", unhealthy_servers))
|
||||
HttpResponse::ServiceUnavailable().body(format!("Health generate failed: {:?}", errors))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -955,7 +923,7 @@ impl PDRouter {
|
||||
if let Some(worker_url) = first_worker_url {
|
||||
// Send request directly without going through Router
|
||||
let mut request_builder = client.get(format!("{}/v1/models", worker_url));
|
||||
for (name, value) in crate::router::copy_request_headers(req) {
|
||||
for (name, value) in crate::routers::router::copy_request_headers(req) {
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
@@ -1035,7 +1003,7 @@ impl PDRouter {
|
||||
|
||||
if let Some(worker_url) = first_worker_url {
|
||||
let mut request_builder = client.get(format!("{}/get_model_info", worker_url));
|
||||
for (name, value) in crate::router::copy_request_headers(req) {
|
||||
for (name, value) in crate::routers::router::copy_request_headers(req) {
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
@@ -1102,3 +1070,324 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for PDRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
// For PD router, we don't support adding workers via this generic method
|
||||
Err(
|
||||
"PD router requires specific add_prefill_server or add_decode_server methods"
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
fn remove_worker(&self, worker_url: &str) {
|
||||
// For PD router, we would need to know if it's a prefill or decode server
|
||||
// For now, try both
|
||||
if let Ok(mut workers) = self.prefill_workers.write() {
|
||||
if let Some(index) = workers.iter().position(|w| w.url() == worker_url) {
|
||||
workers.remove(index);
|
||||
info!("Removed prefill worker: {}", worker_url);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(mut workers) = self.decode_workers.write() {
|
||||
if let Some(index) = workers.iter().position(|w| w.url() == worker_url) {
|
||||
workers.remove(index);
|
||||
info!("Removed decode worker: {}", worker_url);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
let mut urls = Vec::new();
|
||||
|
||||
// Add prefill worker URLs
|
||||
if let Ok(workers) = self.prefill_workers.read() {
|
||||
for worker in workers.iter() {
|
||||
urls.push(worker.url().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Add decode worker URLs
|
||||
if let Ok(workers) = self.decode_workers.read() {
|
||||
for worker in workers.iter() {
|
||||
urls.push(worker.url().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
urls
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl RouterTrait for PDRouter {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _client: &Client, _req: &HttpRequest) -> HttpResponse {
|
||||
// This is a server readiness check - checking if we have healthy workers
|
||||
// Workers handle their own health checks in the background
|
||||
let mut all_healthy = true;
|
||||
let mut unhealthy_servers = Vec::new();
|
||||
|
||||
// Check prefill servers
|
||||
for worker in self.prefill_workers.read().unwrap().iter() {
|
||||
if !worker.is_healthy() {
|
||||
all_healthy = false;
|
||||
unhealthy_servers.push(format!("Prefill: {}", worker.url()));
|
||||
}
|
||||
}
|
||||
|
||||
// Check decode servers
|
||||
for worker in self.decode_workers.read().unwrap().iter() {
|
||||
if !worker.is_healthy() {
|
||||
all_healthy = false;
|
||||
unhealthy_servers.push(format!("Decode: {}", worker.url()));
|
||||
}
|
||||
}
|
||||
|
||||
if all_healthy {
|
||||
HttpResponse::Ok().body("All servers healthy")
|
||||
} else {
|
||||
HttpResponse::ServiceUnavailable()
|
||||
.body(format!("Unhealthy servers: {:?}", unhealthy_servers))
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_generate(&self, client: &Client, _req: &HttpRequest) -> HttpResponse {
|
||||
// Use the existing PDRouter health_generate method
|
||||
PDRouter::health_generate(self, client).await
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, client: &Client, _req: &HttpRequest) -> HttpResponse {
|
||||
// Use the existing PDRouter get_server_info method
|
||||
PDRouter::get_server_info(self, client).await
|
||||
}
|
||||
|
||||
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
||||
// Get first prefill worker URL to avoid holding lock across await
|
||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
||||
};
|
||||
|
||||
if let Some(worker_url) = first_worker_url {
|
||||
// Send request directly without going through Router
|
||||
let mut request_builder = client.get(format!("{}/v1/models", worker_url));
|
||||
for (name, value) in crate::routers::router::copy_request_headers(req) {
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
match request_builder.send().await {
|
||||
Ok(res) => {
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.body(format!("Failed to read response body: {}", e)),
|
||||
}
|
||||
}
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.body(format!("Failed to send request: {}", e)),
|
||||
}
|
||||
} else {
|
||||
HttpResponse::ServiceUnavailable().body("No prefill servers available")
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
||||
// For PD router, get model info from the first prefill server
|
||||
// Get first prefill worker URL to avoid holding lock across await
|
||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
||||
};
|
||||
|
||||
if let Some(worker_url) = first_worker_url {
|
||||
let mut request_builder = client.get(format!("{}/get_model_info", worker_url));
|
||||
for (name, value) in crate::routers::router::copy_request_headers(req) {
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
match request_builder.send().await {
|
||||
Ok(res) => {
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.body(format!("Failed to read response body: {}", e)),
|
||||
}
|
||||
}
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.body(format!("Failed to send request: {}", e)),
|
||||
}
|
||||
} else {
|
||||
HttpResponse::ServiceUnavailable().body("No prefill servers available")
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse {
|
||||
match serde_json::from_value::<GenerateRequest>(body.clone()) {
|
||||
Ok(openai_req) => {
|
||||
// Convert OpenAI format to PD format
|
||||
let pd_req = openai_req.to_pd_request();
|
||||
PDRouter::route_generate(self, client, req, pd_req, "/generate").await
|
||||
}
|
||||
Err(_) => {
|
||||
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
|
||||
match serde_json::from_value::<GenerateReqInput>(body) {
|
||||
Ok(pd_req) => {
|
||||
PDRouter::route_generate(self, client, req, pd_req, "/generate").await
|
||||
}
|
||||
Err(e) => {
|
||||
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse {
|
||||
match serde_json::from_value::<ChatCompletionRequest>(body.clone()) {
|
||||
Ok(openai_req) => {
|
||||
// Convert OpenAI format to PD format
|
||||
let pd_req = openai_req.to_pd_request();
|
||||
PDRouter::route_chat(self, client, req, pd_req, "/v1/chat/completions").await
|
||||
}
|
||||
Err(_) => {
|
||||
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
|
||||
match serde_json::from_value::<ChatReqInput>(body) {
|
||||
Ok(pd_req) => {
|
||||
PDRouter::route_chat(self, client, req, pd_req, "/v1/chat/completions")
|
||||
.await
|
||||
}
|
||||
Err(e) => {
|
||||
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse {
|
||||
match serde_json::from_value::<CompletionRequest>(body.clone()) {
|
||||
Ok(openai_req) => {
|
||||
// Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput)
|
||||
let pd_req = openai_req.to_pd_request();
|
||||
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
|
||||
}
|
||||
Err(_) => {
|
||||
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
|
||||
match serde_json::from_value::<GenerateReqInput>(body) {
|
||||
Ok(pd_req) => {
|
||||
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
|
||||
}
|
||||
Err(e) => {
|
||||
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush_cache(&self, client: &Client) -> HttpResponse {
|
||||
// Use the existing PDRouter flush_cache method
|
||||
PDRouter::flush_cache(self, client).await
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self, client: &Client) -> HttpResponse {
|
||||
// Use the existing PDRouter get_loads method
|
||||
PDRouter::get_loads(self, client).await
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"pd"
|
||||
}
|
||||
|
||||
fn readiness(&self) -> HttpResponse {
|
||||
// PD router is ready if it has at least one healthy prefill AND one healthy decode worker
|
||||
let healthy_prefill_count = self
|
||||
.prefill_workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter(|w| w.is_healthy())
|
||||
.count();
|
||||
|
||||
let healthy_decode_count = self
|
||||
.decode_workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter(|w| w.is_healthy())
|
||||
.count();
|
||||
|
||||
let total_prefill = self.prefill_workers.read().unwrap().len();
|
||||
let total_decode = self.decode_workers.read().unwrap().len();
|
||||
|
||||
if healthy_prefill_count > 0 && healthy_decode_count > 0 {
|
||||
HttpResponse::Ok().json(serde_json::json!({
|
||||
"status": "ready",
|
||||
"prefill": {
|
||||
"healthy": healthy_prefill_count,
|
||||
"total": total_prefill
|
||||
},
|
||||
"decode": {
|
||||
"healthy": healthy_decode_count,
|
||||
"total": total_decode
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
let mut reasons = Vec::new();
|
||||
if healthy_prefill_count == 0 {
|
||||
reasons.push("no healthy prefill workers");
|
||||
}
|
||||
if healthy_decode_count == 0 {
|
||||
reasons.push("no healthy decode workers");
|
||||
}
|
||||
|
||||
HttpResponse::ServiceUnavailable().json(serde_json::json!({
|
||||
"status": "not_ready",
|
||||
"reason": reasons.join(", "),
|
||||
"prefill": {
|
||||
"healthy": healthy_prefill_count,
|
||||
"total": total_prefill
|
||||
},
|
||||
"decode": {
|
||||
"healthy": healthy_decode_count,
|
||||
"total": total_decode
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
// Request adapter to bridge OpenAI API types with PD routing requirements
|
||||
|
||||
use super::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
|
||||
use crate::openai_api_types::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
|
||||
};
|
||||
use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Adapter trait to convert OpenAI requests to PD-compatible requests
|
||||
1055
sgl-router/src/routers/router.rs
Normal file
1055
sgl-router/src/routers/router.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,8 @@
|
||||
use crate::config::RouterConfig;
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::prometheus::{self, PrometheusConfig};
|
||||
use crate::request_adapter::ToPdRequest;
|
||||
use crate::router::PolicyConfig;
|
||||
use crate::router::Router;
|
||||
use crate::routers::{RouterFactory, RouterTrait};
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
use actix_web::{
|
||||
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
|
||||
@@ -19,27 +18,19 @@ use tracing::{error, info, warn, Level};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AppState {
|
||||
router: Arc<Router>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
client: Client,
|
||||
is_pd_mode: bool, // Add flag to track PD mode
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(
|
||||
worker_urls: Vec<String>,
|
||||
client: Client,
|
||||
policy_config: PolicyConfig,
|
||||
) -> Result<Self, String> {
|
||||
// Check if this is PD mode from policy config
|
||||
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
|
||||
pub fn new(router_config: RouterConfig, client: Client) -> Result<Self, String> {
|
||||
// Use RouterFactory to create the appropriate router type
|
||||
let router = RouterFactory::create_router(&router_config)?;
|
||||
|
||||
// Create router based on policy
|
||||
let router = Arc::new(Router::new(worker_urls, policy_config)?);
|
||||
Ok(Self {
|
||||
router,
|
||||
client,
|
||||
is_pd_mode,
|
||||
})
|
||||
// Convert Box<dyn RouterTrait> to Arc<dyn RouterTrait>
|
||||
let router = Arc::from(router);
|
||||
|
||||
Ok(Self { router, client })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,65 +67,39 @@ fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/liveness")]
|
||||
async fn liveness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.liveness()
|
||||
}
|
||||
|
||||
#[get("/readiness")]
|
||||
async fn readiness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.readiness()
|
||||
}
|
||||
|
||||
#[get("/health")]
|
||||
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health", &req)
|
||||
.await
|
||||
data.router.health(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[get("/health_generate")]
|
||||
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
// Check if we're in PD mode
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, check health on all servers
|
||||
data.router
|
||||
.route_pd_health_generate(&data.client, &req)
|
||||
.await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate", &req)
|
||||
.await
|
||||
}
|
||||
data.router.health_generate(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, aggregate info from both prefill and decode servers
|
||||
data.router.get_pd_server_info(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode - return first server's info
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info", &req)
|
||||
.await
|
||||
}
|
||||
data.router.get_server_info(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[get("/v1/models")]
|
||||
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, return models from the first prefill server
|
||||
data.router.get_pd_models(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/v1/models", &req)
|
||||
.await
|
||||
}
|
||||
data.router.get_models(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, get model info from the first prefill server
|
||||
data.router.get_pd_model_info(&data.client, &req).await
|
||||
} else {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info", &req)
|
||||
.await
|
||||
}
|
||||
data.router.get_model_info(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
@@ -143,24 +108,12 @@ async fn generate(
|
||||
body: web::Json<GenerateRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/generate")
|
||||
.await)
|
||||
}
|
||||
let json_body = serde_json::to_value(body.into_inner())
|
||||
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
|
||||
Ok(state
|
||||
.router
|
||||
.route_generate(&state.client, &req, json_body)
|
||||
.await)
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
@@ -169,24 +122,12 @@ async fn v1_chat_completions(
|
||||
body: web::Json<ChatCompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
|
||||
.await)
|
||||
}
|
||||
let json_body = serde_json::to_value(body.into_inner())
|
||||
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
|
||||
Ok(state
|
||||
.router
|
||||
.route_chat(&state.client, &req, json_body)
|
||||
.await)
|
||||
}
|
||||
|
||||
#[post("/v1/completions")]
|
||||
@@ -195,24 +136,12 @@ async fn v1_completions(
|
||||
body: web::Json<CompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/completions")
|
||||
.await)
|
||||
}
|
||||
let json_body = serde_json::to_value(body.into_inner())
|
||||
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
|
||||
Ok(state
|
||||
.router
|
||||
.route_completion(&state.client, &req, json_body)
|
||||
.await)
|
||||
}
|
||||
|
||||
#[post("/add_worker")]
|
||||
@@ -254,29 +183,19 @@ async fn remove_worker(
|
||||
}
|
||||
|
||||
#[post("/flush_cache")]
|
||||
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, flush cache on both prefill and decode servers
|
||||
data.router.route_pd_flush_cache(&data.client).await
|
||||
} else {
|
||||
// Route to all workers for cache flushing
|
||||
data.router
|
||||
.route_to_all(&data.client, "/flush_cache", &req)
|
||||
.await
|
||||
}
|
||||
async fn flush_cache(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.flush_cache(&data.client).await
|
||||
}
|
||||
|
||||
#[get("/get_loads")]
|
||||
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
// Get loads from all workers
|
||||
data.router.get_all_loads(&data.client, &req).await
|
||||
async fn get_loads(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.get_worker_loads(&data.client).await
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub worker_urls: Vec<String>,
|
||||
pub policy_config: PolicyConfig,
|
||||
pub router_config: RouterConfig,
|
||||
pub max_payload_size: usize,
|
||||
pub log_dir: Option<String>,
|
||||
pub log_level: Option<String>,
|
||||
@@ -324,8 +243,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
}
|
||||
|
||||
info!("🚧 Initializing router on {}:{}", config.host, config.port);
|
||||
info!("🚧 Initializing workers on {:?}", config.worker_urls);
|
||||
info!("🚧 Policy Config: {:?}", config.policy_config);
|
||||
info!("🚧 Router mode: {:?}", config.router_config.mode);
|
||||
info!("🚧 Policy: {:?}", config.router_config.policy);
|
||||
info!(
|
||||
"🚧 Max payload size: {} MB",
|
||||
config.max_payload_size / (1024 * 1024)
|
||||
@@ -345,12 +264,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
let app_state_init = AppState::new(
|
||||
config.worker_urls.clone(),
|
||||
client.clone(),
|
||||
config.policy_config.clone(),
|
||||
)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
let app_state_init = AppState::new(config.router_config.clone(), client.clone())
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
let router_arc = Arc::clone(&app_state_init.router);
|
||||
let app_state = web::Data::new(app_state_init);
|
||||
|
||||
@@ -397,6 +312,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
.service(v1_completions)
|
||||
.service(v1_models)
|
||||
.service(get_model_info)
|
||||
.service(liveness)
|
||||
.service(readiness)
|
||||
.service(health)
|
||||
.service(health_generate)
|
||||
.service(get_server_info)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::router::Router;
|
||||
use crate::routers::RouterTrait;
|
||||
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use k8s_openapi::api::core::v1::Pod;
|
||||
@@ -176,7 +176,7 @@ impl PodInfo {
|
||||
|
||||
pub async fn start_service_discovery(
|
||||
config: ServiceDiscoveryConfig,
|
||||
router: Arc<Router>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
) -> Result<task::JoinHandle<()>, kube::Error> {
|
||||
// Don't initialize anything if service discovery is disabled
|
||||
if !config.enabled {
|
||||
@@ -346,7 +346,7 @@ pub async fn start_service_discovery(
|
||||
async fn handle_pod_event(
|
||||
pod_info: &PodInfo,
|
||||
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
|
||||
router: Arc<Router>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
port: u16,
|
||||
pd_mode: bool,
|
||||
) {
|
||||
@@ -379,17 +379,32 @@ async fn handle_pod_event(
|
||||
pod_info.name, pod_info.pod_type, worker_url
|
||||
);
|
||||
|
||||
// Handle PD mode with specific pod types
|
||||
let result = if pd_mode && pod_info.pod_type.is_some() {
|
||||
// Use PD-aware worker management
|
||||
if let Some(pod_type) = &pod_info.pod_type {
|
||||
router
|
||||
.add_pd_worker(&worker_url, pod_type.clone(), pod_info.bootstrap_port)
|
||||
.await
|
||||
// Need to import PDRouter type
|
||||
use crate::routers::pd_router::PDRouter;
|
||||
|
||||
// Try to downcast to PDRouter
|
||||
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
|
||||
match &pod_info.pod_type {
|
||||
Some(PodType::Prefill) => pd_router
|
||||
.add_prefill_server(worker_url.clone(), pod_info.bootstrap_port)
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
Some(PodType::Decode) => pd_router
|
||||
.add_decode_server(worker_url.clone())
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
Some(PodType::Regular) | None => {
|
||||
// Fall back to regular add_worker for regular pods
|
||||
router.add_worker(&worker_url).await
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err("Pod type is None in PD mode".to_string())
|
||||
Err("PD mode enabled but router is not a PDRouter".to_string())
|
||||
}
|
||||
} else {
|
||||
// Fallback to regular worker management
|
||||
// Regular mode or no pod type specified
|
||||
router.add_worker(&worker_url).await
|
||||
};
|
||||
|
||||
@@ -412,7 +427,7 @@ async fn handle_pod_event(
|
||||
async fn handle_pod_deletion(
|
||||
pod_info: &PodInfo,
|
||||
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
|
||||
router: Arc<Router>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
port: u16,
|
||||
pd_mode: bool,
|
||||
) {
|
||||
@@ -435,18 +450,34 @@ async fn handle_pod_deletion(
|
||||
pod_info.name, pod_info.pod_type, worker_url
|
||||
);
|
||||
|
||||
// Handle PD mode removal
|
||||
if pd_mode && pod_info.pod_type.is_some() {
|
||||
// Use PD-aware worker removal
|
||||
if let Some(pod_type) = &pod_info.pod_type {
|
||||
if let Err(e) = router.remove_pd_worker(&worker_url, pod_type.clone()).await {
|
||||
error!(
|
||||
"Failed to remove PD worker {} from router: {}",
|
||||
worker_url, e
|
||||
);
|
||||
use crate::routers::pd_router::PDRouter;
|
||||
|
||||
// Try to downcast to PDRouter for PD-specific removal
|
||||
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
|
||||
match &pod_info.pod_type {
|
||||
Some(PodType::Prefill) => {
|
||||
if let Err(e) = pd_router.remove_prefill_server(&worker_url).await {
|
||||
error!("Failed to remove prefill server {}: {}", worker_url, e);
|
||||
}
|
||||
}
|
||||
Some(PodType::Decode) => {
|
||||
if let Err(e) = pd_router.remove_decode_server(&worker_url).await {
|
||||
error!("Failed to remove decode server {}: {}", worker_url, e);
|
||||
}
|
||||
}
|
||||
Some(PodType::Regular) | None => {
|
||||
// Fall back to regular remove_worker
|
||||
router.remove_worker(&worker_url);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// PD mode but not a PDRouter, use generic removal
|
||||
router.remove_worker(&worker_url);
|
||||
}
|
||||
} else {
|
||||
// Fallback to regular worker removal
|
||||
// Regular mode removal
|
||||
router.remove_worker(&worker_url);
|
||||
}
|
||||
} else {
|
||||
@@ -462,11 +493,9 @@ async fn handle_pod_deletion(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::router::Router;
|
||||
use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus};
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
|
||||
use std::sync::RwLock;
|
||||
|
||||
// Helper function to create a Pod for testing PodInfo::from_pod
|
||||
fn create_k8s_pod(
|
||||
@@ -546,14 +575,14 @@ mod tests {
|
||||
}
|
||||
|
||||
// Helper to create a Router instance for testing event handlers
|
||||
fn create_test_router() -> Arc<Router> {
|
||||
let workers = Arc::new(RwLock::new(Vec::new()));
|
||||
Arc::new(Router::Random {
|
||||
workers,
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
_health_checker: None,
|
||||
})
|
||||
fn create_test_router() -> Arc<dyn RouterTrait> {
|
||||
use crate::config::PolicyConfig;
|
||||
use crate::policies::PolicyFactory;
|
||||
use crate::routers::router::Router;
|
||||
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
let router = Router::new(vec![], policy, 5, 1).unwrap();
|
||||
Arc::new(router) as Arc<dyn RouterTrait>
|
||||
}
|
||||
|
||||
// Helper to create a PD config for testing
|
||||
|
||||
@@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||
SamplingParams, StringOrArray, UserMessageContent,
|
||||
};
|
||||
use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest};
|
||||
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
|
||||
|
||||
#[test]
|
||||
fn test_benchmark_request_creation() {
|
||||
|
||||
@@ -8,12 +8,18 @@
|
||||
//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type.
|
||||
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
|
||||
|
||||
// TODO: This test file needs to be updated for the new configuration structure
|
||||
// where RoutingMode and PolicyConfig are separate
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_pd_routing {
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::pd_types::PDSelectionPolicy;
|
||||
use sglang_router_rs::router::{PolicyConfig, Router};
|
||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::routers::pd_types::get_hostname;
|
||||
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
|
||||
use sglang_router_rs::routers::RouterFactory;
|
||||
|
||||
// Test-only struct to help validate PD request parsing
|
||||
#[derive(Debug)]
|
||||
@@ -116,49 +122,68 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_pd_router_configuration() {
|
||||
// Test PrefillDecodeConfig creation with various policies
|
||||
// This config is used when pd_disaggregation=true
|
||||
let configs = vec![
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy::Random,
|
||||
prefill_urls: vec![
|
||||
("http://prefill1:8080".to_string(), Some(9000)),
|
||||
("http://prefill2:8080".to_string(), None),
|
||||
],
|
||||
decode_urls: vec![
|
||||
"http://decode1:8080".to_string(),
|
||||
"http://decode2:8080".to_string(),
|
||||
],
|
||||
timeout_secs: 10,
|
||||
interval_secs: 1,
|
||||
},
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy::PowerOfTwo,
|
||||
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
|
||||
decode_urls: vec!["http://decode:8080".to_string()],
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
},
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy::CacheAware {
|
||||
// Test PD router configuration with various policies
|
||||
// In the new structure, RoutingMode and PolicyConfig are separate
|
||||
let test_cases = vec![
|
||||
(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
("http://prefill1:8080".to_string(), Some(9000)),
|
||||
("http://prefill2:8080".to_string(), None),
|
||||
],
|
||||
decode_urls: vec![
|
||||
"http://decode1:8080".to_string(),
|
||||
"http://decode2:8080".to_string(),
|
||||
],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
),
|
||||
(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
|
||||
decode_urls: vec!["http://decode:8080".to_string()],
|
||||
},
|
||||
PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 5,
|
||||
},
|
||||
),
|
||||
(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
("http://p1:8080".to_string(), Some(9000)),
|
||||
("http://p2:8080".to_string(), Some(9001)),
|
||||
("http://p3:8080".to_string(), Some(9002)),
|
||||
],
|
||||
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.7,
|
||||
balance_abs_threshold: 20,
|
||||
balance_rel_threshold: 1.2,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000000,
|
||||
},
|
||||
prefill_urls: vec![
|
||||
("http://p1:8080".to_string(), Some(9000)),
|
||||
("http://p2:8080".to_string(), Some(9001)),
|
||||
("http://p3:8080".to_string(), Some(9002)),
|
||||
],
|
||||
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
|
||||
timeout_secs: 10,
|
||||
interval_secs: 2,
|
||||
},
|
||||
),
|
||||
];
|
||||
|
||||
for config in configs {
|
||||
for (mode, policy) in test_cases {
|
||||
let config = RouterConfig {
|
||||
mode,
|
||||
policy,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3001,
|
||||
max_payload_size: 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 10,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: None,
|
||||
};
|
||||
|
||||
// Router creation will fail due to health checks, but config should be valid
|
||||
let result = Router::new(vec![], config);
|
||||
let result = RouterFactory::create_router(&config);
|
||||
assert!(result.is_err());
|
||||
let error_msg = result.unwrap_err();
|
||||
// Error should be about health/timeout, not configuration
|
||||
@@ -225,9 +250,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_simulation() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Since we can't test the actual inject_bootstrap_fields function here
|
||||
// (it's private in the router module), we'll test the expected behavior
|
||||
|
||||
@@ -315,8 +337,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_hostname_extraction() {
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test various URL formats
|
||||
let test_cases = vec![
|
||||
("http://localhost:8080", "localhost"),
|
||||
@@ -662,7 +682,6 @@ mod test_pd_routing {
|
||||
#[test]
|
||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test bootstrap injection with actual benchmark request patterns
|
||||
let mut benchmark_request = json!({
|
||||
@@ -790,9 +809,6 @@ mod test_pd_routing {
|
||||
|
||||
#[test]
|
||||
fn test_large_batch_bootstrap_injection() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::pd_types::get_hostname;
|
||||
|
||||
// Test bootstrap injection performance with very large batches
|
||||
// This simulates the bench_one_batch_server.py scenario
|
||||
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||
|
||||
Reference in New Issue
Block a user