[router] remove old/oudated/useless comments across code base (#10968)
This commit is contained in:
@@ -205,7 +205,6 @@ impl RoutingMode {
|
||||
decode_urls,
|
||||
..
|
||||
} => prefill_urls.len() + decode_urls.len(),
|
||||
// OpenAI mode represents a single upstream
|
||||
RoutingMode::OpenAI { .. } => 1,
|
||||
}
|
||||
}
|
||||
@@ -515,8 +514,6 @@ impl RouterConfig {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ============= RouterConfig Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_router_config_default() {
|
||||
let config = RouterConfig::default();
|
||||
@@ -556,7 +553,6 @@ mod tests {
|
||||
}
|
||||
|
||||
assert!(matches!(config.policy, PolicyConfig::RoundRobin));
|
||||
// Other fields should be default
|
||||
assert_eq!(config.host, "127.0.0.1");
|
||||
assert_eq!(config.port, 3001);
|
||||
}
|
||||
@@ -583,13 +579,10 @@ mod tests {
|
||||
assert_eq!(config.max_payload_size, deserialized.max_payload_size);
|
||||
assert_eq!(config.log_dir, deserialized.log_dir);
|
||||
assert_eq!(config.log_level, deserialized.log_level);
|
||||
// discovery and metrics are None in Default implementation
|
||||
assert!(deserialized.discovery.is_none());
|
||||
assert!(deserialized.metrics.is_none());
|
||||
}
|
||||
|
||||
// ============= RoutingMode Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_routing_mode_is_pd_mode() {
|
||||
let regular = RoutingMode::Regular {
|
||||
@@ -640,7 +633,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_routing_mode_serialization() {
|
||||
// Test Regular mode
|
||||
let regular = RoutingMode::Regular {
|
||||
worker_urls: vec!["http://worker1".to_string()],
|
||||
};
|
||||
@@ -648,7 +640,6 @@ mod tests {
|
||||
assert!(json.contains("\"type\":\"regular\""));
|
||||
assert!(json.contains("\"worker_urls\""));
|
||||
|
||||
// Test PrefillDecode mode
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
@@ -661,8 +652,6 @@ mod tests {
|
||||
assert!(json.contains("\"decode_urls\""));
|
||||
}
|
||||
|
||||
// ============= PolicyConfig Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_policy_config_name() {
|
||||
assert_eq!(PolicyConfig::Random.name(), "random");
|
||||
@@ -685,12 +674,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_policy_config_serialization() {
|
||||
// Test Random
|
||||
let random = PolicyConfig::Random;
|
||||
let json = serde_json::to_string(&random).unwrap();
|
||||
assert_eq!(json, r#"{"type":"random"}"#);
|
||||
|
||||
// Test CacheAware with all parameters
|
||||
let cache_aware = PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.8,
|
||||
balance_abs_threshold: 10,
|
||||
@@ -703,7 +690,6 @@ mod tests {
|
||||
assert!(json.contains("\"cache_threshold\":0.8"));
|
||||
assert!(json.contains("\"balance_abs_threshold\":10"));
|
||||
|
||||
// Test PowerOfTwo
|
||||
let power_of_two = PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
};
|
||||
@@ -756,8 +742,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= DiscoveryConfig Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_discovery_config_default() {
|
||||
let config = DiscoveryConfig::default();
|
||||
@@ -798,14 +782,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_discovery_config_namespace() {
|
||||
// Test None namespace (all namespaces)
|
||||
let config = DiscoveryConfig {
|
||||
namespace: None,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.namespace.is_none());
|
||||
|
||||
// Test specific namespace
|
||||
let config = DiscoveryConfig {
|
||||
namespace: Some("production".to_string()),
|
||||
..Default::default()
|
||||
@@ -813,8 +795,6 @@ mod tests {
|
||||
assert_eq!(config.namespace, Some("production".to_string()));
|
||||
}
|
||||
|
||||
// ============= MetricsConfig Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_metrics_config_default() {
|
||||
let config = MetricsConfig::default();
|
||||
@@ -834,8 +814,6 @@ mod tests {
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
}
|
||||
|
||||
// ============= RouterConfig Utility Methods Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_mode_type() {
|
||||
let config = RouterConfig {
|
||||
@@ -894,8 +872,6 @@ mod tests {
|
||||
assert!(config.has_metrics());
|
||||
}
|
||||
|
||||
// ============= Edge Cases =============
|
||||
|
||||
#[test]
|
||||
fn test_large_worker_lists() {
|
||||
let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect();
|
||||
@@ -906,7 +882,6 @@ mod tests {
|
||||
|
||||
assert_eq!(mode.worker_count(), 1000);
|
||||
|
||||
// Test serialization with large list
|
||||
let config = RouterConfig {
|
||||
mode,
|
||||
..Default::default()
|
||||
@@ -961,8 +936,6 @@ mod tests {
|
||||
assert_eq!(config.log_level, Some("".to_string()));
|
||||
}
|
||||
|
||||
// ============= Complex Configuration Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_full_pd_mode_config() {
|
||||
let config = RouterConfig {
|
||||
@@ -1149,7 +1122,6 @@ mod tests {
|
||||
assert!(config.has_metrics());
|
||||
assert_eq!(config.mode_type(), "regular");
|
||||
|
||||
// Test round-trip serialization
|
||||
let json = serde_json::to_string_pretty(&config).unwrap();
|
||||
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
|
||||
|
||||
@@ -1161,11 +1133,8 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
// ============= Policy Fallback Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_pd_policy_fallback_both_specified() {
|
||||
// When both prefill and decode policies are specified, they should be used
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
@@ -1183,21 +1152,19 @@ mod tests {
|
||||
|
||||
let main_policy = PolicyConfig::Random;
|
||||
|
||||
// Both specific policies should be used
|
||||
match pd.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::CacheAware { .. } => {} // Success
|
||||
PolicyConfig::CacheAware { .. } => {}
|
||||
_ => panic!("Expected CacheAware for prefill"),
|
||||
}
|
||||
|
||||
match pd.get_decode_policy(&main_policy) {
|
||||
PolicyConfig::PowerOfTwo { .. } => {} // Success
|
||||
PolicyConfig::PowerOfTwo { .. } => {}
|
||||
_ => panic!("Expected PowerOfTwo for decode"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_policy_fallback_only_prefill() {
|
||||
// When only prefill policy is specified, decode should use main policy
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
@@ -1213,22 +1180,19 @@ mod tests {
|
||||
|
||||
let main_policy = PolicyConfig::RoundRobin;
|
||||
|
||||
// Prefill should use specific policy
|
||||
match pd.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::CacheAware { .. } => {} // Success
|
||||
PolicyConfig::CacheAware { .. } => {}
|
||||
_ => panic!("Expected CacheAware for prefill"),
|
||||
}
|
||||
|
||||
// Decode should fall back to main policy
|
||||
match pd.get_decode_policy(&main_policy) {
|
||||
PolicyConfig::RoundRobin => {} // Success
|
||||
PolicyConfig::RoundRobin => {}
|
||||
_ => panic!("Expected RoundRobin for decode"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_policy_fallback_only_decode() {
|
||||
// When only decode policy is specified, prefill should use main policy
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
@@ -1240,22 +1204,19 @@ mod tests {
|
||||
|
||||
let main_policy = PolicyConfig::Random;
|
||||
|
||||
// Prefill should fall back to main policy
|
||||
match pd.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::Random => {} // Success
|
||||
PolicyConfig::Random => {}
|
||||
_ => panic!("Expected Random for prefill"),
|
||||
}
|
||||
|
||||
// Decode should use specific policy
|
||||
match pd.get_decode_policy(&main_policy) {
|
||||
PolicyConfig::PowerOfTwo { .. } => {} // Success
|
||||
PolicyConfig::PowerOfTwo { .. } => {}
|
||||
_ => panic!("Expected PowerOfTwo for decode"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_policy_fallback_none_specified() {
|
||||
// When no specific policies are specified, both should use main policy
|
||||
let pd = RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||
decode_urls: vec!["http://decode1".to_string()],
|
||||
@@ -1271,7 +1232,6 @@ mod tests {
|
||||
max_tree_size: 2000,
|
||||
};
|
||||
|
||||
// Both should fall back to main policy
|
||||
match pd.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold, ..
|
||||
@@ -1293,21 +1253,19 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_regular_mode_policy_fallback() {
|
||||
// For regular mode, the helper methods should just return the main policy
|
||||
let regular = RoutingMode::Regular {
|
||||
worker_urls: vec!["http://worker1".to_string()],
|
||||
};
|
||||
|
||||
let main_policy = PolicyConfig::RoundRobin;
|
||||
|
||||
// Both methods should return main policy for regular mode
|
||||
match regular.get_prefill_policy(&main_policy) {
|
||||
PolicyConfig::RoundRobin => {} // Success
|
||||
PolicyConfig::RoundRobin => {}
|
||||
_ => panic!("Expected RoundRobin for regular mode"),
|
||||
}
|
||||
|
||||
match regular.get_decode_policy(&main_policy) {
|
||||
PolicyConfig::RoundRobin => {} // Success
|
||||
PolicyConfig::RoundRobin => {}
|
||||
_ => panic!("Expected RoundRobin for regular mode"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -670,7 +670,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode_with_separate_policies() {
|
||||
// Test PD mode with different policies for prefill and decode
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
@@ -701,7 +700,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
|
||||
// Test that power-of-two policy requires at least 2 workers
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
|
||||
@@ -726,7 +724,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_requires_tokenizer() {
|
||||
// Test that gRPC connection mode requires tokenizer configuration
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
@@ -748,7 +745,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_with_model_path() {
|
||||
// Test that gRPC works with model_path
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
@@ -765,7 +761,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_with_tokenizer_path() {
|
||||
// Test that gRPC works with tokenizer_path
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
|
||||
@@ -336,7 +336,6 @@ mod tests {
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Record failures up to threshold
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
@@ -344,7 +343,6 @@ mod tests {
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
cb.record_failure();
|
||||
|
||||
// Circuit should now be open
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
assert!(!cb.can_execute());
|
||||
assert_eq!(cb.failure_count(), 3);
|
||||
@@ -359,14 +357,11 @@ mod tests {
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Open the circuit
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
|
||||
// Wait for timeout
|
||||
thread::sleep(Duration::from_millis(150));
|
||||
|
||||
// Circuit should be half-open
|
||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||
assert!(cb.can_execute());
|
||||
}
|
||||
@@ -381,20 +376,16 @@ mod tests {
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Open the circuit
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
|
||||
// Wait for timeout
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||
|
||||
// Record successes
|
||||
cb.record_success();
|
||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||
cb.record_success();
|
||||
|
||||
// Circuit should now be closed
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
assert!(cb.can_execute());
|
||||
}
|
||||
@@ -408,18 +399,14 @@ mod tests {
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Open the circuit
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
|
||||
// Wait for timeout
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||
|
||||
// Record a failure in half-open state
|
||||
cb.record_failure();
|
||||
|
||||
// Circuit should reopen immediately
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
assert!(!cb.can_execute());
|
||||
}
|
||||
@@ -432,17 +419,14 @@ mod tests {
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Record some failures
|
||||
cb.record_failure();
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.failure_count(), 2);
|
||||
|
||||
// Success should reset failure count
|
||||
cb.record_success();
|
||||
assert_eq!(cb.failure_count(), 0);
|
||||
assert_eq!(cb.success_count(), 1);
|
||||
|
||||
// Can now record more failures without opening
|
||||
cb.record_failure();
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
@@ -456,11 +440,9 @@ mod tests {
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Open the circuit
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
|
||||
// Manual reset
|
||||
cb.reset();
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
assert_eq!(cb.failure_count(), 0);
|
||||
@@ -505,7 +487,6 @@ mod tests {
|
||||
let cb2 = cb1.clone();
|
||||
assert_eq!(cb2.failure_count(), 1);
|
||||
|
||||
// Changes to cb1 affect cb2 (shared state)
|
||||
cb1.record_failure();
|
||||
assert_eq!(cb2.failure_count(), 2);
|
||||
}
|
||||
|
||||
@@ -1562,19 +1562,16 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
// Test health status
|
||||
assert!(dp_worker.is_healthy());
|
||||
dp_worker.set_healthy(false);
|
||||
assert!(!dp_worker.is_healthy());
|
||||
|
||||
// Test load tracking
|
||||
assert_eq!(dp_worker.load(), 0);
|
||||
dp_worker.increment_load();
|
||||
assert_eq!(dp_worker.load(), 1);
|
||||
dp_worker.decrement_load();
|
||||
assert_eq!(dp_worker.load(), 0);
|
||||
|
||||
// Test processed tracking
|
||||
assert_eq!(dp_worker.processed_requests(), 0);
|
||||
dp_worker.increment_processed();
|
||||
assert_eq!(dp_worker.processed_requests(), 1);
|
||||
|
||||
@@ -1485,7 +1485,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_parse_server_info_with_fallback() {
|
||||
// Test with "model" instead of "model_id"
|
||||
let json = serde_json::json!({
|
||||
"model": "gpt-4",
|
||||
"dp_size": 2
|
||||
|
||||
@@ -459,14 +459,12 @@ mod tests {
|
||||
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
|
||||
let worker_id = registry.register(Arc::from(worker));
|
||||
|
||||
// Verify registration
|
||||
assert!(registry.get(&worker_id).is_some());
|
||||
assert!(registry.get_by_url("http://worker1:8080").is_some());
|
||||
assert_eq!(registry.get_by_model("llama-3-8b").len(), 1);
|
||||
assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1);
|
||||
assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1);
|
||||
|
||||
// Test stats
|
||||
let stats = registry.stats();
|
||||
assert_eq!(stats.total_workers, 1);
|
||||
assert_eq!(stats.total_models, 1);
|
||||
@@ -519,27 +517,22 @@ mod tests {
|
||||
registry.register(Arc::from(worker2));
|
||||
registry.register(Arc::from(worker3));
|
||||
|
||||
// Test get_by_model_fast for llama-3
|
||||
let llama_workers = registry.get_by_model_fast("llama-3");
|
||||
assert_eq!(llama_workers.len(), 2);
|
||||
let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect();
|
||||
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
||||
assert!(urls.contains(&"http://worker2:8080".to_string()));
|
||||
|
||||
// Test get_by_model_fast for gpt-4
|
||||
let gpt_workers = registry.get_by_model_fast("gpt-4");
|
||||
assert_eq!(gpt_workers.len(), 1);
|
||||
assert_eq!(gpt_workers[0].url(), "http://worker3:8080");
|
||||
|
||||
// Test get_by_model_fast for non-existent model
|
||||
let unknown_workers = registry.get_by_model_fast("unknown-model");
|
||||
assert_eq!(unknown_workers.len(), 0);
|
||||
|
||||
// Test that both get_by_model and get_by_model_fast return same results
|
||||
let llama_workers_slow = registry.get_by_model("llama-3");
|
||||
assert_eq!(llama_workers.len(), llama_workers_slow.len());
|
||||
|
||||
// Test removal updates the model index
|
||||
registry.remove_by_url("http://worker1:8080");
|
||||
let llama_workers_after = registry.get_by_model_fast("llama-3");
|
||||
assert_eq!(llama_workers_after.len(), 1);
|
||||
|
||||
@@ -266,7 +266,6 @@ mod tests {
|
||||
assert_eq!(chain.responses[1].input, "Second");
|
||||
assert_eq!(chain.responses[2].input, "Third");
|
||||
|
||||
// Test with max_depth
|
||||
let limited_chain = store.get_response_chain(&id3, Some(2)).await.unwrap();
|
||||
assert_eq!(limited_chain.responses.len(), 2);
|
||||
assert_eq!(limited_chain.responses[0].input, "Second");
|
||||
@@ -314,7 +313,6 @@ mod tests {
|
||||
let deleted_count = store.delete_user_responses("user1").await.unwrap();
|
||||
assert_eq!(deleted_count, 2);
|
||||
|
||||
// Verify they're gone
|
||||
let user1_responses_after = store.list_user_responses("user1", None).await.unwrap();
|
||||
assert_eq!(user1_responses_after.len(), 0);
|
||||
|
||||
|
||||
@@ -223,7 +223,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_proto_types_compilation() {
|
||||
// Test that protobuf types can be constructed
|
||||
let health_req = proto::HealthCheckRequest {
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: "test".to_string(),
|
||||
@@ -320,8 +319,6 @@ mod tests {
|
||||
}
|
||||
|
||||
// TODO: SessionParams not in current proto - skip test
|
||||
// #[test]
|
||||
// fn test_session_params() { ... }
|
||||
|
||||
#[test]
|
||||
fn test_embed_request() {
|
||||
@@ -349,7 +346,6 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_connect_invalid_endpoint() {
|
||||
// Test connecting to an invalid endpoint should return error
|
||||
let result = SglangSchedulerClient::connect("invalid://endpoint").await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
@@ -365,7 +361,6 @@ mod tests {
|
||||
assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]);
|
||||
}
|
||||
|
||||
// Test response type construction
|
||||
#[test]
|
||||
fn test_generate_stream_chunk() {
|
||||
let chunk = proto::GenerateStreamChunk {
|
||||
@@ -383,6 +378,4 @@ mod tests {
|
||||
}
|
||||
|
||||
// TODO: ModelInfo not in current proto - skip test
|
||||
// #[test]
|
||||
// fn test_model_info() { ... }
|
||||
}
|
||||
|
||||
@@ -288,8 +288,6 @@ impl McpClientManager {
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Helpers =====
|
||||
|
||||
fn client_for(&self, server_name: &str) -> McpResult<&RunningService<RoleClient, ()>> {
|
||||
self.clients
|
||||
.get(server_name)
|
||||
@@ -317,8 +315,6 @@ impl McpClientManager {
|
||||
.ok_or_else(|| McpError::ResourceNotFound(uri.to_string()))
|
||||
}
|
||||
|
||||
// ===== Tool Methods =====
|
||||
|
||||
/// Call a tool by name
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
@@ -380,8 +376,6 @@ impl McpClientManager {
|
||||
self.clients.keys().cloned().collect()
|
||||
}
|
||||
|
||||
// ===== Prompt Methods =====
|
||||
|
||||
/// Get a prompt by name with arguments
|
||||
pub async fn get_prompt(
|
||||
&self,
|
||||
@@ -439,8 +433,6 @@ impl McpClientManager {
|
||||
})
|
||||
}
|
||||
|
||||
// ===== Resource Methods =====
|
||||
|
||||
/// Read a resource by URI
|
||||
pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
|
||||
let (server_name, _resource) = self.resource_entry(uri)?;
|
||||
|
||||
@@ -598,8 +598,6 @@ mod tests {
|
||||
use super::*;
|
||||
use std::net::TcpListener;
|
||||
|
||||
// ============= PrometheusConfig Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_prometheus_config_default() {
|
||||
let config = PrometheusConfig::default();
|
||||
@@ -628,8 +626,6 @@ mod tests {
|
||||
assert_eq!(cloned.host, config.host);
|
||||
}
|
||||
|
||||
// ============= IP Address Parsing Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_valid_ipv4_parsing() {
|
||||
let test_cases = vec!["127.0.0.1", "192.168.1.1", "0.0.0.0"];
|
||||
@@ -679,8 +675,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Socket Address Creation Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_socket_addr_creation() {
|
||||
let test_cases = vec![("127.0.0.1", 8080), ("0.0.0.0", 29000), ("::1", 9090)];
|
||||
@@ -716,8 +710,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Duration Bucket Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_duration_bucket_coverage() {
|
||||
let test_cases: [(f64, &str); 7] = [
|
||||
@@ -743,8 +735,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Matcher Configuration Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_duration_suffix_matcher() {
|
||||
let matcher = Matcher::Suffix(String::from("duration_seconds"));
|
||||
@@ -763,8 +753,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Builder Configuration Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_prometheus_builder_configuration() {
|
||||
let _config = PrometheusConfig::default();
|
||||
@@ -783,16 +771,12 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Upkeep Timeout Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_upkeep_timeout_duration() {
|
||||
let timeout = Duration::from_secs(5 * 60);
|
||||
assert_eq!(timeout.as_secs(), 300);
|
||||
}
|
||||
|
||||
// ============= Custom Bucket Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_custom_buckets_for_different_metrics() {
|
||||
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
|
||||
@@ -810,8 +794,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= RouterMetrics Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_metrics_static_methods() {
|
||||
RouterMetrics::record_request("/generate");
|
||||
@@ -876,8 +858,6 @@ mod tests {
|
||||
TokenizerMetrics::set_vocab_size("huggingface", 50000);
|
||||
}
|
||||
|
||||
// ============= Port Availability Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_port_already_in_use() {
|
||||
let port = 29123;
|
||||
@@ -892,8 +872,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Integration Test Helpers =============
|
||||
|
||||
#[test]
|
||||
fn test_metrics_endpoint_accessibility() {
|
||||
let config = PrometheusConfig {
|
||||
@@ -937,8 +915,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Edge Cases Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_empty_string_metrics() {
|
||||
RouterMetrics::record_request("");
|
||||
|
||||
@@ -178,8 +178,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Logging Middleware =============
|
||||
|
||||
/// Custom span maker that includes request ID
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RequestSpan;
|
||||
@@ -336,8 +334,6 @@ pub fn log_request(entry: RequestLogEntry) {
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Concurrency Limiting with Queue Support ============
|
||||
|
||||
/// Request queue entry
|
||||
pub struct QueuedRequest {
|
||||
/// Time when the request was queued
|
||||
|
||||
@@ -54,21 +54,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_create_from_config() {
|
||||
// Test Random
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
assert_eq!(policy.name(), "random");
|
||||
|
||||
// Test RoundRobin
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin);
|
||||
assert_eq!(policy.name(), "round_robin");
|
||||
|
||||
// Test PowerOfTwo
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
});
|
||||
assert_eq!(policy.name(), "power_of_two");
|
||||
|
||||
// Test CacheAware
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.7,
|
||||
balance_abs_threshold: 10,
|
||||
|
||||
@@ -75,7 +75,6 @@ mod tests {
|
||||
),
|
||||
];
|
||||
|
||||
// Test multiple selections to ensure randomness
|
||||
let mut counts = HashMap::new();
|
||||
for _ in 0..100 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
|
||||
@@ -49,12 +49,6 @@ use std::collections::HashMap;
|
||||
// - StringOrArray & LoRAPath types
|
||||
// - Helper functions
|
||||
|
||||
// ==================================================================
|
||||
// = OPENAI SPEC - Chat Completions API =
|
||||
// ==================================================================
|
||||
|
||||
// ============= Message Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatMessage {
|
||||
@@ -119,8 +113,6 @@ pub struct ImageUrl {
|
||||
pub detail: Option<String>, // "auto", "low", or "high"
|
||||
}
|
||||
|
||||
// ============= Response Format Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ResponseFormat {
|
||||
@@ -140,8 +132,6 @@ pub struct JsonSchemaFormat {
|
||||
pub strict: Option<bool>,
|
||||
}
|
||||
|
||||
// ============= Streaming Delta Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatMessageDelta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -177,8 +167,6 @@ pub struct FunctionCallDelta {
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ============= Request =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||
pub struct ChatCompletionRequest {
|
||||
/// A list of messages comprising the conversation so far
|
||||
@@ -299,7 +287,6 @@ pub struct ChatCompletionRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub verbosity: Option<i32>,
|
||||
|
||||
// ============= SGLang Extensions =============
|
||||
/// Top-k sampling parameter (-1 to disable)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<i32>,
|
||||
@@ -423,8 +410,6 @@ impl GenerationRequest for ChatCompletionRequest {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Regular Response =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
@@ -453,8 +438,6 @@ pub struct ChatChoice {
|
||||
pub hidden_states: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
// ============= Streaming Response =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
pub id: String,
|
||||
@@ -477,9 +460,6 @@ pub struct ChatStreamChoice {
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = OPENAI SPEC - Completions API =
|
||||
// ==================================================================
|
||||
// Completions API request types (v1/completions) - DEPRECATED but still supported
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -554,7 +534,6 @@ pub struct CompletionRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<i64>,
|
||||
|
||||
// ============= SGLang Extensions =============
|
||||
/// Top-k sampling parameter (-1 to disable)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<i32>,
|
||||
@@ -599,7 +578,6 @@ pub struct CompletionRequest {
|
||||
#[serde(default = "default_true")]
|
||||
pub skip_special_tokens: bool,
|
||||
|
||||
// ============= SGLang Extensions =============
|
||||
/// Path to LoRA adapter(s) for model customization
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lora_path: Option<LoRAPath>,
|
||||
@@ -638,8 +616,6 @@ impl GenerationRequest for CompletionRequest {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Regular Response =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub id: String,
|
||||
@@ -668,8 +644,6 @@ pub struct CompletionChoice {
|
||||
pub hidden_states: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
// ============= Streaming Response =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionStreamResponse {
|
||||
pub id: String,
|
||||
@@ -690,12 +664,6 @@ pub struct CompletionStreamChoice {
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = OPENAI SPEC - Responses API =
|
||||
// ==================================================================
|
||||
|
||||
// ============= Tool Definitions =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponseTool {
|
||||
#[serde(rename = "type")]
|
||||
@@ -709,8 +677,6 @@ pub enum ResponseToolType {
|
||||
CodeInterpreter,
|
||||
}
|
||||
|
||||
// ============= Reasoning Configuration =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponseReasoningParam {
|
||||
#[serde(default = "default_reasoning_effort")]
|
||||
@@ -729,8 +695,6 @@ pub enum ReasoningEffort {
|
||||
High,
|
||||
}
|
||||
|
||||
// ============= Input/Output Items =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -790,8 +754,6 @@ pub enum ResponseReasoningContent {
|
||||
ReasoningText { text: String },
|
||||
}
|
||||
|
||||
// ============= Output Items for Response =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -823,8 +785,6 @@ pub enum ResponseOutputItem {
|
||||
},
|
||||
}
|
||||
|
||||
// ============= Service Tier =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ServiceTier {
|
||||
@@ -841,8 +801,6 @@ impl Default for ServiceTier {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Truncation =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Truncation {
|
||||
@@ -856,8 +814,6 @@ impl Default for Truncation {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Response Status =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseStatus {
|
||||
@@ -868,8 +824,6 @@ pub enum ResponseStatus {
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
// ============= Reasoning Info =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ReasoningInfo {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -878,8 +832,6 @@ pub struct ReasoningInfo {
|
||||
pub summary: Option<String>,
|
||||
}
|
||||
|
||||
// ============= Text Format =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponseTextFormat {
|
||||
pub format: TextFormatType,
|
||||
@@ -891,8 +843,6 @@ pub struct TextFormatType {
|
||||
pub format_type: String,
|
||||
}
|
||||
|
||||
// ============= Include Fields =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum IncludeField {
|
||||
@@ -910,8 +860,6 @@ pub enum IncludeField {
|
||||
ReasoningEncryptedContent,
|
||||
}
|
||||
|
||||
// ============= Usage Info =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct UsageInfo {
|
||||
pub prompt_tokens: u32,
|
||||
@@ -928,8 +876,6 @@ pub struct PromptTokenUsageInfo {
|
||||
pub cached_tokens: u32,
|
||||
}
|
||||
|
||||
// ============= Response Usage Format =============
|
||||
|
||||
/// OpenAI Responses API usage format (different from standard UsageInfo)
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponseUsage {
|
||||
@@ -1038,7 +984,6 @@ fn generate_request_id() -> String {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponsesRequest {
|
||||
// ============= Core OpenAI API fields =============
|
||||
/// Run the request in the background
|
||||
#[serde(default)]
|
||||
pub background: bool,
|
||||
@@ -1122,7 +1067,6 @@ pub struct ResponsesRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
// ============= SGLang Extensions =============
|
||||
/// Request ID
|
||||
#[serde(default = "generate_request_id")]
|
||||
pub request_id: String,
|
||||
@@ -1606,8 +1550,6 @@ impl ResponsesResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Helper Functions =============
|
||||
|
||||
impl ResponseOutputItem {
|
||||
/// Create a new message output item
|
||||
pub fn new_message(
|
||||
@@ -1708,20 +1650,12 @@ impl UsageInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = OPENAI SPEC - Common =
|
||||
// ==================================================================
|
||||
|
||||
// ============= Shared Request Components =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct StreamOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub include_usage: Option<bool>,
|
||||
}
|
||||
|
||||
// ============= Tool Choice Types =============
|
||||
|
||||
/// Tool choice value for simple string options
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -1793,8 +1727,6 @@ pub struct FunctionCallResponse {
|
||||
pub arguments: Option<String>, // JSON string
|
||||
}
|
||||
|
||||
// ============= Usage Tracking =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
@@ -1809,8 +1741,6 @@ pub struct CompletionTokensDetails {
|
||||
pub reasoning_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// ============= Logprobs Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct LogProbs {
|
||||
pub tokens: Vec<String>,
|
||||
@@ -1860,10 +1790,6 @@ pub struct ErrorDetail {
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = SGLANG SPEC - GENERATE API =
|
||||
// ==================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum InputIds {
|
||||
@@ -1975,7 +1901,6 @@ pub struct GenerateRequest {
|
||||
#[serde(default)]
|
||||
pub return_logprob: bool,
|
||||
|
||||
// ============= SGLang Extensions =============
|
||||
/// Path to LoRA adapter(s) for model customization
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lora_path: Option<LoRAPath>,
|
||||
@@ -2036,10 +1961,6 @@ impl GenerationRequest for GenerateRequest {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = SGLANG SPEC - RERANK API =
|
||||
// ==================================================================
|
||||
|
||||
// Constants for rerank API
|
||||
pub const DEFAULT_MODEL_NAME: &str = "default";
|
||||
|
||||
@@ -2237,10 +2158,6 @@ impl RerankResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = OPENAI SPEC - Embeddings API =
|
||||
// ==================================================================
|
||||
|
||||
/// Embeddings request compatible with OpenAI API
|
||||
/// We intentionally keep fields flexible to pass through to workers.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -2292,10 +2209,6 @@ impl GenerationRequest for EmbeddingRequest {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = COMMON =
|
||||
// ==================================================================
|
||||
|
||||
/// Helper function for serde default value
|
||||
pub fn default_true() -> bool {
|
||||
true
|
||||
@@ -2359,10 +2272,6 @@ mod tests {
|
||||
use super::*;
|
||||
use serde_json::{from_str, json, to_string};
|
||||
|
||||
// ==================================================================
|
||||
// = RERANK REQUEST TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_serialization() {
|
||||
let request = RerankRequest {
|
||||
@@ -2534,10 +2443,6 @@ mod tests {
|
||||
assert_eq!(request.effective_top_k(), 3);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = RERANK RESPONSE TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_creation() {
|
||||
let results = vec![
|
||||
@@ -2709,10 +2614,6 @@ mod tests {
|
||||
assert_eq!(response.results[0].document, None);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = RERANK RESULT TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_result_serialization() {
|
||||
let result = RerankResult {
|
||||
@@ -2755,10 +2656,6 @@ mod tests {
|
||||
assert_eq!(deserialized.meta_info, result.meta_info);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = V1 COMPATIBILITY TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_v1_rerank_req_input_serialization() {
|
||||
let v1_input = V1RerankReqInput {
|
||||
@@ -2791,10 +2688,6 @@ mod tests {
|
||||
assert_eq!(request.user, None);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = GENERATION REQUEST TRAIT TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_generation_request_trait() {
|
||||
let request = RerankRequest {
|
||||
@@ -2812,10 +2705,6 @@ mod tests {
|
||||
assert_eq!(request.extract_text_for_routing(), "test query");
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = EDGE CASES AND STRESS TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_very_long_query() {
|
||||
let long_query = "a".repeat(100000);
|
||||
@@ -2918,10 +2807,6 @@ mod tests {
|
||||
assert_eq!(usage.total_tokens, 150);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = INTEGRATION TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_full_rerank_workflow() {
|
||||
// Create request
|
||||
@@ -2980,7 +2865,6 @@ mod tests {
|
||||
// Apply top_k
|
||||
response.apply_top_k(request.effective_top_k());
|
||||
|
||||
// Verify results
|
||||
assert_eq!(response.results.len(), 2);
|
||||
assert_eq!(response.results[0].score, 0.95);
|
||||
assert_eq!(response.results[0].index, 0);
|
||||
@@ -2995,10 +2879,6 @@ mod tests {
|
||||
assert_eq!(deserialized.model, response.model);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = EMBEDDINGS REQUEST TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_embedding_request_serialization_string_input() {
|
||||
let req = EmbeddingRequest {
|
||||
|
||||
@@ -537,10 +537,6 @@ pub trait ValidatableRequest:
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = OPENAI CHAT COMPLETION VALIDATION =
|
||||
// ==================================================================
|
||||
|
||||
impl SamplingOptionsProvider for ChatCompletionRequest {
|
||||
fn get_temperature(&self) -> Option<f32> {
|
||||
self.temperature
|
||||
@@ -909,7 +905,6 @@ mod tests {
|
||||
fn test_chat_cross_parameter_conflicts() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// Test 1: max_tokens vs max_completion_tokens conflict
|
||||
request.max_tokens = Some(100);
|
||||
request.max_completion_tokens = Some(200);
|
||||
assert!(
|
||||
@@ -921,7 +916,6 @@ mod tests {
|
||||
request.max_tokens = None;
|
||||
request.max_completion_tokens = None;
|
||||
|
||||
// Test 2: tools vs functions conflict (deprecated)
|
||||
request.tools = Some(vec![]);
|
||||
request.functions = Some(vec![]);
|
||||
assert!(
|
||||
@@ -929,7 +923,6 @@ mod tests {
|
||||
"Should reject both tools and functions"
|
||||
);
|
||||
|
||||
// Test 3: logprobs=true without top_logprobs should be valid
|
||||
let mut request = create_valid_chat_request();
|
||||
request.logprobs = true;
|
||||
request.top_logprobs = None;
|
||||
@@ -938,7 +931,6 @@ mod tests {
|
||||
"logprobs=true without top_logprobs should be valid"
|
||||
);
|
||||
|
||||
// Test 4: top_logprobs without logprobs=true should fail (OpenAI rule)
|
||||
let mut request = create_valid_chat_request();
|
||||
request.logprobs = false;
|
||||
request.top_logprobs = Some(5);
|
||||
@@ -967,7 +959,6 @@ mod tests {
|
||||
fn test_parameter_ranges() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// Test temperature range (0.0 to 2.0)
|
||||
request.temperature = Some(1.5);
|
||||
assert!(request.validate().is_ok());
|
||||
request.temperature = Some(-0.1);
|
||||
@@ -975,7 +966,6 @@ mod tests {
|
||||
request.temperature = Some(3.0);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test top_p range (0.0 to 1.0)
|
||||
request.temperature = Some(1.0); // Reset
|
||||
request.top_p = Some(0.9);
|
||||
assert!(request.validate().is_ok());
|
||||
@@ -984,7 +974,6 @@ mod tests {
|
||||
request.top_p = Some(1.5);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test frequency_penalty range (-2.0 to 2.0)
|
||||
request.top_p = Some(0.9); // Reset
|
||||
request.frequency_penalty = Some(1.5);
|
||||
assert!(request.validate().is_ok());
|
||||
@@ -993,7 +982,6 @@ mod tests {
|
||||
request.frequency_penalty = Some(3.0);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test presence_penalty range (-2.0 to 2.0)
|
||||
request.frequency_penalty = Some(0.0); // Reset
|
||||
request.presence_penalty = Some(-1.5);
|
||||
assert!(request.validate().is_ok());
|
||||
@@ -1002,7 +990,6 @@ mod tests {
|
||||
request.presence_penalty = Some(2.5);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test repetition_penalty range (0.0 to 2.0)
|
||||
request.presence_penalty = Some(0.0); // Reset
|
||||
request.repetition_penalty = Some(1.2);
|
||||
assert!(request.validate().is_ok());
|
||||
@@ -1011,7 +998,6 @@ mod tests {
|
||||
request.repetition_penalty = Some(2.1);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test min_p range (0.0 to 1.0)
|
||||
request.repetition_penalty = Some(1.0); // Reset
|
||||
request.min_p = Some(0.5);
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
@@ -373,7 +373,6 @@ mod tests {
|
||||
// Both should use the same passthrough parser instance
|
||||
assert!(Arc::ptr_eq(&parser1, &parser2));
|
||||
|
||||
// Verify it's actually a passthrough parser
|
||||
let parser = parser1.lock().unwrap();
|
||||
assert_eq!(parser.model_type(), "passthrough");
|
||||
}
|
||||
@@ -456,7 +455,6 @@ mod tests {
|
||||
|
||||
match p.detect_and_parse_reasoning(&input) {
|
||||
Ok(result) => {
|
||||
// Verify parsing worked correctly with substantial content
|
||||
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
|
||||
assert!(result
|
||||
.normal_text
|
||||
|
||||
@@ -88,7 +88,6 @@ mod tests {
|
||||
fn test_kimi_partial_unicode() {
|
||||
let mut parser = KimiParser::new();
|
||||
|
||||
// Test partial Unicode token buffering
|
||||
let result1 = parser
|
||||
.parse_reasoning_streaming_incremental("◁thi")
|
||||
.unwrap();
|
||||
|
||||
@@ -96,8 +96,6 @@ impl GrpcRouter {
|
||||
})
|
||||
}
|
||||
|
||||
// ============ Chat Implementation ============
|
||||
|
||||
/// Main route_chat implementation
|
||||
async fn route_chat_impl(
|
||||
&self,
|
||||
@@ -207,7 +205,6 @@ impl GrpcRouter {
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Helper Methods ============
|
||||
/// Select a worker for the request
|
||||
fn select_worker_for_request(
|
||||
&self,
|
||||
@@ -809,7 +806,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_mixed_content_types() {
|
||||
// Test with both text and multimodal content
|
||||
let messages = vec![
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
@@ -833,7 +829,6 @@ mod tests {
|
||||
},
|
||||
];
|
||||
|
||||
// Test String format
|
||||
let result_string =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||
.unwrap();
|
||||
@@ -842,7 +837,6 @@ mod tests {
|
||||
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
|
||||
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
|
||||
|
||||
// Test OpenAI format
|
||||
let result_openai =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
|
||||
.unwrap();
|
||||
|
||||
@@ -957,7 +957,6 @@ impl RouterTrait for PDRouter {
|
||||
}
|
||||
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
// Test model generation capability by selecting a random pair and testing them
|
||||
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
||||
|
||||
// Select a random worker pair using the policy
|
||||
@@ -972,7 +971,6 @@ impl RouterTrait for PDRouter {
|
||||
}
|
||||
};
|
||||
|
||||
// Test prefill server's health_generate
|
||||
let prefill_url = format!("{}/health_generate", prefill.url());
|
||||
let (prefill_result, decode_result) = tokio::join!(
|
||||
self.client.get(&prefill_url).send(),
|
||||
|
||||
@@ -1018,7 +1018,6 @@ mod tests {
|
||||
};
|
||||
let port = 8080u16;
|
||||
|
||||
// Test that unified handler works for regular mode
|
||||
handle_pod_event(
|
||||
&pod_info,
|
||||
Arc::clone(&tracked_pods),
|
||||
@@ -1045,7 +1044,6 @@ mod tests {
|
||||
};
|
||||
let port = 8080u16;
|
||||
|
||||
// Test that unified handler works for PD mode with prefill
|
||||
handle_pod_event(
|
||||
&pod_info,
|
||||
Arc::clone(&tracked_pods),
|
||||
@@ -1080,7 +1078,6 @@ mod tests {
|
||||
|
||||
let port = 8080u16;
|
||||
|
||||
// Test that unified handler works for deletion in PD mode
|
||||
handle_pod_deletion(
|
||||
&pod_info,
|
||||
Arc::clone(&tracked_pods),
|
||||
|
||||
@@ -279,11 +279,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_create_tiktoken_tokenizer() {
|
||||
// Test creating tokenizer for GPT models
|
||||
let tokenizer = create_tokenizer("gpt-4").unwrap();
|
||||
assert!(tokenizer.vocab_size() > 0);
|
||||
|
||||
// Test encoding and decoding
|
||||
let text = "Hello, world!";
|
||||
let encoding = tokenizer.encode(text).unwrap();
|
||||
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
|
||||
@@ -292,7 +290,6 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_download_tokenizer_from_hf() {
|
||||
// Test with a small model that should have tokenizer files
|
||||
// Skip this test if HF_TOKEN is not set and we're in CI
|
||||
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
|
||||
println!("Skipping HF download test in CI without HF_TOKEN");
|
||||
|
||||
@@ -206,7 +206,6 @@ mod tests {
|
||||
// The incremental text should be " world" (with the space that the mock tokenizer adds)
|
||||
assert_eq!(text2, " world");
|
||||
|
||||
// Verify the full text
|
||||
assert_eq!(seq.text().unwrap(), "Hello world");
|
||||
}
|
||||
|
||||
|
||||
@@ -398,7 +398,6 @@ mod tests {
|
||||
// The fix ensures we only output NEW text, not accumulated text
|
||||
assert_eq!(outputs.len(), 3);
|
||||
|
||||
// Verify no text is repeated
|
||||
for i in 0..outputs.len() {
|
||||
for j in i + 1..outputs.len() {
|
||||
// No output should contain another (no accumulation)
|
||||
|
||||
@@ -36,22 +36,17 @@ fn test_tokenizer_wrapper() {
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
// Test encoding
|
||||
let encoding = tokenizer.encode("Hello world").unwrap();
|
||||
assert_eq!(encoding.token_ids(), &[1, 2]);
|
||||
|
||||
// Test decoding
|
||||
let text = tokenizer.decode(&[1, 2], false).unwrap();
|
||||
assert_eq!(text, "Hello world");
|
||||
|
||||
// Test vocab size
|
||||
assert_eq!(tokenizer.vocab_size(), 8);
|
||||
|
||||
// Test token to ID
|
||||
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
|
||||
assert_eq!(tokenizer.token_to_id("unknown"), None);
|
||||
|
||||
// Test ID to token
|
||||
assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string()));
|
||||
assert_eq!(tokenizer.id_to_token(9999), None);
|
||||
}
|
||||
|
||||
@@ -246,7 +246,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_unrecognized_model_name_returns_error() {
|
||||
// Test that unrecognized model names return an error
|
||||
let result = TiktokenTokenizer::from_model_name("distilgpt-2");
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
@@ -268,7 +267,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_recognized_model_names() {
|
||||
// Test that recognized model names work correctly
|
||||
assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
|
||||
assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
|
||||
assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
|
||||
|
||||
@@ -139,7 +139,6 @@ mod tests {
|
||||
async fn test_single_call_with_semicolon() {
|
||||
let parser = LlamaParser::new();
|
||||
// Note: Llama 3.2 doesn't handle multiple calls well
|
||||
// Test that we can at least parse a single call followed by semicolon
|
||||
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
@@ -102,7 +102,6 @@ impl PythonicParser {
|
||||
if bracket_count == 0 {
|
||||
// Found the matching bracket
|
||||
let extracted: String = chars[start_idx..=i].iter().collect();
|
||||
// Verify this actually contains a function call
|
||||
if extracted.contains('(') && extracted.contains(')') {
|
||||
return Some(extracted);
|
||||
}
|
||||
|
||||
@@ -21,21 +21,18 @@ fn test_parse_state_new() {
|
||||
fn test_parse_state_process_char() {
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Test bracket tracking
|
||||
state.process_char('{');
|
||||
assert_eq!(state.bracket_depth, 1);
|
||||
|
||||
state.process_char('}');
|
||||
assert_eq!(state.bracket_depth, 0);
|
||||
|
||||
// Test string tracking
|
||||
state.process_char('"');
|
||||
assert!(state.in_string);
|
||||
|
||||
state.process_char('"');
|
||||
assert!(!state.in_string);
|
||||
|
||||
// Test escape handling
|
||||
state.process_char('"');
|
||||
state.process_char('\\');
|
||||
assert!(state.escape_next);
|
||||
@@ -63,10 +60,8 @@ fn test_token_config() {
|
||||
fn test_parser_registry() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// Test has default mappings
|
||||
assert!(!registry.list_mappings().is_empty());
|
||||
|
||||
// Test model pattern matching
|
||||
let mappings = registry.list_mappings();
|
||||
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
|
||||
assert!(has_gpt);
|
||||
@@ -76,10 +71,8 @@ fn test_parser_registry() {
|
||||
fn test_parser_registry_pattern_matching() {
|
||||
let mut registry = ParserRegistry::new_for_testing();
|
||||
|
||||
// Test that model mappings work by checking the list
|
||||
registry.map_model("test-model", "json");
|
||||
|
||||
// Verify through list_mappings
|
||||
let mappings = registry.list_mappings();
|
||||
let has_test = mappings
|
||||
.iter()
|
||||
@@ -112,25 +105,21 @@ fn test_tool_call_serialization() {
|
||||
fn test_partial_json_parser() {
|
||||
let parser = PartialJson::default();
|
||||
|
||||
// Test complete JSON
|
||||
let input = r#"{"name": "test", "value": 42}"#;
|
||||
let (value, consumed) = parser.parse_value(input).unwrap();
|
||||
assert_eq!(value["name"], "test");
|
||||
assert_eq!(value["value"], 42);
|
||||
assert_eq!(consumed, input.len());
|
||||
|
||||
// Test incomplete JSON object
|
||||
let input = r#"{"name": "test", "value": "#;
|
||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||
assert_eq!(value["name"], "test");
|
||||
assert!(value["value"].is_null());
|
||||
|
||||
// Test incomplete string
|
||||
let input = r#"{"name": "tes"#;
|
||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||
assert_eq!(value["name"], "tes");
|
||||
|
||||
// Test incomplete array
|
||||
let input = r#"[1, 2, "#;
|
||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||
assert!(value.is_array());
|
||||
@@ -193,11 +182,9 @@ fn test_compute_diff() {
|
||||
|
||||
#[test]
|
||||
fn test_stream_result_variants() {
|
||||
// Test Incomplete
|
||||
let result = StreamResult::Incomplete;
|
||||
matches!(result, StreamResult::Incomplete);
|
||||
|
||||
// Test ToolName
|
||||
let result = StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: "test".to_string(),
|
||||
@@ -209,7 +196,6 @@ fn test_stream_result_variants() {
|
||||
panic!("Expected ToolName variant");
|
||||
}
|
||||
|
||||
// Test ToolComplete
|
||||
let tool = ToolCall {
|
||||
id: "123".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
@@ -255,7 +241,6 @@ fn test_partial_tool_call() {
|
||||
async fn test_json_parser_complete_single() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test single tool call with arguments
|
||||
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
@@ -269,7 +254,6 @@ async fn test_json_parser_complete_single() {
|
||||
async fn test_json_parser_complete_array() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test array of tool calls
|
||||
let input = r#"[
|
||||
{"name": "get_weather", "arguments": {"location": "SF"}},
|
||||
{"name": "get_news", "arguments": {"query": "technology"}}
|
||||
@@ -286,7 +270,6 @@ async fn test_json_parser_complete_array() {
|
||||
async fn test_json_parser_with_parameters() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test with "parameters" instead of "arguments"
|
||||
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
@@ -299,7 +282,6 @@ async fn test_json_parser_with_parameters() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_parser_with_tokens() {
|
||||
// Test with custom wrapper tokens
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["[TOOL_CALLS] [".to_string()],
|
||||
end_tokens: vec!["]".to_string()],
|
||||
@@ -315,7 +297,6 @@ async fn test_json_parser_with_tokens() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiline_json_with_tokens() {
|
||||
// Test that regex with (?s) flag properly handles multi-line JSON
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tool>".to_string()],
|
||||
end_tokens: vec!["</tool>".to_string()],
|
||||
@@ -342,7 +323,6 @@ async fn test_multiline_json_with_tokens() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiline_json_array() {
|
||||
// Test multi-line JSON array without wrapper tokens
|
||||
let parser = JsonParser::new();
|
||||
|
||||
let input = r#"[
|
||||
@@ -390,7 +370,6 @@ async fn test_json_parser_streaming() {
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Test with complete JSON
|
||||
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||
|
||||
let result = parser
|
||||
@@ -417,7 +396,6 @@ async fn test_registry_with_json_parser() {
|
||||
// Should get JSON parser for OpenAI models
|
||||
let parser = registry.get_parser("gpt-4-turbo").unwrap();
|
||||
|
||||
// Test that the parser works
|
||||
let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
@@ -677,7 +655,6 @@ mod edge_cases {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_token_pairs_with_conflicts() {
|
||||
// Test with overlapping token patterns
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<<".to_string(), "<tool>".to_string()],
|
||||
end_tokens: vec![">>".to_string(), "</tool>".to_string()],
|
||||
@@ -708,7 +685,6 @@ mod edge_cases {
|
||||
async fn test_streaming_with_partial_chunks() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test 1: Very incomplete JSON (just opening brace) should return Incomplete
|
||||
let mut state1 = ParseState::new();
|
||||
let partial = r#"{"#;
|
||||
let result = parser
|
||||
@@ -720,7 +696,6 @@ mod edge_cases {
|
||||
"Should return Incomplete for just opening brace"
|
||||
);
|
||||
|
||||
// Test 2: Complete JSON should return ToolComplete
|
||||
let mut state2 = ParseState::new();
|
||||
let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
|
||||
let result = parser
|
||||
@@ -738,7 +713,6 @@ mod edge_cases {
|
||||
_ => panic!("Expected ToolComplete for complete JSON"),
|
||||
}
|
||||
|
||||
// Test 3: Partial JSON with name
|
||||
// The PartialJson parser can complete partial JSON by filling in missing values
|
||||
let mut state3 = ParseState::new();
|
||||
let partial_with_name = r#"{"name": "test", "argum"#;
|
||||
@@ -863,7 +837,6 @@ mod stress_tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_parser_usage() {
|
||||
// Test that parser can be used concurrently
|
||||
let parser = std::sync::Arc::new(JsonParser::new());
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
@@ -679,7 +679,6 @@ mod tests {
|
||||
fn test_get_smallest_tenant() {
|
||||
let tree = Tree::new();
|
||||
|
||||
// Test empty tree
|
||||
assert_eq!(tree.get_smallest_tenant(), "empty");
|
||||
|
||||
// Insert data for tenant1 - "ap" + "icot" = 6 chars
|
||||
@@ -689,7 +688,6 @@ mod tests {
|
||||
// Insert data for tenant2 - "cat" = 3 chars
|
||||
tree.insert("cat", "tenant2");
|
||||
|
||||
// Test - tenant2 should be smallest with 3 chars vs 6 chars
|
||||
assert_eq!(
|
||||
tree.get_smallest_tenant(),
|
||||
"tenant2",
|
||||
@@ -702,7 +700,6 @@ mod tests {
|
||||
tree.insert("do", "tenant3");
|
||||
tree.insert("hi", "tenant4");
|
||||
|
||||
// Test - should return either tenant3 or tenant4 (both have 2 chars)
|
||||
let smallest = tree.get_smallest_tenant();
|
||||
assert!(
|
||||
smallest == "tenant3" || smallest == "tenant4",
|
||||
@@ -720,7 +717,6 @@ mod tests {
|
||||
"Expected tenant3 to be smallest with 2 characters"
|
||||
);
|
||||
|
||||
// Test eviction
|
||||
tree.evict_tenant_by_size(3); // This should evict tenants with more than 3 chars
|
||||
|
||||
let post_eviction_smallest = tree.get_smallest_tenant();
|
||||
@@ -731,7 +727,6 @@ mod tests {
|
||||
fn test_tenant_char_count() {
|
||||
let tree = Tree::new();
|
||||
|
||||
// Phase 1: Initial insertions
|
||||
tree.insert("apple", "tenant1");
|
||||
tree.insert("apricot", "tenant1");
|
||||
tree.insert("banana", "tenant1");
|
||||
@@ -755,7 +750,6 @@ mod tests {
|
||||
"Phase 1: Initial insertions"
|
||||
);
|
||||
|
||||
// Phase 2: Additional insertions
|
||||
tree.insert("apartment", "tenant1");
|
||||
tree.insert("appetite", "tenant2");
|
||||
tree.insert("ball", "tenant1");
|
||||
@@ -778,7 +772,6 @@ mod tests {
|
||||
"Phase 2: Additional insertions"
|
||||
);
|
||||
|
||||
// Phase 3: Overlapping insertions
|
||||
tree.insert("zebra", "tenant1");
|
||||
tree.insert("zebra", "tenant2");
|
||||
tree.insert("zero", "tenant1");
|
||||
@@ -801,7 +794,6 @@ mod tests {
|
||||
"Phase 3: Overlapping insertions"
|
||||
);
|
||||
|
||||
// Phase 4: Eviction test
|
||||
tree.evict_tenant_by_size(10);
|
||||
|
||||
let computed_sizes = tree.get_used_size_per_tenant();
|
||||
@@ -1088,8 +1080,6 @@ mod tests {
|
||||
|
||||
tree.pretty_print();
|
||||
|
||||
// Test sequentially
|
||||
|
||||
for (text, tenant) in TEST_PAIRS.iter() {
|
||||
let (matched_text, matched_tenant) = tree.prefix_match(text);
|
||||
assert_eq!(matched_text, *text);
|
||||
@@ -1162,7 +1152,6 @@ mod tests {
|
||||
|
||||
tree.pretty_print();
|
||||
|
||||
// Verify initial sizes
|
||||
let sizes_before = tree.get_used_size_per_tenant();
|
||||
assert_eq!(sizes_before.get("tenant1").unwrap(), &5); // "hello" = 5
|
||||
assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10
|
||||
@@ -1172,12 +1161,10 @@ mod tests {
|
||||
|
||||
tree.pretty_print();
|
||||
|
||||
// Verify sizes after eviction
|
||||
let sizes_after = tree.get_used_size_per_tenant();
|
||||
assert_eq!(sizes_after.get("tenant1").unwrap(), &5); // Should be unchanged
|
||||
assert_eq!(sizes_after.get("tenant2").unwrap(), &5); // Only "world" remains
|
||||
|
||||
// Verify "world" remains for tenant2
|
||||
let (matched, tenant) = tree.prefix_match("world");
|
||||
assert_eq!(matched, "world");
|
||||
assert_eq!(tenant, "tenant2");
|
||||
@@ -1208,7 +1195,6 @@ mod tests {
|
||||
|
||||
// Check sizes after eviction
|
||||
let sizes_after = tree.get_used_size_per_tenant();
|
||||
// Verify all tenants are under their size limits
|
||||
for (tenant, &size) in sizes_after.iter() {
|
||||
assert!(
|
||||
size <= max_size,
|
||||
@@ -1287,7 +1273,6 @@ mod tests {
|
||||
let final_sizes = tree.get_used_size_per_tenant();
|
||||
println!("Final sizes after test completion: {:?}", final_sizes);
|
||||
|
||||
// Verify all tenants are under limit
|
||||
for (_, &size) in final_sizes.iter() {
|
||||
assert!(
|
||||
size <= max_size,
|
||||
@@ -1364,14 +1349,12 @@ mod tests {
|
||||
tree.insert("help", "tenant1"); // tenant1: hel -> p
|
||||
tree.insert("helicopter", "tenant2"); // tenant2: hel -> icopter
|
||||
|
||||
// Test tenant1's data
|
||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "hello"); // Full match for tenant1
|
||||
assert_eq!(tree.prefix_match_tenant("help", "tenant1"), "help"); // Exclusive to tenant1
|
||||
assert_eq!(tree.prefix_match_tenant("hel", "tenant1"), "hel"); // Shared prefix
|
||||
assert_eq!(tree.prefix_match_tenant("hello world", "tenant1"), "hello"); // Should stop at tenant1's boundary
|
||||
assert_eq!(tree.prefix_match_tenant("helicopter", "tenant1"), "hel"); // Should stop at tenant1's boundary
|
||||
|
||||
// Test tenant2's data
|
||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); // Full match for tenant2
|
||||
assert_eq!(
|
||||
tree.prefix_match_tenant("hello world", "tenant2"),
|
||||
@@ -1384,7 +1367,6 @@ mod tests {
|
||||
assert_eq!(tree.prefix_match_tenant("hel", "tenant2"), "hel"); // Shared prefix
|
||||
assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "hel"); // Should stop at tenant2's boundary
|
||||
|
||||
// Test non-existent tenant
|
||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant3"), ""); // Non-existent tenant
|
||||
assert_eq!(tree.prefix_match_tenant("help", "tenant3"), ""); // Non-existent tenant
|
||||
}
|
||||
@@ -1399,7 +1381,6 @@ mod tests {
|
||||
tree.insert("hello", "tenant2");
|
||||
tree.insert("help", "tenant2");
|
||||
|
||||
// Verify initial state
|
||||
let initial_sizes = tree.get_used_size_per_tenant();
|
||||
assert_eq!(initial_sizes.get("tenant1").unwrap(), &10); // "hello" + "world"
|
||||
assert_eq!(initial_sizes.get("tenant2").unwrap(), &6); // "hello" + "p"
|
||||
@@ -1407,7 +1388,6 @@ mod tests {
|
||||
// Evict tenant1
|
||||
tree.remove_tenant("tenant1");
|
||||
|
||||
// Verify after eviction
|
||||
let final_sizes = tree.get_used_size_per_tenant();
|
||||
assert!(
|
||||
!final_sizes.contains_key("tenant1"),
|
||||
@@ -1419,11 +1399,9 @@ mod tests {
|
||||
"tenant2 should be unaffected"
|
||||
);
|
||||
|
||||
// Verify tenant1's data is inaccessible
|
||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "");
|
||||
assert_eq!(tree.prefix_match_tenant("world", "tenant1"), "");
|
||||
|
||||
// Verify tenant2's data is still accessible
|
||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello");
|
||||
assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "help");
|
||||
}
|
||||
@@ -1441,7 +1419,6 @@ mod tests {
|
||||
tree.insert("banana", "tenant2");
|
||||
tree.insert("ball", "tenant2");
|
||||
|
||||
// Verify initial state
|
||||
let initial_sizes = tree.get_used_size_per_tenant();
|
||||
println!("Initial sizes: {:?}", initial_sizes);
|
||||
tree.pretty_print();
|
||||
@@ -1449,29 +1426,24 @@ mod tests {
|
||||
// Evict tenant1
|
||||
tree.remove_tenant("tenant1");
|
||||
|
||||
// Verify final state
|
||||
let final_sizes = tree.get_used_size_per_tenant();
|
||||
println!("Final sizes: {:?}", final_sizes);
|
||||
tree.pretty_print();
|
||||
|
||||
// Verify tenant1 is completely removed
|
||||
assert!(
|
||||
!final_sizes.contains_key("tenant1"),
|
||||
"tenant1 should be completely removed"
|
||||
);
|
||||
|
||||
// Verify all tenant1's data is inaccessible
|
||||
assert_eq!(tree.prefix_match_tenant("apple", "tenant1"), "");
|
||||
assert_eq!(tree.prefix_match_tenant("application", "tenant1"), "");
|
||||
assert_eq!(tree.prefix_match_tenant("banana", "tenant1"), "");
|
||||
|
||||
// Verify tenant2's data is intact
|
||||
assert_eq!(tree.prefix_match_tenant("apple", "tenant2"), "apple");
|
||||
assert_eq!(tree.prefix_match_tenant("appetite", "tenant2"), "appetite");
|
||||
assert_eq!(tree.prefix_match_tenant("banana", "tenant2"), "banana");
|
||||
assert_eq!(tree.prefix_match_tenant("ball", "tenant2"), "ball");
|
||||
|
||||
// Verify the tree structure is still valid for tenant2
|
||||
let tenant2_size = final_sizes.get("tenant2").unwrap();
|
||||
assert_eq!(tenant2_size, &(5 + 5 + 6 + 2)); // "apple" + "etite" + "banana" + "ll"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user