diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 6b24a5fd1..5e25b2c3b 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -214,83 +214,590 @@ impl RouterConfig { pub fn has_metrics(&self) -> bool { self.metrics.is_some() } +} - /* Commented out - no longer needed without compatibility layer - /// Convert to routing PolicyConfig for internal use - pub fn to_routing_policy_config(&self) -> ConfigResult { - match (&self.mode, &self.policy) { - ( - RoutingMode::PrefillDecode { - prefill_urls, - decode_urls, - }, - policy, - ) => { - // Map policy to PDSelectionPolicy - let selection_policy = match policy { - PolicyConfig::Random => crate::pd_types::PDSelectionPolicy::Random, - PolicyConfig::PowerOfTwo { .. } => { - crate::pd_types::PDSelectionPolicy::PowerOfTwo - } - PolicyConfig::CacheAware { .. } => { - return Err(ConfigError::IncompatibleConfig { - reason: "CacheAware policy is not supported in PD disaggregated mode" - .to_string(), - }); - } - PolicyConfig::RoundRobin => { - return Err(ConfigError::IncompatibleConfig { - reason: "RoundRobin policy is not supported in PD disaggregated mode" - .to_string(), - }); - } - }; +#[cfg(test)] +mod tests { + use super::*; - Ok(crate::router::PolicyConfig::PrefillDecodeConfig { - selection_policy, - prefill_urls: prefill_urls.clone(), - decode_urls: decode_urls.clone(), - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval_secs, - }) + // ============= RouterConfig Tests ============= + + #[test] + fn test_router_config_default() { + let config = RouterConfig::default(); + + assert!( + matches!(config.mode, RoutingMode::Regular { worker_urls } if worker_urls.is_empty()) + ); + assert!(matches!(config.policy, PolicyConfig::Random)); + assert_eq!(config.host, "127.0.0.1"); + assert_eq!(config.port, 3001); + assert_eq!(config.max_payload_size, 268_435_456); + assert_eq!(config.request_timeout_secs, 600); + assert_eq!(config.worker_startup_timeout_secs, 300); + assert_eq!(config.worker_startup_check_interval_secs, 10); + assert!(config.discovery.is_none()); + assert!(config.metrics.is_none()); + assert!(config.log_dir.is_none()); + assert!(config.log_level.is_none()); + } + + #[test] + fn test_router_config_new() { + let mode = RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string(), "http://worker2".to_string()], + }; + let policy = PolicyConfig::RoundRobin; + + let config = RouterConfig::new(mode, policy); + + match config.mode { + RoutingMode::Regular { worker_urls } => { + assert_eq!(worker_urls.len(), 2); + assert_eq!(worker_urls[0], "http://worker1"); + assert_eq!(worker_urls[1], "http://worker2"); } - (RoutingMode::Regular { .. }, PolicyConfig::Random) => { - Ok(crate::router::PolicyConfig::RandomConfig { - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval_secs, - }) - } - (RoutingMode::Regular { .. }, PolicyConfig::RoundRobin) => { - Ok(crate::router::PolicyConfig::RoundRobinConfig { - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval_secs, - }) - } - ( - RoutingMode::Regular { .. }, - PolicyConfig::CacheAware { - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - eviction_interval_secs, - max_tree_size, - }, - ) => Ok(crate::router::PolicyConfig::CacheAwareConfig { - cache_threshold: *cache_threshold, - balance_abs_threshold: *balance_abs_threshold, - balance_rel_threshold: *balance_rel_threshold, - eviction_interval_secs: *eviction_interval_secs, - max_tree_size: *max_tree_size, - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval_secs, - }), - (RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => { - Err(ConfigError::IncompatibleConfig { - reason: "PowerOfTwo policy is only supported in PD disaggregated mode" - .to_string(), - }) + _ => panic!("Expected Regular mode"), + } + + assert!(matches!(config.policy, PolicyConfig::RoundRobin)); + // Other fields should be default + assert_eq!(config.host, "127.0.0.1"); + assert_eq!(config.port, 3001); + } + + #[test] + fn test_router_config_serialization() { + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string()], + }, + policy: PolicyConfig::Random, + host: "0.0.0.0".to_string(), + port: 8080, + max_payload_size: 1024, + request_timeout_secs: 30, + worker_startup_timeout_secs: 60, + worker_startup_check_interval_secs: 5, + discovery: Some(DiscoveryConfig::default()), + metrics: Some(MetricsConfig::default()), + log_dir: Some("/var/log".to_string()), + log_level: Some("debug".to_string()), + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(config.host, deserialized.host); + assert_eq!(config.port, deserialized.port); + assert_eq!(config.max_payload_size, deserialized.max_payload_size); + assert!(deserialized.discovery.is_some()); + assert!(deserialized.metrics.is_some()); + } + + // ============= RoutingMode Tests ============= + + #[test] + fn test_routing_mode_is_pd_mode() { + let regular = RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string()], + }; + assert!(!regular.is_pd_mode()); + + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], + decode_urls: vec!["http://decode1".to_string()], + }; + assert!(pd.is_pd_mode()); + } + + #[test] + fn test_routing_mode_worker_count() { + let regular = RoutingMode::Regular { + worker_urls: vec![ + "http://worker1".to_string(), + "http://worker2".to_string(), + "http://worker3".to_string(), + ], + }; + assert_eq!(regular.worker_count(), 3); + + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![ + ("http://prefill1".to_string(), Some(8001)), + ("http://prefill2".to_string(), None), + ], + decode_urls: vec![ + "http://decode1".to_string(), + "http://decode2".to_string(), + "http://decode3".to_string(), + ], + }; + assert_eq!(pd.worker_count(), 5); + + let empty_regular = RoutingMode::Regular { + worker_urls: vec![], + }; + assert_eq!(empty_regular.worker_count(), 0); + } + + #[test] + fn test_routing_mode_serialization() { + // Test Regular mode + let regular = RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string()], + }; + let json = serde_json::to_string(®ular).unwrap(); + assert!(json.contains("\"type\":\"regular\"")); + assert!(json.contains("\"worker_urls\"")); + + // Test PrefillDecode mode + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], + decode_urls: vec!["http://decode1".to_string()], + }; + let json = serde_json::to_string(&pd).unwrap(); + assert!(json.contains("\"type\":\"prefill_decode\"")); + assert!(json.contains("\"prefill_urls\"")); + assert!(json.contains("\"decode_urls\"")); + } + + // ============= PolicyConfig Tests ============= + + #[test] + fn test_policy_config_name() { + assert_eq!(PolicyConfig::Random.name(), "random"); + assert_eq!(PolicyConfig::RoundRobin.name(), "round_robin"); + + let cache_aware = PolicyConfig::CacheAware { + cache_threshold: 0.8, + balance_abs_threshold: 10, + balance_rel_threshold: 1.5, + eviction_interval_secs: 300, + max_tree_size: 1000, + }; + assert_eq!(cache_aware.name(), "cache_aware"); + + let power_of_two = PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }; + assert_eq!(power_of_two.name(), "power_of_two"); + } + + #[test] + fn test_policy_config_serialization() { + // Test Random + let random = PolicyConfig::Random; + let json = serde_json::to_string(&random).unwrap(); + assert_eq!(json, r#"{"type":"random"}"#); + + // Test CacheAware with all parameters + let cache_aware = PolicyConfig::CacheAware { + cache_threshold: 0.8, + balance_abs_threshold: 10, + balance_rel_threshold: 1.5, + eviction_interval_secs: 300, + max_tree_size: 1000, + }; + let json = serde_json::to_string(&cache_aware).unwrap(); + assert!(json.contains("\"type\":\"cache_aware\"")); + assert!(json.contains("\"cache_threshold\":0.8")); + assert!(json.contains("\"balance_abs_threshold\":10")); + + // Test PowerOfTwo + let power_of_two = PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }; + let json = serde_json::to_string(&power_of_two).unwrap(); + assert!(json.contains("\"type\":\"power_of_two\"")); + assert!(json.contains("\"load_check_interval_secs\":60")); + } + + #[test] + fn test_cache_aware_parameters() { + let cache_aware = PolicyConfig::CacheAware { + cache_threshold: 0.75, + balance_abs_threshold: 20, + balance_rel_threshold: 2.0, + eviction_interval_secs: 600, + max_tree_size: 5000, + }; + + match cache_aware { + PolicyConfig::CacheAware { + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + eviction_interval_secs, + max_tree_size, + } => { + assert!((cache_threshold - 0.75).abs() < 0.0001); + assert_eq!(balance_abs_threshold, 20); + assert!((balance_rel_threshold - 2.0).abs() < 0.0001); + assert_eq!(eviction_interval_secs, 600); + assert_eq!(max_tree_size, 5000); } + _ => panic!("Expected CacheAware"), } } - */ + + #[test] + fn test_power_of_two_parameters() { + let power_of_two = PolicyConfig::PowerOfTwo { + load_check_interval_secs: 120, + }; + + match power_of_two { + PolicyConfig::PowerOfTwo { + load_check_interval_secs, + } => { + assert_eq!(load_check_interval_secs, 120); + } + _ => panic!("Expected PowerOfTwo"), + } + } + + // ============= DiscoveryConfig Tests ============= + + #[test] + fn test_discovery_config_default() { + let config = DiscoveryConfig::default(); + + assert!(!config.enabled); + assert!(config.namespace.is_none()); + assert_eq!(config.port, 8000); + assert_eq!(config.check_interval_secs, 60); + assert!(config.selector.is_empty()); + assert!(config.prefill_selector.is_empty()); + assert!(config.decode_selector.is_empty()); + assert_eq!(config.bootstrap_port_annotation, "sglang.ai/bootstrap-port"); + } + + #[test] + fn test_discovery_config_with_selectors() { + let mut selector = HashMap::new(); + selector.insert("app".to_string(), "sglang".to_string()); + selector.insert("role".to_string(), "worker".to_string()); + + let config = DiscoveryConfig { + enabled: true, + namespace: Some("default".to_string()), + port: 9000, + check_interval_secs: 30, + selector: selector.clone(), + prefill_selector: selector.clone(), + decode_selector: selector.clone(), + bootstrap_port_annotation: "custom.io/port".to_string(), + }; + + assert!(config.enabled); + assert_eq!(config.namespace, Some("default".to_string())); + assert_eq!(config.port, 9000); + assert_eq!(config.selector.len(), 2); + assert_eq!(config.selector.get("app"), Some(&"sglang".to_string())); + } + + #[test] + fn test_discovery_config_namespace() { + // Test None namespace (all namespaces) + let config = DiscoveryConfig { + namespace: None, + ..Default::default() + }; + assert!(config.namespace.is_none()); + + // Test specific namespace + let config = DiscoveryConfig { + namespace: Some("production".to_string()), + ..Default::default() + }; + assert_eq!(config.namespace, Some("production".to_string())); + } + + // ============= MetricsConfig Tests ============= + + #[test] + fn test_metrics_config_default() { + let config = MetricsConfig::default(); + + assert_eq!(config.port, 29000); + assert_eq!(config.host, "127.0.0.1"); + } + + #[test] + fn test_metrics_config_custom() { + let config = MetricsConfig { + port: 9090, + host: "0.0.0.0".to_string(), + }; + + assert_eq!(config.port, 9090); + assert_eq!(config.host, "0.0.0.0"); + } + + // ============= RouterConfig Utility Methods Tests ============= + + #[test] + fn test_mode_type() { + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![], + }, + ..Default::default() + }; + assert_eq!(config.mode_type(), "regular"); + + let config = RouterConfig { + mode: RoutingMode::PrefillDecode { + prefill_urls: vec![], + decode_urls: vec![], + }, + ..Default::default() + }; + assert_eq!(config.mode_type(), "prefill_decode"); + } + + #[test] + fn test_has_service_discovery() { + let config = RouterConfig::default(); + assert!(!config.has_service_discovery()); + + let config = RouterConfig { + discovery: Some(DiscoveryConfig { + enabled: false, + ..Default::default() + }), + ..Default::default() + }; + assert!(!config.has_service_discovery()); + + let config = RouterConfig { + discovery: Some(DiscoveryConfig { + enabled: true, + ..Default::default() + }), + ..Default::default() + }; + assert!(config.has_service_discovery()); + } + + #[test] + fn test_has_metrics() { + let config = RouterConfig::default(); + assert!(!config.has_metrics()); + + let config = RouterConfig { + metrics: Some(MetricsConfig::default()), + ..Default::default() + }; + assert!(config.has_metrics()); + } + + // ============= Edge Cases ============= + + #[test] + fn test_large_worker_lists() { + let large_urls: Vec = (0..1000).map(|i| format!("http://worker{}", i)).collect(); + + let mode = RoutingMode::Regular { + worker_urls: large_urls.clone(), + }; + + assert_eq!(mode.worker_count(), 1000); + + // Test serialization with large list + let config = RouterConfig { + mode, + ..Default::default() + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); + + match deserialized.mode { + RoutingMode::Regular { worker_urls } => { + assert_eq!(worker_urls.len(), 1000); + } + _ => panic!("Expected Regular mode"), + } + } + + #[test] + fn test_unicode_in_config() { + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec!["http://работник1".to_string(), "http://工作者2".to_string()], + }, + log_dir: Some("/日志/目录".to_string()), + ..Default::default() + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); + + match deserialized.mode { + RoutingMode::Regular { worker_urls } => { + assert_eq!(worker_urls[0], "http://работник1"); + assert_eq!(worker_urls[1], "http://工作者2"); + } + _ => panic!("Expected Regular mode"), + } + + assert_eq!(deserialized.log_dir, Some("/日志/目录".to_string())); + } + + #[test] + fn test_empty_string_fields() { + let config = RouterConfig { + host: "".to_string(), + log_dir: Some("".to_string()), + log_level: Some("".to_string()), + ..Default::default() + }; + + assert_eq!(config.host, ""); + assert_eq!(config.log_dir, Some("".to_string())); + assert_eq!(config.log_level, Some("".to_string())); + } + + // ============= Complex Configuration Tests ============= + + #[test] + fn test_full_pd_mode_config() { + let config = RouterConfig { + mode: RoutingMode::PrefillDecode { + prefill_urls: vec![ + ("http://prefill1:8000".to_string(), Some(8001)), + ("http://prefill2:8000".to_string(), None), + ], + decode_urls: vec![ + "http://decode1:8000".to_string(), + "http://decode2:8000".to_string(), + ], + }, + policy: PolicyConfig::PowerOfTwo { + load_check_interval_secs: 30, + }, + host: "0.0.0.0".to_string(), + port: 3000, + max_payload_size: 1048576, + request_timeout_secs: 120, + worker_startup_timeout_secs: 60, + worker_startup_check_interval_secs: 5, + discovery: Some(DiscoveryConfig { + enabled: true, + namespace: Some("sglang".to_string()), + ..Default::default() + }), + metrics: Some(MetricsConfig { + port: 9090, + host: "0.0.0.0".to_string(), + }), + log_dir: Some("/var/log/sglang".to_string()), + log_level: Some("info".to_string()), + }; + + assert!(config.mode.is_pd_mode()); + assert_eq!(config.mode.worker_count(), 4); + assert_eq!(config.policy.name(), "power_of_two"); + assert!(config.has_service_discovery()); + assert!(config.has_metrics()); + } + + #[test] + fn test_full_regular_mode_config() { + let mut selector = HashMap::new(); + selector.insert("app".to_string(), "sglang".to_string()); + + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![ + "http://worker1:8000".to_string(), + "http://worker2:8000".to_string(), + "http://worker3:8000".to_string(), + ], + }, + policy: PolicyConfig::CacheAware { + cache_threshold: 0.9, + balance_abs_threshold: 5, + balance_rel_threshold: 1.2, + eviction_interval_secs: 600, + max_tree_size: 10000, + }, + host: "0.0.0.0".to_string(), + port: 3001, + max_payload_size: 536870912, + request_timeout_secs: 300, + worker_startup_timeout_secs: 180, + worker_startup_check_interval_secs: 15, + discovery: Some(DiscoveryConfig { + enabled: true, + namespace: None, + port: 8080, + check_interval_secs: 45, + selector, + ..Default::default() + }), + metrics: Some(MetricsConfig::default()), + log_dir: None, + log_level: Some("debug".to_string()), + }; + + assert!(!config.mode.is_pd_mode()); + assert_eq!(config.mode.worker_count(), 3); + assert_eq!(config.policy.name(), "cache_aware"); + assert!(config.has_service_discovery()); + assert!(config.has_metrics()); + } + + #[test] + fn test_config_with_all_options() { + let mut selectors = HashMap::new(); + selectors.insert("env".to_string(), "prod".to_string()); + selectors.insert("version".to_string(), "v1".to_string()); + + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string()], + }, + policy: PolicyConfig::RoundRobin, + host: "::1".to_string(), // IPv6 + port: 8888, + max_payload_size: 1024 * 1024 * 512, // 512MB + request_timeout_secs: 900, + worker_startup_timeout_secs: 600, + worker_startup_check_interval_secs: 20, + discovery: Some(DiscoveryConfig { + enabled: true, + namespace: Some("production".to_string()), + port: 8443, + check_interval_secs: 120, + selector: selectors.clone(), + prefill_selector: selectors.clone(), + decode_selector: selectors, + bootstrap_port_annotation: "mycompany.io/bootstrap".to_string(), + }), + metrics: Some(MetricsConfig { + port: 9999, + host: "::".to_string(), // IPv6 any + }), + log_dir: Some("/opt/logs/sglang".to_string()), + log_level: Some("trace".to_string()), + }; + + assert!(config.has_service_discovery()); + assert!(config.has_metrics()); + assert_eq!(config.mode_type(), "regular"); + + // Test round-trip serialization + let json = serde_json::to_string_pretty(&config).unwrap(); + let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.host, "::1"); + assert_eq!(deserialized.port, 8888); + assert_eq!( + deserialized.discovery.unwrap().namespace, + Some("production".to_string()) + ); + } } diff --git a/sgl-router/src/metrics.rs b/sgl-router/src/metrics.rs index 76e952a03..78a06de44 100644 --- a/sgl-router/src/metrics.rs +++ b/sgl-router/src/metrics.rs @@ -322,3 +322,414 @@ impl RouterMetrics { .set(count as f64); } } + +#[cfg(test)] +mod tests { + use super::*; + use std::net::TcpListener; + + // ============= PrometheusConfig Tests ============= + + #[test] + fn test_prometheus_config_default() { + let config = PrometheusConfig::default(); + assert_eq!(config.port, 29000); + assert_eq!(config.host, "0.0.0.0"); + } + + #[test] + fn test_prometheus_config_custom() { + let config = PrometheusConfig { + port: 8080, + host: "127.0.0.1".to_string(), + }; + assert_eq!(config.port, 8080); + assert_eq!(config.host, "127.0.0.1"); + } + + #[test] + fn test_prometheus_config_clone() { + let config = PrometheusConfig { + port: 9090, + host: "192.168.1.1".to_string(), + }; + let cloned = config.clone(); + assert_eq!(cloned.port, config.port); + assert_eq!(cloned.host, config.host); + } + + // ============= IP Address Parsing Tests ============= + + #[test] + fn test_valid_ipv4_parsing() { + let test_cases = vec!["127.0.0.1", "192.168.1.1", "0.0.0.0"]; + + for ip_str in test_cases { + let config = PrometheusConfig { + port: 29000, + host: ip_str.to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + assert!(matches!(ip_addr, IpAddr::V4(_))); + } + } + + #[test] + fn test_valid_ipv6_parsing() { + let test_cases = vec!["::1", "2001:db8::1", "::"]; + + for ip_str in test_cases { + let config = PrometheusConfig { + port: 29000, + host: ip_str.to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + assert!(matches!(ip_addr, IpAddr::V6(_))); + } + } + + #[test] + fn test_invalid_ip_parsing() { + let test_cases = vec!["invalid", "256.256.256.256", "hostname"]; + + for ip_str in test_cases { + let config = PrometheusConfig { + port: 29000, + host: ip_str.to_string(), + }; + + let ip_addr: IpAddr = config + .host + .parse() + .unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); + + // Should fall back to 0.0.0.0 + assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); + } + } + + // ============= Socket Address Creation Tests ============= + + #[test] + fn test_socket_addr_creation() { + let test_cases = vec![("127.0.0.1", 8080), ("0.0.0.0", 29000), ("::1", 9090)]; + + for (host, port) in test_cases { + let config = PrometheusConfig { + port, + host: host.to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + let socket_addr = SocketAddr::new(ip_addr, config.port); + + assert_eq!(socket_addr.port(), port); + assert_eq!(socket_addr.ip().to_string(), host); + } + } + + #[test] + fn test_socket_addr_with_different_ports() { + let ports = vec![0, 80, 8080, 65535]; + + for port in ports { + let config = PrometheusConfig { + port, + host: "127.0.0.1".to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + let socket_addr = SocketAddr::new(ip_addr, config.port); + + assert_eq!(socket_addr.port(), port); + } + } + + // ============= Duration Bucket Tests ============= + + #[test] + fn test_duration_bucket_values() { + let expected_buckets = vec![ + 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, + 60.0, 90.0, 120.0, 180.0, 240.0, + ]; + + // The buckets are defined in start_prometheus function + assert_eq!(expected_buckets.len(), 20); + + // Verify proper ordering + for i in 1..expected_buckets.len() { + assert!(expected_buckets[i] > expected_buckets[i - 1]); + } + } + + #[test] + fn test_duration_bucket_coverage() { + let test_cases = vec![ + (0.0005, "sub-millisecond"), + (0.005, "5ms"), + (0.05, "50ms"), + (1.0, "1s"), + (10.0, "10s"), + (60.0, "1m"), + (240.0, "4m"), + ]; + + let buckets = vec![ + 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, + 60.0, 90.0, 120.0, 180.0, 240.0, + ]; + + for (duration, label) in test_cases { + let bucket_found = buckets + .iter() + .any(|&b| ((b - duration) as f64).abs() < 0.0001 || b > duration); + assert!(bucket_found, "No bucket found for {} ({})", duration, label); + } + } + + // ============= Matcher Configuration Tests ============= + + #[test] + fn test_duration_suffix_matcher() { + let matcher = Matcher::Suffix(String::from("duration_seconds")); + + // Test matching behavior + let _matching_metrics = vec![ + "request_duration_seconds", + "response_duration_seconds", + "sgl_router_request_duration_seconds", + ]; + + let _non_matching_metrics = + vec!["duration_total", "duration_seconds_total", "other_metric"]; + + // Note: We can't directly test Matcher matching without the internals, + // but we can verify the matcher is created correctly + match matcher { + Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"), + _ => panic!("Expected Suffix matcher"), + } + } + + // ============= Builder Configuration Tests ============= + + #[test] + fn test_prometheus_builder_configuration() { + // This test verifies the builder configuration without actually starting Prometheus + let _config = PrometheusConfig::default(); + + let duration_matcher = Matcher::Suffix(String::from("duration_seconds")); + let duration_bucket = [ + 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, + 60.0, 90.0, 120.0, 180.0, 240.0, + ]; + + // Verify bucket configuration + assert_eq!(duration_bucket.len(), 20); + + // Verify matcher is suffix type + match duration_matcher { + Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"), + _ => panic!("Expected Suffix matcher"), + } + } + + // ============= Upkeep Timeout Tests ============= + + #[test] + fn test_upkeep_timeout_duration() { + let timeout = Duration::from_secs(5 * 60); + assert_eq!(timeout.as_secs(), 300); + } + + // ============= Custom Bucket Tests ============= + + #[test] + fn test_custom_buckets_for_different_metrics() { + // Test that we can create different bucket configurations + let request_buckets = vec![0.001, 0.01, 0.1, 1.0, 10.0]; + let generate_buckets = vec![0.1, 0.5, 1.0, 5.0, 30.0, 60.0]; + + assert_eq!(request_buckets.len(), 5); + assert_eq!(generate_buckets.len(), 6); + + // Verify each set is sorted + for i in 1..request_buckets.len() { + assert!(request_buckets[i] > request_buckets[i - 1]); + } + + for i in 1..generate_buckets.len() { + assert!(generate_buckets[i] > generate_buckets[i - 1]); + } + } + + // ============= RouterMetrics Tests ============= + + #[test] + fn test_metrics_static_methods() { + // Test that all static methods can be called without panic + RouterMetrics::record_request("/generate"); + RouterMetrics::record_request_duration("/generate", Duration::from_millis(100)); + RouterMetrics::record_request_error("/generate", "timeout"); + RouterMetrics::record_retry("/generate"); + + RouterMetrics::set_active_workers(5); + RouterMetrics::set_worker_health("http://worker1", true); + RouterMetrics::set_worker_load("http://worker1", 10); + RouterMetrics::record_processed_request("http://worker1"); + + RouterMetrics::record_policy_decision("random", "http://worker1"); + RouterMetrics::record_cache_hit(); + RouterMetrics::record_cache_miss(); + RouterMetrics::set_tree_size("http://worker1", 1000); + RouterMetrics::record_load_balancing_event(); + RouterMetrics::set_load_range(20, 5); + + RouterMetrics::record_pd_request("/v1/chat/completions"); + RouterMetrics::record_pd_request_duration("/v1/chat/completions", Duration::from_secs(1)); + RouterMetrics::record_pd_prefill_request("http://prefill1"); + RouterMetrics::record_pd_decode_request("http://decode1"); + RouterMetrics::record_pd_error("invalid_request"); + RouterMetrics::record_pd_prefill_error("http://prefill1"); + RouterMetrics::record_pd_decode_error("http://decode1"); + RouterMetrics::record_pd_stream_error("http://decode1"); + + RouterMetrics::record_discovery_update(3, 1); + RouterMetrics::record_generate_duration(Duration::from_secs(2)); + RouterMetrics::set_running_requests("http://worker1", 15); + } + + // ============= Port Availability Tests ============= + + #[test] + fn test_port_already_in_use() { + // Skip this test if we can't bind to the port + let port = 29123; // Use a different port to avoid conflicts + + if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) { + // Port is available, we can test + let config = PrometheusConfig { + port, + host: "127.0.0.1".to_string(), + }; + + // Just verify config is created correctly + assert_eq!(config.port, port); + } + } + + // ============= Integration Test Helpers ============= + + #[test] + fn test_metrics_endpoint_accessibility() { + // This would be an integration test in practice + // Here we just verify the configuration + let config = PrometheusConfig { + port: 29000, + host: "127.0.0.1".to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + let socket_addr = SocketAddr::new(ip_addr, config.port); + + assert_eq!(socket_addr.to_string(), "127.0.0.1:29000"); + } + + #[test] + fn test_concurrent_metric_updates() { + // Test that metric updates can be called concurrently + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::thread; + + let done = Arc::new(AtomicBool::new(false)); + let mut handles = vec![]; + + for i in 0..3 { + let done_clone = done.clone(); + let handle = thread::spawn(move || { + let worker = format!("http://worker{}", i); + while !done_clone.load(Ordering::Relaxed) { + RouterMetrics::set_worker_load(&worker, i * 10); + RouterMetrics::record_processed_request(&worker); + thread::sleep(Duration::from_millis(1)); + } + }); + handles.push(handle); + } + + // Let threads run briefly + thread::sleep(Duration::from_millis(10)); + done.store(true, Ordering::Relaxed); + + // Wait for all threads + for handle in handles { + handle.join().unwrap(); + } + + // If we get here without panic, concurrent access works + assert!(true); + } + + // ============= Edge Cases Tests ============= + + #[test] + fn test_empty_string_metrics() { + // Test that empty strings don't cause issues + RouterMetrics::record_request(""); + RouterMetrics::set_worker_health("", true); + RouterMetrics::record_policy_decision("", ""); + + // If we get here without panic, empty strings are handled + assert!(true); + } + + #[test] + fn test_very_long_metric_labels() { + let long_label = "a".repeat(1000); + + RouterMetrics::record_request(&long_label); + RouterMetrics::set_worker_health(&long_label, false); + + // If we get here without panic, long labels are handled + assert!(true); + } + + #[test] + fn test_special_characters_in_labels() { + let special_labels = vec![ + "test/with/slashes", + "test-with-dashes", + "test_with_underscores", + "test.with.dots", + "test:with:colons", + ]; + + for label in special_labels { + RouterMetrics::record_request(label); + RouterMetrics::set_worker_health(label, true); + } + + // If we get here without panic, special characters are handled + assert!(true); + } + + #[test] + fn test_extreme_metric_values() { + // Test extreme values + RouterMetrics::set_active_workers(0); + RouterMetrics::set_active_workers(usize::MAX); + + RouterMetrics::set_worker_load("worker", 0); + RouterMetrics::set_worker_load("worker", usize::MAX); + + RouterMetrics::record_request_duration("route", Duration::from_nanos(1)); + RouterMetrics::record_request_duration("route", Duration::from_secs(86400)); // 24 hours + + // If we get here without panic, extreme values are handled + assert!(true); + } +} diff --git a/sgl-router/src/routers/pd_types.rs b/sgl-router/src/routers/pd_types.rs index 155274b06..e83ab5b60 100644 --- a/sgl-router/src/routers/pd_types.rs +++ b/sgl-router/src/routers/pd_types.rs @@ -58,7 +58,7 @@ pub enum PDSelectionPolicy { }, } // Bootstrap types from PDLB -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, PartialEq)] #[serde(untagged)] pub enum SingleOrBatch { Single(T), diff --git a/sgl-router/src/routers/request_adapter.rs b/sgl-router/src/routers/request_adapter.rs index f5611bbe4..201c61aa5 100644 --- a/sgl-router/src/routers/request_adapter.rs +++ b/sgl-router/src/routers/request_adapter.rs @@ -211,6 +211,7 @@ impl ToPdRequest for ChatCompletionRequest { self.temperature => "temperature", self.top_p => "top_p", self.n => "n", + self.stream_options => "stream_options", self.stop => "stop", self.max_tokens => "max_tokens", self.max_completion_tokens => "max_completion_tokens", @@ -262,3 +263,1015 @@ pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone { impl RouteableRequest for GenerateRequest {} impl RouteableRequest for CompletionRequest {} impl RouteableRequest for ChatCompletionRequest {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::openai_api_types::*; + use serde_json::json; + use std::collections::HashMap; + + // ============= GenerateRequest to_pd_request Tests ============= + + #[test] + fn test_generate_to_pd_request_with_text_only() { + let req = GenerateRequest { + text: Some("Hello world".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + // Check text field conversion + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Hello world")); + assert!(pd_req.input_ids.is_none()); + + // Check bootstrap fields are None + assert!(pd_req.bootstrap_host.is_none()); + assert!(pd_req.bootstrap_port.is_none()); + assert!(pd_req.bootstrap_room.is_none()); + + // Check stream flag + assert_eq!(pd_req.stream, false); + + // Check other fields + let other = pd_req.other.as_object().unwrap(); + assert_eq!(other.get("stream"), Some(&json!(false))); + assert_eq!(other.get("return_logprob"), Some(&json!(false))); + } + + #[test] + fn test_generate_to_pd_request_with_prompt_string() { + let req = GenerateRequest { + text: None, + prompt: Some(StringOrArray::String("Test prompt".to_string())), + input_ids: None, + stream: true, + parameters: None, + sampling_params: None, + return_logprob: true, + }; + + let pd_req = req.to_pd_request(); + + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Test prompt")); + assert!(pd_req.input_ids.is_none()); + assert_eq!(pd_req.stream, true); + + let other = pd_req.other.as_object().unwrap(); + assert_eq!(other.get("stream"), Some(&json!(true))); + assert_eq!(other.get("return_logprob"), Some(&json!(true))); + } + + #[test] + fn test_generate_to_pd_request_with_prompt_array() { + let req = GenerateRequest { + text: None, + prompt: Some(StringOrArray::Array(vec![ + "Prompt 1".to_string(), + "Prompt 2".to_string(), + "Prompt 3".to_string(), + ])), + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + match pd_req.text { + Some(SingleOrBatch::Batch(ref batch)) => { + assert_eq!(batch.len(), 3); + assert_eq!(batch[0], "Prompt 1"); + assert_eq!(batch[1], "Prompt 2"); + assert_eq!(batch[2], "Prompt 3"); + } + _ => panic!("Expected batch text"), + } + } + + #[test] + fn test_generate_to_pd_request_with_single_input_ids() { + let req = GenerateRequest { + text: None, + prompt: None, + input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])), + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + assert!(pd_req.text.is_none()); + assert!(matches!( + pd_req.input_ids, + Some(SingleOrBatch::Single(ref ids)) if ids == &vec![100, 200, 300, 400] + )); + } + + #[test] + fn test_generate_to_pd_request_with_batch_input_ids() { + let req = GenerateRequest { + text: None, + prompt: None, + input_ids: Some(InputIds::Batch(vec![ + vec![1, 2, 3], + vec![4, 5, 6, 7], + vec![8, 9], + ])), + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + match pd_req.input_ids { + Some(SingleOrBatch::Batch(ref batch)) => { + assert_eq!(batch.len(), 3); + assert_eq!(batch[0], vec![1, 2, 3]); + assert_eq!(batch[1], vec![4, 5, 6, 7]); + assert_eq!(batch[2], vec![8, 9]); + } + _ => panic!("Expected batch input_ids"), + } + } + + #[test] + fn test_generate_to_pd_request_priority_text_over_prompt() { + let req = GenerateRequest { + text: Some("SGLang text".to_string()), + prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), + input_ids: Some(InputIds::Single(vec![1, 2, 3])), + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + // text should take priority + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "SGLang text")); + assert!(pd_req.input_ids.is_none()); + } + + #[test] + fn test_generate_to_pd_request_priority_prompt_over_input_ids() { + let req = GenerateRequest { + text: None, + prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), + input_ids: Some(InputIds::Single(vec![1, 2, 3])), + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + // prompt should take priority over input_ids + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "OpenAI prompt")); + assert!(pd_req.input_ids.is_none()); + } + + #[test] + fn test_generate_to_pd_request_with_parameters() { + let params = GenerateParameters { + max_new_tokens: Some(100), + temperature: Some(0.8), + top_p: Some(0.95), + seed: Some(12345), + stop: Some(vec!["END".to_string(), "STOP".to_string()]), + repetition_penalty: Some(1.1), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Check that max_new_tokens and temperature were extracted to top level + assert_eq!(other.get("max_new_tokens"), Some(&json!(100))); + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001); + + // Check that other parameters remain under "parameters" + let params = other.get("parameters").unwrap().as_object().unwrap(); + assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001); + assert_eq!(params.get("seed"), Some(&json!(12345))); + assert_eq!(params.get("stop"), Some(&json!(vec!["END", "STOP"]))); + assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1 < 0.0001); + } + + #[test] + fn test_generate_to_pd_request_with_sampling_params() { + let sampling = SamplingParams { + max_new_tokens: Some(200), + temperature: Some(0.7), + top_p: Some(0.9), + top_k: Some(50), + frequency_penalty: Some(0.1), + presence_penalty: Some(0.2), + repetition_penalty: Some(1.05), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: Some(sampling), + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Check extracted top-level fields + assert_eq!(other.get("max_new_tokens"), Some(&json!(200))); + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001); + + // Check full sampling_params is preserved + let sampling = other.get("sampling_params").unwrap().as_object().unwrap(); + assert_eq!(sampling.get("max_new_tokens"), Some(&json!(200))); + assert!(sampling.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001); + assert!(sampling.get("top_p").unwrap().as_f64().unwrap() - 0.9 < 0.0001); + assert_eq!(sampling.get("top_k"), Some(&json!(50))); + assert!(sampling.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001); + assert!(sampling.get("presence_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001); + } + + #[test] + fn test_generate_to_pd_request_sampling_params_override_parameters() { + // When both parameters and sampling_params have max_new_tokens/temperature, + // sampling_params should take precedence (processed last) + let params = GenerateParameters { + max_new_tokens: Some(100), + temperature: Some(0.5), + ..Default::default() + }; + + let sampling = SamplingParams { + max_new_tokens: Some(200), + temperature: Some(0.9), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: Some(sampling), + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Should use values from sampling_params since they're processed last + assert_eq!(other.get("max_new_tokens"), Some(&json!(200))); + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.9 < 0.0001); + } + + #[test] + fn test_generate_to_pd_request_empty_parameters() { + let params = GenerateParameters::default(); + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Should not have parameters field if all values are None/default + assert!(!other.contains_key("parameters")); + assert!(!other.contains_key("max_new_tokens")); + assert!(!other.contains_key("temperature")); + } + + #[test] + fn test_generate_to_pd_request_all_fields() { + let params = GenerateParameters { + max_new_tokens: Some(150), + temperature: Some(0.6), + top_k: Some(40), + ..Default::default() + }; + + let sampling = SamplingParams { + max_new_tokens: Some(250), // Will override parameters + temperature: Some(0.8), // Will override parameters + presence_penalty: Some(0.1), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("Complex test".to_string()), + prompt: Some(StringOrArray::String("Ignored prompt".to_string())), + input_ids: None, + stream: true, + parameters: Some(params), + sampling_params: Some(sampling), + return_logprob: true, + }; + + let pd_req = req.to_pd_request(); + + // Verify all fields + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complex test")); + assert!(pd_req.input_ids.is_none()); + assert_eq!(pd_req.stream, true); + assert!(pd_req.bootstrap_host.is_none()); + assert!(pd_req.bootstrap_port.is_none()); + assert!(pd_req.bootstrap_room.is_none()); + + let other = pd_req.other.as_object().unwrap(); + assert_eq!(other.get("stream"), Some(&json!(true))); + assert_eq!(other.get("return_logprob"), Some(&json!(true))); + // Sampling params override parameters + assert_eq!(other.get("max_new_tokens"), Some(&json!(250))); + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001); + assert!(other.contains_key("parameters")); + assert!(other.contains_key("sampling_params")); + } + + // ============= CompletionRequest to_pd_request Tests ============= + + #[test] + fn test_completion_to_pd_request_basic() { + let req = CompletionRequest { + model: "gpt-3.5-turbo".to_string(), + prompt: StringOrArray::String("Complete this sentence".to_string()), + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + suffix: None, + }; + + let pd_req = req.to_pd_request(); + + assert!( + matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complete this sentence") + ); + assert!(pd_req.input_ids.is_none()); + assert_eq!(pd_req.stream, false); + + let other = pd_req.other.as_object().unwrap(); + assert_eq!(other.get("model"), Some(&json!("gpt-3.5-turbo"))); + assert_eq!(other.get("stream"), Some(&json!(false))); + } + + #[test] + fn test_completion_to_pd_request_array_prompt() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::Array(vec![ + "First prompt".to_string(), + "Second prompt".to_string(), + ]), + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + suffix: None, + }; + + let pd_req = req.to_pd_request(); + + match pd_req.text { + Some(SingleOrBatch::Batch(ref batch)) => { + assert_eq!(batch.len(), 2); + assert_eq!(batch[0], "First prompt"); + assert_eq!(batch[1], "Second prompt"); + } + _ => panic!("Expected batch text"), + } + } + + #[test] + fn test_completion_to_pd_request_parameter_mapping() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("test".to_string()), + max_tokens: Some(150), // -> max_new_tokens + temperature: Some(0.75), + top_p: Some(0.92), + n: Some(3), // -> best_of + stream: true, + stream_options: None, + logprobs: Some(10), // -> top_n_tokens + echo: true, // -> return_full_text + stop: Some(StringOrArray::Array(vec![ + "\\n".to_string(), + "END".to_string(), + ])), + presence_penalty: Some(0.5), // -> repetition_penalty = 1.5 + frequency_penalty: Some(0.2), + best_of: Some(5), + logit_bias: None, + user: Some("user123".to_string()), + seed: Some(42), + suffix: Some("...".to_string()), + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + let params = other.get("parameters").unwrap().as_object().unwrap(); + + // Check parameter mappings + assert_eq!(params.get("max_new_tokens"), Some(&json!(150))); + assert!(params.get("temperature").unwrap().as_f64().unwrap() - 0.75 < 0.0001); + assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.92 < 0.0001); + assert_eq!(params.get("best_of"), Some(&json!(3))); + assert_eq!(params.get("top_n_tokens"), Some(&json!(10))); + assert_eq!(params.get("return_full_text"), Some(&json!(true))); + assert_eq!(params.get("stop"), Some(&json!(vec!["\\n", "END"]))); + assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.5 < 0.0001); + assert_eq!(params.get("seed"), Some(&json!(42))); + + // Check other fields + assert_eq!(other.get("model"), Some(&json!("test"))); + assert_eq!(other.get("stream"), Some(&json!(true))); + } + + #[test] + fn test_completion_to_pd_request_stop_string() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("test".to_string()), + stop: Some(StringOrArray::String("STOP".to_string())), + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + suffix: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + let params = other.get("parameters").unwrap().as_object().unwrap(); + + // Single string stop should be converted to array + assert_eq!(params.get("stop"), Some(&json!(vec!["STOP"]))); + } + + #[test] + fn test_completion_to_pd_request_no_presence_penalty() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("test".to_string()), + presence_penalty: None, + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + suffix: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + let params = other.get("parameters").unwrap().as_object().unwrap(); + + // Should not have repetition_penalty if presence_penalty is None + assert!(!params.contains_key("repetition_penalty")); + } + + // ============= ChatCompletionRequest to_pd_request Tests ============= + + #[test] + fn test_chat_to_pd_request_basic() { + let messages = vec![ + ChatMessage::System { + role: "system".to_string(), + content: "You are a helpful assistant".to_string(), + name: None, + }, + ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Hello!".to_string()), + name: None, + }, + ]; + + let req = ChatCompletionRequest { + messages, + model: "gpt-4".to_string(), + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + user: None, + seed: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + + assert_eq!(pd_req.stream, false); + assert!(pd_req.bootstrap_host.is_none()); + assert!(pd_req.bootstrap_port.is_none()); + assert!(pd_req.bootstrap_room.is_none()); + + let other = pd_req.other.as_object().unwrap(); + assert!(other.contains_key("messages")); + assert_eq!(other.get("model"), Some(&json!("gpt-4"))); + assert_eq!(other.get("stream"), Some(&json!(false))); + + // Check messages are preserved + let messages = other.get("messages").unwrap().as_array().unwrap(); + assert_eq!(messages.len(), 2); + } + + #[test] + fn test_chat_to_pd_request_with_all_optional_fields() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Test".to_string()), + name: Some("test_user".to_string()), + }]; + + let mut logit_bias = HashMap::new(); + logit_bias.insert("50256".to_string(), -100); + + let tool = Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather info".to_string()), + parameters: json!({"type": "object"}), + }, + }; + + let req = ChatCompletionRequest { + messages, + model: "gpt-4".to_string(), + temperature: Some(0.8), + top_p: Some(0.95), + n: Some(2), + stream: true, + stream_options: Some(StreamOptions { + include_usage: Some(true), + }), + stop: Some(StringOrArray::String("\\n\\n".to_string())), + max_tokens: Some(200), + max_completion_tokens: Some(150), + presence_penalty: Some(0.1), + frequency_penalty: Some(0.2), + logit_bias: Some(logit_bias), + logprobs: true, + top_logprobs: Some(5), + user: Some("user456".to_string()), + seed: Some(12345), + response_format: Some(ResponseFormat::JsonObject), + tools: Some(vec![tool]), + tool_choice: Some(ToolChoice::Auto), + parallel_tool_calls: Some(false), + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Check all fields are preserved + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001); + assert!(other.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001); + assert_eq!(other.get("n"), Some(&json!(2))); + assert_eq!(other.get("stream"), Some(&json!(true))); + assert!(other.contains_key("stream_options")); + assert!(other.contains_key("stop")); + assert_eq!(other.get("max_tokens"), Some(&json!(200))); + assert_eq!(other.get("max_completion_tokens"), Some(&json!(150))); + assert!(other.get("presence_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001); + assert!(other.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001); + assert!(other.contains_key("logit_bias")); + assert_eq!(other.get("logprobs"), Some(&json!(true))); + assert_eq!(other.get("top_logprobs"), Some(&json!(5))); + assert_eq!(other.get("user"), Some(&json!("user456"))); + assert_eq!(other.get("seed"), Some(&json!(12345))); + assert!(other.contains_key("response_format")); + assert!(other.contains_key("tools")); + assert!(other.contains_key("tool_choice")); + assert_eq!(other.get("parallel_tool_calls"), Some(&json!(false))); + } + + #[test] + fn test_chat_to_pd_request_multimodal_content() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "What's in this image?".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: Some("high".to_string()), + }, + }, + ]), + name: None, + }]; + + let req = ChatCompletionRequest { + messages, + model: "gpt-4-vision".to_string(), + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + user: None, + seed: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Messages with multimodal content should be preserved + assert!(other.contains_key("messages")); + let messages = other.get("messages").unwrap().as_array().unwrap(); + assert_eq!(messages.len(), 1); + + // Verify the message structure is preserved + let msg = &messages[0]; + assert_eq!(msg["role"], "user"); + assert!(msg["content"].is_array()); + } + + #[test] + fn test_chat_to_pd_request_logprobs_boolean() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Test".to_string()), + name: None, + }]; + + let req = ChatCompletionRequest { + messages, + model: "test".to_string(), + logprobs: true, // Boolean logprobs flag + top_logprobs: Some(3), + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + seed: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + assert_eq!(other.get("logprobs"), Some(&json!(true))); + assert_eq!(other.get("top_logprobs"), Some(&json!(3))); + } + + #[test] + fn test_chat_to_pd_request_minimal_fields() { + let messages = vec![ChatMessage::Assistant { + role: "assistant".to_string(), + content: Some("I can help with that.".to_string()), + name: None, + tool_calls: None, + function_call: None, + }]; + + let req = ChatCompletionRequest { + messages, + model: "gpt-3.5-turbo".to_string(), + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + user: None, + seed: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Should only have required fields + assert!(other.contains_key("messages")); + assert!(other.contains_key("model")); + assert!(other.contains_key("stream")); + + // Optional fields should not be present + assert!(!other.contains_key("temperature")); + assert!(!other.contains_key("top_p")); + assert!(!other.contains_key("max_tokens")); + assert!(!other.contains_key("stop")); + } + + #[test] + fn test_routeable_request_to_json() { + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let json = req.to_json().unwrap(); + assert_eq!(json["text"], "test"); + assert_eq!(json["stream"], false); + } + + // ============= Macro Tests ============= + + #[test] + fn test_insert_if_some_macro() { + let mut map = serde_json::Map::new(); + + let some_value: Option = Some(42); + let none_value: Option = None; + + insert_if_some!(map, + some_value => "present", + none_value => "absent" + ); + + assert_eq!(map.get("present"), Some(&json!(42))); + assert!(!map.contains_key("absent")); + } + + #[test] + fn test_insert_value_macro() { + let mut map = serde_json::Map::new(); + + let value1 = "test"; + let value2 = 42; + + insert_value!(map, + value1 => "string_field", + value2 => "int_field" + ); + + assert_eq!(map.get("string_field"), Some(&json!("test"))); + assert_eq!(map.get("int_field"), Some(&json!(42))); + } + + // ============= Edge Cases and Error Handling ============= + + #[test] + fn test_null_value_handling() { + let params = GenerateParameters { + max_new_tokens: None, + temperature: None, + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Should not have parameters field if all fields are None + assert!(!other.contains_key("parameters")); + } + + #[test] + fn test_large_batch_conversion() { + let large_batch: Vec = (0..1000).map(|i| format!("item_{}", i)).collect(); + + let req = GenerateRequest { + text: None, + prompt: Some(StringOrArray::Array(large_batch.clone())), + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + if let Some(SingleOrBatch::Batch(batch)) = pd_req.text { + assert_eq!(batch.len(), 1000); + assert_eq!(batch[0], "item_0"); + assert_eq!(batch[999], "item_999"); + } else { + panic!("Expected batch text"); + } + } + + #[test] + fn test_unicode_string_handling() { + let unicode_text = "Hello 世界 🌍 नमस्ते мир".to_string(); + + let req = GenerateRequest { + text: Some(unicode_text.clone()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + if let Some(SingleOrBatch::Single(text)) = pd_req.text { + assert_eq!(text, unicode_text); + } else { + panic!("Expected single text"); + } + } + + #[test] + fn test_deeply_nested_parameters() { + let mut nested_params = serde_json::Map::new(); + nested_params.insert( + "nested".to_string(), + json!({ + "level1": { + "level2": { + "level3": "value" + } + } + }), + ); + + let params = GenerateParameters { + max_new_tokens: Some(100), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Parameters should be preserved even with nested structures + assert!(other.contains_key("max_new_tokens")); + } + + // ============= Bootstrap Field Tests ============= + + #[test] + fn test_bootstrap_fields_none() { + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + assert_eq!(pd_req.bootstrap_host, None); + assert_eq!(pd_req.bootstrap_port, None); + assert_eq!(pd_req.bootstrap_room, None); + } +}