[router] allow one router to support different model families and serving mode (#10244)
This commit is contained in:
129
sgl-router/tests/cache_aware_backward_compat_test.rs
Normal file
129
sgl-router/tests/cache_aware_backward_compat_test.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
|
||||
use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_backward_compatibility_with_empty_model_id() {
|
||||
let config = CacheAwareConfig {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 2,
|
||||
balance_rel_threshold: 1.5,
|
||||
eviction_interval_secs: 0, // Disable background eviction for testing
|
||||
max_tree_size: 100,
|
||||
};
|
||||
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
|
||||
// Create workers with empty model_id (simulating existing routers)
|
||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
|
||||
// No model_id label - should default to "unknown"
|
||||
|
||||
let mut labels2 = HashMap::new();
|
||||
labels2.insert("model_id".to_string(), "unknown".to_string());
|
||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels2);
|
||||
|
||||
// Add workers - should both go to "default" tree
|
||||
policy.add_worker(&worker1);
|
||||
policy.add_worker(&worker2);
|
||||
|
||||
// Create worker list
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1.clone()), Arc::new(worker2.clone())];
|
||||
|
||||
// Select worker - should work without errors
|
||||
let selected = policy.select_worker(&workers, Some("test request"));
|
||||
assert!(selected.is_some(), "Should select a worker");
|
||||
|
||||
// Remove workers - should work without errors
|
||||
policy.remove_worker(&worker1);
|
||||
policy.remove_worker(&worker2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_model_ids() {
|
||||
let config = CacheAwareConfig {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 2,
|
||||
balance_rel_threshold: 1.5,
|
||||
eviction_interval_secs: 0,
|
||||
max_tree_size: 100,
|
||||
};
|
||||
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
|
||||
// Create workers with different model_id scenarios
|
||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
|
||||
// No model_id label - defaults to "unknown" which goes to "default" tree
|
||||
|
||||
let mut labels2 = HashMap::new();
|
||||
labels2.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels2);
|
||||
|
||||
let mut labels3 = HashMap::new();
|
||||
labels3.insert("model_id".to_string(), "unknown".to_string());
|
||||
let worker3 = BasicWorker::new("http://worker3:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels3);
|
||||
|
||||
let mut labels4 = HashMap::new();
|
||||
labels4.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker4 = BasicWorker::new("http://worker4:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels4);
|
||||
|
||||
// Add all workers
|
||||
policy.add_worker(&worker1);
|
||||
policy.add_worker(&worker2);
|
||||
policy.add_worker(&worker3);
|
||||
policy.add_worker(&worker4);
|
||||
|
||||
// Test selection with default workers only
|
||||
let default_workers: Vec<Arc<dyn Worker>> =
|
||||
vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())];
|
||||
let selected = policy.select_worker(&default_workers, Some("test request"));
|
||||
assert!(selected.is_some(), "Should select from default workers");
|
||||
|
||||
// Test selection with specific model workers only
|
||||
let llama_workers: Vec<Arc<dyn Worker>> =
|
||||
vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())];
|
||||
let selected = policy.select_worker(&llama_workers, Some("test request"));
|
||||
assert!(selected.is_some(), "Should select from llama-3 workers");
|
||||
|
||||
// Test selection with mixed workers
|
||||
let all_workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(worker1.clone()),
|
||||
Arc::new(worker2.clone()),
|
||||
Arc::new(worker3.clone()),
|
||||
Arc::new(worker4.clone()),
|
||||
];
|
||||
let selected = policy.select_worker(&all_workers, Some("test request"));
|
||||
assert!(selected.is_some(), "Should select from all workers");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_worker_by_url_backward_compat() {
|
||||
let config = CacheAwareConfig::default();
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
|
||||
// Create workers with different model_ids
|
||||
let mut labels1 = HashMap::new();
|
||||
labels1.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels1);
|
||||
|
||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
|
||||
// No model_id label - defaults to "unknown"
|
||||
|
||||
// Add workers
|
||||
policy.add_worker(&worker1);
|
||||
policy.add_worker(&worker2);
|
||||
|
||||
// Remove by URL (backward compatibility method)
|
||||
// Should remove from all trees since we don't know the model
|
||||
policy.remove_worker_by_url("http://worker1:8080");
|
||||
|
||||
// Verify removal worked
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())];
|
||||
let selected = policy.select_worker(&workers, Some("test"));
|
||||
assert_eq!(selected, Some(0), "Should only have worker2 left");
|
||||
}
|
||||
168
sgl-router/tests/policy_registry_integration.rs
Normal file
168
sgl-router/tests/policy_registry_integration.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
//! Integration tests for PolicyRegistry with RouterManager
|
||||
|
||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig};
|
||||
use sglang_router_rs::core::WorkerRegistry;
|
||||
use sglang_router_rs::policies::PolicyRegistry;
|
||||
use sglang_router_rs::protocols::worker_spec::WorkerConfigRequest;
|
||||
use sglang_router_rs::routers::router_manager::RouterManager;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_policy_registry_with_router_manager() {
|
||||
// Create RouterConfig
|
||||
let config = RouterConfig {
|
||||
enable_igw: true,
|
||||
policy: PolicyConfig::RoundRobin,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create HTTP client
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Create shared registries
|
||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||
let policy_registry = Arc::new(PolicyRegistry::new(PolicyConfig::RoundRobin));
|
||||
|
||||
// Create RouterManager with shared registries
|
||||
let _router_manager = RouterManager::new(
|
||||
config,
|
||||
client,
|
||||
worker_registry.clone(),
|
||||
policy_registry.clone(),
|
||||
);
|
||||
|
||||
// Test adding workers with different models and policies
|
||||
|
||||
// Add first worker for llama-3 with cache_aware policy hint
|
||||
let mut labels1 = HashMap::new();
|
||||
labels1.insert("policy".to_string(), "cache_aware".to_string());
|
||||
|
||||
let _worker1_config = WorkerConfigRequest {
|
||||
url: "http://worker1:8000".to_string(),
|
||||
model_id: Some("llama-3".to_string()),
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
labels: labels1,
|
||||
bootstrap_port: None,
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
};
|
||||
|
||||
// This would normally connect to a real worker, but for testing we'll just verify the structure
|
||||
// In a real test, we'd need to mock the worker or use a test server
|
||||
|
||||
// Verify PolicyRegistry has the correct policy for llama-3
|
||||
let _llama_policy = policy_registry.get_policy("llama-3");
|
||||
// After first worker is added, llama-3 should have a policy
|
||||
|
||||
// Add second worker for llama-3 with different policy hint (should be ignored)
|
||||
let mut labels2 = HashMap::new();
|
||||
labels2.insert("policy".to_string(), "random".to_string());
|
||||
|
||||
let _worker2_config = WorkerConfigRequest {
|
||||
url: "http://worker2:8000".to_string(),
|
||||
model_id: Some("llama-3".to_string()),
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
labels: labels2,
|
||||
bootstrap_port: None,
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
};
|
||||
|
||||
// The second worker should use the same policy as the first (cache_aware)
|
||||
|
||||
// Add worker for different model (gpt-4) with random policy
|
||||
let mut labels3 = HashMap::new();
|
||||
labels3.insert("policy".to_string(), "random".to_string());
|
||||
|
||||
let _worker3_config = WorkerConfigRequest {
|
||||
url: "http://worker3:8000".to_string(),
|
||||
model_id: Some("gpt-4".to_string()),
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
labels: labels3,
|
||||
bootstrap_port: None,
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
};
|
||||
|
||||
// Verify gpt-4 has random policy
|
||||
let _gpt_policy = policy_registry.get_policy("gpt-4");
|
||||
|
||||
// Test removing workers
|
||||
// When we remove both llama-3 workers, the policy should be cleaned up
|
||||
|
||||
println!("PolicyRegistry integration test structure created");
|
||||
println!("Note: This test requires mocking or test servers to fully execute");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_registry_cleanup() {
|
||||
use sglang_router_rs::config::PolicyConfig;
|
||||
use sglang_router_rs::policies::PolicyRegistry;
|
||||
|
||||
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
|
||||
|
||||
// Add workers for a model
|
||||
let policy1 = registry.on_worker_added("model-1", Some("cache_aware"));
|
||||
assert_eq!(policy1.name(), "cache_aware");
|
||||
|
||||
// Second worker uses existing policy
|
||||
let policy2 = registry.on_worker_added("model-1", Some("random"));
|
||||
assert_eq!(policy2.name(), "cache_aware"); // Should still be cache_aware
|
||||
|
||||
// Verify policy exists
|
||||
assert!(registry.get_policy("model-1").is_some());
|
||||
|
||||
// Remove first worker - policy should remain
|
||||
registry.on_worker_removed("model-1");
|
||||
assert!(registry.get_policy("model-1").is_some());
|
||||
|
||||
// Remove second worker - policy should be cleaned up
|
||||
registry.on_worker_removed("model-1");
|
||||
assert!(registry.get_policy("model-1").is_none());
|
||||
|
||||
println!("✓ PolicyRegistry cleanup test passed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_registry_multiple_models() {
|
||||
use sglang_router_rs::config::PolicyConfig;
|
||||
use sglang_router_rs::policies::PolicyRegistry;
|
||||
|
||||
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
|
||||
|
||||
// Add workers for different models with different policies
|
||||
let llama_policy = registry.on_worker_added("llama-3", Some("cache_aware"));
|
||||
let gpt_policy = registry.on_worker_added("gpt-4", Some("random"));
|
||||
let mistral_policy = registry.on_worker_added("mistral", None); // Uses default
|
||||
|
||||
assert_eq!(llama_policy.name(), "cache_aware");
|
||||
assert_eq!(gpt_policy.name(), "random");
|
||||
assert_eq!(mistral_policy.name(), "round_robin"); // Default
|
||||
|
||||
// Verify all policies are stored
|
||||
assert!(registry.get_policy("llama-3").is_some());
|
||||
assert!(registry.get_policy("gpt-4").is_some());
|
||||
assert!(registry.get_policy("mistral").is_some());
|
||||
|
||||
// Get all mappings
|
||||
let mappings = registry.get_all_mappings();
|
||||
assert_eq!(mappings.len(), 3);
|
||||
assert_eq!(mappings.get("llama-3").unwrap(), "cache_aware");
|
||||
assert_eq!(mappings.get("gpt-4").unwrap(), "random");
|
||||
assert_eq!(mappings.get("mistral").unwrap(), "round_robin");
|
||||
|
||||
println!("✓ PolicyRegistry multiple models test passed");
|
||||
}
|
||||
@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() {
|
||||
rid: None,
|
||||
};
|
||||
|
||||
let response = router.route_generate(None, &generate_request).await;
|
||||
let response = router.route_generate(None, &generate_request, None).await;
|
||||
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||
|
||||
// Test completion endpoint (should also not be supported)
|
||||
let completion_request = create_minimal_completion_request();
|
||||
let response = router.route_completion(None, &completion_request).await;
|
||||
let response = router
|
||||
.route_completion(None, &completion_request, None)
|
||||
.await;
|
||||
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() {
|
||||
chat_request.temperature = Some(0.7);
|
||||
|
||||
// Route the request
|
||||
let response = router.route_chat(None, &chat_request).await;
|
||||
let response = router.route_chat(None, &chat_request, None).await;
|
||||
|
||||
// Should get a successful response from mock server
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() {
|
||||
let chat_request: ChatCompletionRequest =
|
||||
serde_json::from_str(&body_str).unwrap();
|
||||
|
||||
router.route_chat(Some(&parts.headers), &chat_request).await
|
||||
router
|
||||
.route_chat(Some(&parts.headers), &chat_request, None)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}),
|
||||
@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
|
||||
});
|
||||
let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap();
|
||||
|
||||
let response = router.route_chat(None, &chat_request).await;
|
||||
let response = router.route_chat(None, &chat_request, None).await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Should be SSE
|
||||
@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() {
|
||||
|
||||
// First few requests should fail and record failures
|
||||
for _ in 0..3 {
|
||||
let response = router.route_chat(None, &chat_request).await;
|
||||
let response = router.route_chat(None, &chat_request, None).await;
|
||||
// Should get either an error or circuit breaker response
|
||||
assert!(
|
||||
response.status() == StatusCode::INTERNAL_SERVER_ERROR
|
||||
|
||||
Reference in New Issue
Block a user