[router] remove old/oudated/useless comments across code base (#10968)
This commit is contained in:
@@ -205,7 +205,6 @@ impl RoutingMode {
|
|||||||
decode_urls,
|
decode_urls,
|
||||||
..
|
..
|
||||||
} => prefill_urls.len() + decode_urls.len(),
|
} => prefill_urls.len() + decode_urls.len(),
|
||||||
// OpenAI mode represents a single upstream
|
|
||||||
RoutingMode::OpenAI { .. } => 1,
|
RoutingMode::OpenAI { .. } => 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -515,8 +514,6 @@ impl RouterConfig {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
// ============= RouterConfig Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_router_config_default() {
|
fn test_router_config_default() {
|
||||||
let config = RouterConfig::default();
|
let config = RouterConfig::default();
|
||||||
@@ -556,7 +553,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert!(matches!(config.policy, PolicyConfig::RoundRobin));
|
assert!(matches!(config.policy, PolicyConfig::RoundRobin));
|
||||||
// Other fields should be default
|
|
||||||
assert_eq!(config.host, "127.0.0.1");
|
assert_eq!(config.host, "127.0.0.1");
|
||||||
assert_eq!(config.port, 3001);
|
assert_eq!(config.port, 3001);
|
||||||
}
|
}
|
||||||
@@ -583,13 +579,10 @@ mod tests {
|
|||||||
assert_eq!(config.max_payload_size, deserialized.max_payload_size);
|
assert_eq!(config.max_payload_size, deserialized.max_payload_size);
|
||||||
assert_eq!(config.log_dir, deserialized.log_dir);
|
assert_eq!(config.log_dir, deserialized.log_dir);
|
||||||
assert_eq!(config.log_level, deserialized.log_level);
|
assert_eq!(config.log_level, deserialized.log_level);
|
||||||
// discovery and metrics are None in Default implementation
|
|
||||||
assert!(deserialized.discovery.is_none());
|
assert!(deserialized.discovery.is_none());
|
||||||
assert!(deserialized.metrics.is_none());
|
assert!(deserialized.metrics.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= RoutingMode Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_routing_mode_is_pd_mode() {
|
fn test_routing_mode_is_pd_mode() {
|
||||||
let regular = RoutingMode::Regular {
|
let regular = RoutingMode::Regular {
|
||||||
@@ -640,7 +633,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_routing_mode_serialization() {
|
fn test_routing_mode_serialization() {
|
||||||
// Test Regular mode
|
|
||||||
let regular = RoutingMode::Regular {
|
let regular = RoutingMode::Regular {
|
||||||
worker_urls: vec!["http://worker1".to_string()],
|
worker_urls: vec!["http://worker1".to_string()],
|
||||||
};
|
};
|
||||||
@@ -648,7 +640,6 @@ mod tests {
|
|||||||
assert!(json.contains("\"type\":\"regular\""));
|
assert!(json.contains("\"type\":\"regular\""));
|
||||||
assert!(json.contains("\"worker_urls\""));
|
assert!(json.contains("\"worker_urls\""));
|
||||||
|
|
||||||
// Test PrefillDecode mode
|
|
||||||
let pd = RoutingMode::PrefillDecode {
|
let pd = RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
|
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
|
||||||
decode_urls: vec!["http://decode1".to_string()],
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
@@ -661,8 +652,6 @@ mod tests {
|
|||||||
assert!(json.contains("\"decode_urls\""));
|
assert!(json.contains("\"decode_urls\""));
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= PolicyConfig Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_policy_config_name() {
|
fn test_policy_config_name() {
|
||||||
assert_eq!(PolicyConfig::Random.name(), "random");
|
assert_eq!(PolicyConfig::Random.name(), "random");
|
||||||
@@ -685,12 +674,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_policy_config_serialization() {
|
fn test_policy_config_serialization() {
|
||||||
// Test Random
|
|
||||||
let random = PolicyConfig::Random;
|
let random = PolicyConfig::Random;
|
||||||
let json = serde_json::to_string(&random).unwrap();
|
let json = serde_json::to_string(&random).unwrap();
|
||||||
assert_eq!(json, r#"{"type":"random"}"#);
|
assert_eq!(json, r#"{"type":"random"}"#);
|
||||||
|
|
||||||
// Test CacheAware with all parameters
|
|
||||||
let cache_aware = PolicyConfig::CacheAware {
|
let cache_aware = PolicyConfig::CacheAware {
|
||||||
cache_threshold: 0.8,
|
cache_threshold: 0.8,
|
||||||
balance_abs_threshold: 10,
|
balance_abs_threshold: 10,
|
||||||
@@ -703,7 +690,6 @@ mod tests {
|
|||||||
assert!(json.contains("\"cache_threshold\":0.8"));
|
assert!(json.contains("\"cache_threshold\":0.8"));
|
||||||
assert!(json.contains("\"balance_abs_threshold\":10"));
|
assert!(json.contains("\"balance_abs_threshold\":10"));
|
||||||
|
|
||||||
// Test PowerOfTwo
|
|
||||||
let power_of_two = PolicyConfig::PowerOfTwo {
|
let power_of_two = PolicyConfig::PowerOfTwo {
|
||||||
load_check_interval_secs: 60,
|
load_check_interval_secs: 60,
|
||||||
};
|
};
|
||||||
@@ -756,8 +742,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= DiscoveryConfig Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_discovery_config_default() {
|
fn test_discovery_config_default() {
|
||||||
let config = DiscoveryConfig::default();
|
let config = DiscoveryConfig::default();
|
||||||
@@ -798,14 +782,12 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_discovery_config_namespace() {
|
fn test_discovery_config_namespace() {
|
||||||
// Test None namespace (all namespaces)
|
|
||||||
let config = DiscoveryConfig {
|
let config = DiscoveryConfig {
|
||||||
namespace: None,
|
namespace: None,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
assert!(config.namespace.is_none());
|
assert!(config.namespace.is_none());
|
||||||
|
|
||||||
// Test specific namespace
|
|
||||||
let config = DiscoveryConfig {
|
let config = DiscoveryConfig {
|
||||||
namespace: Some("production".to_string()),
|
namespace: Some("production".to_string()),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -813,8 +795,6 @@ mod tests {
|
|||||||
assert_eq!(config.namespace, Some("production".to_string()));
|
assert_eq!(config.namespace, Some("production".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= MetricsConfig Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_metrics_config_default() {
|
fn test_metrics_config_default() {
|
||||||
let config = MetricsConfig::default();
|
let config = MetricsConfig::default();
|
||||||
@@ -834,8 +814,6 @@ mod tests {
|
|||||||
assert_eq!(config.host, "0.0.0.0");
|
assert_eq!(config.host, "0.0.0.0");
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= RouterConfig Utility Methods Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mode_type() {
|
fn test_mode_type() {
|
||||||
let config = RouterConfig {
|
let config = RouterConfig {
|
||||||
@@ -894,8 +872,6 @@ mod tests {
|
|||||||
assert!(config.has_metrics());
|
assert!(config.has_metrics());
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Edge Cases =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_large_worker_lists() {
|
fn test_large_worker_lists() {
|
||||||
let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect();
|
let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect();
|
||||||
@@ -906,7 +882,6 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(mode.worker_count(), 1000);
|
assert_eq!(mode.worker_count(), 1000);
|
||||||
|
|
||||||
// Test serialization with large list
|
|
||||||
let config = RouterConfig {
|
let config = RouterConfig {
|
||||||
mode,
|
mode,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -961,8 +936,6 @@ mod tests {
|
|||||||
assert_eq!(config.log_level, Some("".to_string()));
|
assert_eq!(config.log_level, Some("".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Complex Configuration Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_full_pd_mode_config() {
|
fn test_full_pd_mode_config() {
|
||||||
let config = RouterConfig {
|
let config = RouterConfig {
|
||||||
@@ -1149,7 +1122,6 @@ mod tests {
|
|||||||
assert!(config.has_metrics());
|
assert!(config.has_metrics());
|
||||||
assert_eq!(config.mode_type(), "regular");
|
assert_eq!(config.mode_type(), "regular");
|
||||||
|
|
||||||
// Test round-trip serialization
|
|
||||||
let json = serde_json::to_string_pretty(&config).unwrap();
|
let json = serde_json::to_string_pretty(&config).unwrap();
|
||||||
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
|
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
|
||||||
@@ -1161,11 +1133,8 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Policy Fallback Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pd_policy_fallback_both_specified() {
|
fn test_pd_policy_fallback_both_specified() {
|
||||||
// When both prefill and decode policies are specified, they should be used
|
|
||||||
let pd = RoutingMode::PrefillDecode {
|
let pd = RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||||
decode_urls: vec!["http://decode1".to_string()],
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
@@ -1183,21 +1152,19 @@ mod tests {
|
|||||||
|
|
||||||
let main_policy = PolicyConfig::Random;
|
let main_policy = PolicyConfig::Random;
|
||||||
|
|
||||||
// Both specific policies should be used
|
|
||||||
match pd.get_prefill_policy(&main_policy) {
|
match pd.get_prefill_policy(&main_policy) {
|
||||||
PolicyConfig::CacheAware { .. } => {} // Success
|
PolicyConfig::CacheAware { .. } => {}
|
||||||
_ => panic!("Expected CacheAware for prefill"),
|
_ => panic!("Expected CacheAware for prefill"),
|
||||||
}
|
}
|
||||||
|
|
||||||
match pd.get_decode_policy(&main_policy) {
|
match pd.get_decode_policy(&main_policy) {
|
||||||
PolicyConfig::PowerOfTwo { .. } => {} // Success
|
PolicyConfig::PowerOfTwo { .. } => {}
|
||||||
_ => panic!("Expected PowerOfTwo for decode"),
|
_ => panic!("Expected PowerOfTwo for decode"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pd_policy_fallback_only_prefill() {
|
fn test_pd_policy_fallback_only_prefill() {
|
||||||
// When only prefill policy is specified, decode should use main policy
|
|
||||||
let pd = RoutingMode::PrefillDecode {
|
let pd = RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||||
decode_urls: vec!["http://decode1".to_string()],
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
@@ -1213,22 +1180,19 @@ mod tests {
|
|||||||
|
|
||||||
let main_policy = PolicyConfig::RoundRobin;
|
let main_policy = PolicyConfig::RoundRobin;
|
||||||
|
|
||||||
// Prefill should use specific policy
|
|
||||||
match pd.get_prefill_policy(&main_policy) {
|
match pd.get_prefill_policy(&main_policy) {
|
||||||
PolicyConfig::CacheAware { .. } => {} // Success
|
PolicyConfig::CacheAware { .. } => {}
|
||||||
_ => panic!("Expected CacheAware for prefill"),
|
_ => panic!("Expected CacheAware for prefill"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode should fall back to main policy
|
|
||||||
match pd.get_decode_policy(&main_policy) {
|
match pd.get_decode_policy(&main_policy) {
|
||||||
PolicyConfig::RoundRobin => {} // Success
|
PolicyConfig::RoundRobin => {}
|
||||||
_ => panic!("Expected RoundRobin for decode"),
|
_ => panic!("Expected RoundRobin for decode"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pd_policy_fallback_only_decode() {
|
fn test_pd_policy_fallback_only_decode() {
|
||||||
// When only decode policy is specified, prefill should use main policy
|
|
||||||
let pd = RoutingMode::PrefillDecode {
|
let pd = RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||||
decode_urls: vec!["http://decode1".to_string()],
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
@@ -1240,22 +1204,19 @@ mod tests {
|
|||||||
|
|
||||||
let main_policy = PolicyConfig::Random;
|
let main_policy = PolicyConfig::Random;
|
||||||
|
|
||||||
// Prefill should fall back to main policy
|
|
||||||
match pd.get_prefill_policy(&main_policy) {
|
match pd.get_prefill_policy(&main_policy) {
|
||||||
PolicyConfig::Random => {} // Success
|
PolicyConfig::Random => {}
|
||||||
_ => panic!("Expected Random for prefill"),
|
_ => panic!("Expected Random for prefill"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode should use specific policy
|
|
||||||
match pd.get_decode_policy(&main_policy) {
|
match pd.get_decode_policy(&main_policy) {
|
||||||
PolicyConfig::PowerOfTwo { .. } => {} // Success
|
PolicyConfig::PowerOfTwo { .. } => {}
|
||||||
_ => panic!("Expected PowerOfTwo for decode"),
|
_ => panic!("Expected PowerOfTwo for decode"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pd_policy_fallback_none_specified() {
|
fn test_pd_policy_fallback_none_specified() {
|
||||||
// When no specific policies are specified, both should use main policy
|
|
||||||
let pd = RoutingMode::PrefillDecode {
|
let pd = RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
prefill_urls: vec![("http://prefill1".to_string(), None)],
|
||||||
decode_urls: vec!["http://decode1".to_string()],
|
decode_urls: vec!["http://decode1".to_string()],
|
||||||
@@ -1271,7 +1232,6 @@ mod tests {
|
|||||||
max_tree_size: 2000,
|
max_tree_size: 2000,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Both should fall back to main policy
|
|
||||||
match pd.get_prefill_policy(&main_policy) {
|
match pd.get_prefill_policy(&main_policy) {
|
||||||
PolicyConfig::CacheAware {
|
PolicyConfig::CacheAware {
|
||||||
cache_threshold, ..
|
cache_threshold, ..
|
||||||
@@ -1293,21 +1253,19 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_regular_mode_policy_fallback() {
|
fn test_regular_mode_policy_fallback() {
|
||||||
// For regular mode, the helper methods should just return the main policy
|
|
||||||
let regular = RoutingMode::Regular {
|
let regular = RoutingMode::Regular {
|
||||||
worker_urls: vec!["http://worker1".to_string()],
|
worker_urls: vec!["http://worker1".to_string()],
|
||||||
};
|
};
|
||||||
|
|
||||||
let main_policy = PolicyConfig::RoundRobin;
|
let main_policy = PolicyConfig::RoundRobin;
|
||||||
|
|
||||||
// Both methods should return main policy for regular mode
|
|
||||||
match regular.get_prefill_policy(&main_policy) {
|
match regular.get_prefill_policy(&main_policy) {
|
||||||
PolicyConfig::RoundRobin => {} // Success
|
PolicyConfig::RoundRobin => {}
|
||||||
_ => panic!("Expected RoundRobin for regular mode"),
|
_ => panic!("Expected RoundRobin for regular mode"),
|
||||||
}
|
}
|
||||||
|
|
||||||
match regular.get_decode_policy(&main_policy) {
|
match regular.get_decode_policy(&main_policy) {
|
||||||
PolicyConfig::RoundRobin => {} // Success
|
PolicyConfig::RoundRobin => {}
|
||||||
_ => panic!("Expected RoundRobin for regular mode"),
|
_ => panic!("Expected RoundRobin for regular mode"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -670,7 +670,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_validate_pd_mode_with_separate_policies() {
|
fn test_validate_pd_mode_with_separate_policies() {
|
||||||
// Test PD mode with different policies for prefill and decode
|
|
||||||
let config = RouterConfig::new(
|
let config = RouterConfig::new(
|
||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![
|
prefill_urls: vec![
|
||||||
@@ -701,7 +700,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
|
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
|
||||||
// Test that power-of-two policy requires at least 2 workers
|
|
||||||
let config = RouterConfig::new(
|
let config = RouterConfig::new(
|
||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
|
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
|
||||||
@@ -726,7 +724,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_validate_grpc_requires_tokenizer() {
|
fn test_validate_grpc_requires_tokenizer() {
|
||||||
// Test that gRPC connection mode requires tokenizer configuration
|
|
||||||
let mut config = RouterConfig::new(
|
let mut config = RouterConfig::new(
|
||||||
RoutingMode::Regular {
|
RoutingMode::Regular {
|
||||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||||
@@ -748,7 +745,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_validate_grpc_with_model_path() {
|
fn test_validate_grpc_with_model_path() {
|
||||||
// Test that gRPC works with model_path
|
|
||||||
let mut config = RouterConfig::new(
|
let mut config = RouterConfig::new(
|
||||||
RoutingMode::Regular {
|
RoutingMode::Regular {
|
||||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||||
@@ -765,7 +761,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_validate_grpc_with_tokenizer_path() {
|
fn test_validate_grpc_with_tokenizer_path() {
|
||||||
// Test that gRPC works with tokenizer_path
|
|
||||||
let mut config = RouterConfig::new(
|
let mut config = RouterConfig::new(
|
||||||
RoutingMode::Regular {
|
RoutingMode::Regular {
|
||||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||||
|
|||||||
@@ -336,7 +336,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let cb = CircuitBreaker::with_config(config);
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
// Record failures up to threshold
|
|
||||||
assert_eq!(cb.state(), CircuitState::Closed);
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
assert_eq!(cb.state(), CircuitState::Closed);
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
@@ -344,7 +343,6 @@ mod tests {
|
|||||||
assert_eq!(cb.state(), CircuitState::Closed);
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
|
|
||||||
// Circuit should now be open
|
|
||||||
assert_eq!(cb.state(), CircuitState::Open);
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
assert!(!cb.can_execute());
|
assert!(!cb.can_execute());
|
||||||
assert_eq!(cb.failure_count(), 3);
|
assert_eq!(cb.failure_count(), 3);
|
||||||
@@ -359,14 +357,11 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let cb = CircuitBreaker::with_config(config);
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
// Open the circuit
|
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
assert_eq!(cb.state(), CircuitState::Open);
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
|
||||||
// Wait for timeout
|
|
||||||
thread::sleep(Duration::from_millis(150));
|
thread::sleep(Duration::from_millis(150));
|
||||||
|
|
||||||
// Circuit should be half-open
|
|
||||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||||
assert!(cb.can_execute());
|
assert!(cb.can_execute());
|
||||||
}
|
}
|
||||||
@@ -381,20 +376,16 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let cb = CircuitBreaker::with_config(config);
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
// Open the circuit
|
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
assert_eq!(cb.state(), CircuitState::Open);
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
|
||||||
// Wait for timeout
|
|
||||||
thread::sleep(Duration::from_millis(100));
|
thread::sleep(Duration::from_millis(100));
|
||||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||||
|
|
||||||
// Record successes
|
|
||||||
cb.record_success();
|
cb.record_success();
|
||||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||||
cb.record_success();
|
cb.record_success();
|
||||||
|
|
||||||
// Circuit should now be closed
|
|
||||||
assert_eq!(cb.state(), CircuitState::Closed);
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
assert!(cb.can_execute());
|
assert!(cb.can_execute());
|
||||||
}
|
}
|
||||||
@@ -408,18 +399,14 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let cb = CircuitBreaker::with_config(config);
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
// Open the circuit
|
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
assert_eq!(cb.state(), CircuitState::Open);
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
|
||||||
// Wait for timeout
|
|
||||||
thread::sleep(Duration::from_millis(100));
|
thread::sleep(Duration::from_millis(100));
|
||||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||||
|
|
||||||
// Record a failure in half-open state
|
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
|
|
||||||
// Circuit should reopen immediately
|
|
||||||
assert_eq!(cb.state(), CircuitState::Open);
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
assert!(!cb.can_execute());
|
assert!(!cb.can_execute());
|
||||||
}
|
}
|
||||||
@@ -432,17 +419,14 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let cb = CircuitBreaker::with_config(config);
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
// Record some failures
|
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
assert_eq!(cb.failure_count(), 2);
|
assert_eq!(cb.failure_count(), 2);
|
||||||
|
|
||||||
// Success should reset failure count
|
|
||||||
cb.record_success();
|
cb.record_success();
|
||||||
assert_eq!(cb.failure_count(), 0);
|
assert_eq!(cb.failure_count(), 0);
|
||||||
assert_eq!(cb.success_count(), 1);
|
assert_eq!(cb.success_count(), 1);
|
||||||
|
|
||||||
// Can now record more failures without opening
|
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
assert_eq!(cb.state(), CircuitState::Closed);
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
@@ -456,11 +440,9 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let cb = CircuitBreaker::with_config(config);
|
let cb = CircuitBreaker::with_config(config);
|
||||||
|
|
||||||
// Open the circuit
|
|
||||||
cb.record_failure();
|
cb.record_failure();
|
||||||
assert_eq!(cb.state(), CircuitState::Open);
|
assert_eq!(cb.state(), CircuitState::Open);
|
||||||
|
|
||||||
// Manual reset
|
|
||||||
cb.reset();
|
cb.reset();
|
||||||
assert_eq!(cb.state(), CircuitState::Closed);
|
assert_eq!(cb.state(), CircuitState::Closed);
|
||||||
assert_eq!(cb.failure_count(), 0);
|
assert_eq!(cb.failure_count(), 0);
|
||||||
@@ -505,7 +487,6 @@ mod tests {
|
|||||||
let cb2 = cb1.clone();
|
let cb2 = cb1.clone();
|
||||||
assert_eq!(cb2.failure_count(), 1);
|
assert_eq!(cb2.failure_count(), 1);
|
||||||
|
|
||||||
// Changes to cb1 affect cb2 (shared state)
|
|
||||||
cb1.record_failure();
|
cb1.record_failure();
|
||||||
assert_eq!(cb2.failure_count(), 2);
|
assert_eq!(cb2.failure_count(), 2);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1562,19 +1562,16 @@ mod tests {
|
|||||||
.worker_type(WorkerType::Regular)
|
.worker_type(WorkerType::Regular)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
// Test health status
|
|
||||||
assert!(dp_worker.is_healthy());
|
assert!(dp_worker.is_healthy());
|
||||||
dp_worker.set_healthy(false);
|
dp_worker.set_healthy(false);
|
||||||
assert!(!dp_worker.is_healthy());
|
assert!(!dp_worker.is_healthy());
|
||||||
|
|
||||||
// Test load tracking
|
|
||||||
assert_eq!(dp_worker.load(), 0);
|
assert_eq!(dp_worker.load(), 0);
|
||||||
dp_worker.increment_load();
|
dp_worker.increment_load();
|
||||||
assert_eq!(dp_worker.load(), 1);
|
assert_eq!(dp_worker.load(), 1);
|
||||||
dp_worker.decrement_load();
|
dp_worker.decrement_load();
|
||||||
assert_eq!(dp_worker.load(), 0);
|
assert_eq!(dp_worker.load(), 0);
|
||||||
|
|
||||||
// Test processed tracking
|
|
||||||
assert_eq!(dp_worker.processed_requests(), 0);
|
assert_eq!(dp_worker.processed_requests(), 0);
|
||||||
dp_worker.increment_processed();
|
dp_worker.increment_processed();
|
||||||
assert_eq!(dp_worker.processed_requests(), 1);
|
assert_eq!(dp_worker.processed_requests(), 1);
|
||||||
|
|||||||
@@ -1485,7 +1485,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse_server_info_with_fallback() {
|
fn test_parse_server_info_with_fallback() {
|
||||||
// Test with "model" instead of "model_id"
|
|
||||||
let json = serde_json::json!({
|
let json = serde_json::json!({
|
||||||
"model": "gpt-4",
|
"model": "gpt-4",
|
||||||
"dp_size": 2
|
"dp_size": 2
|
||||||
|
|||||||
@@ -459,14 +459,12 @@ mod tests {
|
|||||||
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
|
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
|
||||||
let worker_id = registry.register(Arc::from(worker));
|
let worker_id = registry.register(Arc::from(worker));
|
||||||
|
|
||||||
// Verify registration
|
|
||||||
assert!(registry.get(&worker_id).is_some());
|
assert!(registry.get(&worker_id).is_some());
|
||||||
assert!(registry.get_by_url("http://worker1:8080").is_some());
|
assert!(registry.get_by_url("http://worker1:8080").is_some());
|
||||||
assert_eq!(registry.get_by_model("llama-3-8b").len(), 1);
|
assert_eq!(registry.get_by_model("llama-3-8b").len(), 1);
|
||||||
assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1);
|
assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1);
|
||||||
assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1);
|
assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1);
|
||||||
|
|
||||||
// Test stats
|
|
||||||
let stats = registry.stats();
|
let stats = registry.stats();
|
||||||
assert_eq!(stats.total_workers, 1);
|
assert_eq!(stats.total_workers, 1);
|
||||||
assert_eq!(stats.total_models, 1);
|
assert_eq!(stats.total_models, 1);
|
||||||
@@ -519,27 +517,22 @@ mod tests {
|
|||||||
registry.register(Arc::from(worker2));
|
registry.register(Arc::from(worker2));
|
||||||
registry.register(Arc::from(worker3));
|
registry.register(Arc::from(worker3));
|
||||||
|
|
||||||
// Test get_by_model_fast for llama-3
|
|
||||||
let llama_workers = registry.get_by_model_fast("llama-3");
|
let llama_workers = registry.get_by_model_fast("llama-3");
|
||||||
assert_eq!(llama_workers.len(), 2);
|
assert_eq!(llama_workers.len(), 2);
|
||||||
let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect();
|
let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect();
|
||||||
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
||||||
assert!(urls.contains(&"http://worker2:8080".to_string()));
|
assert!(urls.contains(&"http://worker2:8080".to_string()));
|
||||||
|
|
||||||
// Test get_by_model_fast for gpt-4
|
|
||||||
let gpt_workers = registry.get_by_model_fast("gpt-4");
|
let gpt_workers = registry.get_by_model_fast("gpt-4");
|
||||||
assert_eq!(gpt_workers.len(), 1);
|
assert_eq!(gpt_workers.len(), 1);
|
||||||
assert_eq!(gpt_workers[0].url(), "http://worker3:8080");
|
assert_eq!(gpt_workers[0].url(), "http://worker3:8080");
|
||||||
|
|
||||||
// Test get_by_model_fast for non-existent model
|
|
||||||
let unknown_workers = registry.get_by_model_fast("unknown-model");
|
let unknown_workers = registry.get_by_model_fast("unknown-model");
|
||||||
assert_eq!(unknown_workers.len(), 0);
|
assert_eq!(unknown_workers.len(), 0);
|
||||||
|
|
||||||
// Test that both get_by_model and get_by_model_fast return same results
|
|
||||||
let llama_workers_slow = registry.get_by_model("llama-3");
|
let llama_workers_slow = registry.get_by_model("llama-3");
|
||||||
assert_eq!(llama_workers.len(), llama_workers_slow.len());
|
assert_eq!(llama_workers.len(), llama_workers_slow.len());
|
||||||
|
|
||||||
// Test removal updates the model index
|
|
||||||
registry.remove_by_url("http://worker1:8080");
|
registry.remove_by_url("http://worker1:8080");
|
||||||
let llama_workers_after = registry.get_by_model_fast("llama-3");
|
let llama_workers_after = registry.get_by_model_fast("llama-3");
|
||||||
assert_eq!(llama_workers_after.len(), 1);
|
assert_eq!(llama_workers_after.len(), 1);
|
||||||
|
|||||||
@@ -266,7 +266,6 @@ mod tests {
|
|||||||
assert_eq!(chain.responses[1].input, "Second");
|
assert_eq!(chain.responses[1].input, "Second");
|
||||||
assert_eq!(chain.responses[2].input, "Third");
|
assert_eq!(chain.responses[2].input, "Third");
|
||||||
|
|
||||||
// Test with max_depth
|
|
||||||
let limited_chain = store.get_response_chain(&id3, Some(2)).await.unwrap();
|
let limited_chain = store.get_response_chain(&id3, Some(2)).await.unwrap();
|
||||||
assert_eq!(limited_chain.responses.len(), 2);
|
assert_eq!(limited_chain.responses.len(), 2);
|
||||||
assert_eq!(limited_chain.responses[0].input, "Second");
|
assert_eq!(limited_chain.responses[0].input, "Second");
|
||||||
@@ -314,7 +313,6 @@ mod tests {
|
|||||||
let deleted_count = store.delete_user_responses("user1").await.unwrap();
|
let deleted_count = store.delete_user_responses("user1").await.unwrap();
|
||||||
assert_eq!(deleted_count, 2);
|
assert_eq!(deleted_count, 2);
|
||||||
|
|
||||||
// Verify they're gone
|
|
||||||
let user1_responses_after = store.list_user_responses("user1", None).await.unwrap();
|
let user1_responses_after = store.list_user_responses("user1", None).await.unwrap();
|
||||||
assert_eq!(user1_responses_after.len(), 0);
|
assert_eq!(user1_responses_after.len(), 0);
|
||||||
|
|
||||||
|
|||||||
@@ -223,7 +223,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_proto_types_compilation() {
|
fn test_proto_types_compilation() {
|
||||||
// Test that protobuf types can be constructed
|
|
||||||
let health_req = proto::HealthCheckRequest {
|
let health_req = proto::HealthCheckRequest {
|
||||||
tokenized: Some(proto::TokenizedInput {
|
tokenized: Some(proto::TokenizedInput {
|
||||||
original_text: "test".to_string(),
|
original_text: "test".to_string(),
|
||||||
@@ -320,8 +319,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: SessionParams not in current proto - skip test
|
// TODO: SessionParams not in current proto - skip test
|
||||||
// #[test]
|
|
||||||
// fn test_session_params() { ... }
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_embed_request() {
|
fn test_embed_request() {
|
||||||
@@ -349,7 +346,6 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_client_connect_invalid_endpoint() {
|
async fn test_client_connect_invalid_endpoint() {
|
||||||
// Test connecting to an invalid endpoint should return error
|
|
||||||
let result = SglangSchedulerClient::connect("invalid://endpoint").await;
|
let result = SglangSchedulerClient::connect("invalid://endpoint").await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
@@ -365,7 +361,6 @@ mod tests {
|
|||||||
assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]);
|
assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test response type construction
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_generate_stream_chunk() {
|
fn test_generate_stream_chunk() {
|
||||||
let chunk = proto::GenerateStreamChunk {
|
let chunk = proto::GenerateStreamChunk {
|
||||||
@@ -383,6 +378,4 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: ModelInfo not in current proto - skip test
|
// TODO: ModelInfo not in current proto - skip test
|
||||||
// #[test]
|
|
||||||
// fn test_model_info() { ... }
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -288,8 +288,6 @@ impl McpClientManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ===== Helpers =====
|
|
||||||
|
|
||||||
fn client_for(&self, server_name: &str) -> McpResult<&RunningService<RoleClient, ()>> {
|
fn client_for(&self, server_name: &str) -> McpResult<&RunningService<RoleClient, ()>> {
|
||||||
self.clients
|
self.clients
|
||||||
.get(server_name)
|
.get(server_name)
|
||||||
@@ -317,8 +315,6 @@ impl McpClientManager {
|
|||||||
.ok_or_else(|| McpError::ResourceNotFound(uri.to_string()))
|
.ok_or_else(|| McpError::ResourceNotFound(uri.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ===== Tool Methods =====
|
|
||||||
|
|
||||||
/// Call a tool by name
|
/// Call a tool by name
|
||||||
pub async fn call_tool(
|
pub async fn call_tool(
|
||||||
&self,
|
&self,
|
||||||
@@ -380,8 +376,6 @@ impl McpClientManager {
|
|||||||
self.clients.keys().cloned().collect()
|
self.clients.keys().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ===== Prompt Methods =====
|
|
||||||
|
|
||||||
/// Get a prompt by name with arguments
|
/// Get a prompt by name with arguments
|
||||||
pub async fn get_prompt(
|
pub async fn get_prompt(
|
||||||
&self,
|
&self,
|
||||||
@@ -439,8 +433,6 @@ impl McpClientManager {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ===== Resource Methods =====
|
|
||||||
|
|
||||||
/// Read a resource by URI
|
/// Read a resource by URI
|
||||||
pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
|
pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
|
||||||
let (server_name, _resource) = self.resource_entry(uri)?;
|
let (server_name, _resource) = self.resource_entry(uri)?;
|
||||||
|
|||||||
@@ -598,8 +598,6 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use std::net::TcpListener;
|
use std::net::TcpListener;
|
||||||
|
|
||||||
// ============= PrometheusConfig Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_prometheus_config_default() {
|
fn test_prometheus_config_default() {
|
||||||
let config = PrometheusConfig::default();
|
let config = PrometheusConfig::default();
|
||||||
@@ -628,8 +626,6 @@ mod tests {
|
|||||||
assert_eq!(cloned.host, config.host);
|
assert_eq!(cloned.host, config.host);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= IP Address Parsing Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_valid_ipv4_parsing() {
|
fn test_valid_ipv4_parsing() {
|
||||||
let test_cases = vec!["127.0.0.1", "192.168.1.1", "0.0.0.0"];
|
let test_cases = vec!["127.0.0.1", "192.168.1.1", "0.0.0.0"];
|
||||||
@@ -679,8 +675,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Socket Address Creation Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_socket_addr_creation() {
|
fn test_socket_addr_creation() {
|
||||||
let test_cases = vec![("127.0.0.1", 8080), ("0.0.0.0", 29000), ("::1", 9090)];
|
let test_cases = vec![("127.0.0.1", 8080), ("0.0.0.0", 29000), ("::1", 9090)];
|
||||||
@@ -716,8 +710,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Duration Bucket Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_duration_bucket_coverage() {
|
fn test_duration_bucket_coverage() {
|
||||||
let test_cases: [(f64, &str); 7] = [
|
let test_cases: [(f64, &str); 7] = [
|
||||||
@@ -743,8 +735,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Matcher Configuration Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_duration_suffix_matcher() {
|
fn test_duration_suffix_matcher() {
|
||||||
let matcher = Matcher::Suffix(String::from("duration_seconds"));
|
let matcher = Matcher::Suffix(String::from("duration_seconds"));
|
||||||
@@ -763,8 +753,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Builder Configuration Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_prometheus_builder_configuration() {
|
fn test_prometheus_builder_configuration() {
|
||||||
let _config = PrometheusConfig::default();
|
let _config = PrometheusConfig::default();
|
||||||
@@ -783,16 +771,12 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Upkeep Timeout Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_upkeep_timeout_duration() {
|
fn test_upkeep_timeout_duration() {
|
||||||
let timeout = Duration::from_secs(5 * 60);
|
let timeout = Duration::from_secs(5 * 60);
|
||||||
assert_eq!(timeout.as_secs(), 300);
|
assert_eq!(timeout.as_secs(), 300);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Custom Bucket Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_custom_buckets_for_different_metrics() {
|
fn test_custom_buckets_for_different_metrics() {
|
||||||
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
|
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
|
||||||
@@ -810,8 +794,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= RouterMetrics Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_metrics_static_methods() {
|
fn test_metrics_static_methods() {
|
||||||
RouterMetrics::record_request("/generate");
|
RouterMetrics::record_request("/generate");
|
||||||
@@ -876,8 +858,6 @@ mod tests {
|
|||||||
TokenizerMetrics::set_vocab_size("huggingface", 50000);
|
TokenizerMetrics::set_vocab_size("huggingface", 50000);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Port Availability Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_port_already_in_use() {
|
fn test_port_already_in_use() {
|
||||||
let port = 29123;
|
let port = 29123;
|
||||||
@@ -892,8 +872,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Integration Test Helpers =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_metrics_endpoint_accessibility() {
|
fn test_metrics_endpoint_accessibility() {
|
||||||
let config = PrometheusConfig {
|
let config = PrometheusConfig {
|
||||||
@@ -937,8 +915,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Edge Cases Tests =============
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_empty_string_metrics() {
|
fn test_empty_string_metrics() {
|
||||||
RouterMetrics::record_request("");
|
RouterMetrics::record_request("");
|
||||||
|
|||||||
@@ -178,8 +178,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Logging Middleware =============
|
|
||||||
|
|
||||||
/// Custom span maker that includes request ID
|
/// Custom span maker that includes request ID
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct RequestSpan;
|
pub struct RequestSpan;
|
||||||
@@ -336,8 +334,6 @@ pub fn log_request(entry: RequestLogEntry) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Concurrency Limiting with Queue Support ============
|
|
||||||
|
|
||||||
/// Request queue entry
|
/// Request queue entry
|
||||||
pub struct QueuedRequest {
|
pub struct QueuedRequest {
|
||||||
/// Time when the request was queued
|
/// Time when the request was queued
|
||||||
|
|||||||
@@ -54,21 +54,17 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_from_config() {
|
fn test_create_from_config() {
|
||||||
// Test Random
|
|
||||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||||
assert_eq!(policy.name(), "random");
|
assert_eq!(policy.name(), "random");
|
||||||
|
|
||||||
// Test RoundRobin
|
|
||||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin);
|
let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin);
|
||||||
assert_eq!(policy.name(), "round_robin");
|
assert_eq!(policy.name(), "round_robin");
|
||||||
|
|
||||||
// Test PowerOfTwo
|
|
||||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo {
|
let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo {
|
||||||
load_check_interval_secs: 60,
|
load_check_interval_secs: 60,
|
||||||
});
|
});
|
||||||
assert_eq!(policy.name(), "power_of_two");
|
assert_eq!(policy.name(), "power_of_two");
|
||||||
|
|
||||||
// Test CacheAware
|
|
||||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware {
|
let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware {
|
||||||
cache_threshold: 0.7,
|
cache_threshold: 0.7,
|
||||||
balance_abs_threshold: 10,
|
balance_abs_threshold: 10,
|
||||||
|
|||||||
@@ -75,7 +75,6 @@ mod tests {
|
|||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Test multiple selections to ensure randomness
|
|
||||||
let mut counts = HashMap::new();
|
let mut counts = HashMap::new();
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||||
|
|||||||
@@ -49,12 +49,6 @@ use std::collections::HashMap;
|
|||||||
// - StringOrArray & LoRAPath types
|
// - StringOrArray & LoRAPath types
|
||||||
// - Helper functions
|
// - Helper functions
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = OPENAI SPEC - Chat Completions API =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
// ============= Message Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum ChatMessage {
|
pub enum ChatMessage {
|
||||||
@@ -119,8 +113,6 @@ pub struct ImageUrl {
|
|||||||
pub detail: Option<String>, // "auto", "low", or "high"
|
pub detail: Option<String>, // "auto", "low", or "high"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Response Format Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum ResponseFormat {
|
pub enum ResponseFormat {
|
||||||
@@ -140,8 +132,6 @@ pub struct JsonSchemaFormat {
|
|||||||
pub strict: Option<bool>,
|
pub strict: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Streaming Delta Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ChatMessageDelta {
|
pub struct ChatMessageDelta {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -177,8 +167,6 @@ pub struct FunctionCallDelta {
|
|||||||
pub arguments: Option<String>,
|
pub arguments: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Request =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||||
pub struct ChatCompletionRequest {
|
pub struct ChatCompletionRequest {
|
||||||
/// A list of messages comprising the conversation so far
|
/// A list of messages comprising the conversation so far
|
||||||
@@ -299,7 +287,6 @@ pub struct ChatCompletionRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub verbosity: Option<i32>,
|
pub verbosity: Option<i32>,
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Top-k sampling parameter (-1 to disable)
|
/// Top-k sampling parameter (-1 to disable)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub top_k: Option<i32>,
|
pub top_k: Option<i32>,
|
||||||
@@ -423,8 +410,6 @@ impl GenerationRequest for ChatCompletionRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Regular Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ChatCompletionResponse {
|
pub struct ChatCompletionResponse {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@@ -453,8 +438,6 @@ pub struct ChatChoice {
|
|||||||
pub hidden_states: Option<Vec<f32>>,
|
pub hidden_states: Option<Vec<f32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Streaming Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ChatCompletionStreamResponse {
|
pub struct ChatCompletionStreamResponse {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@@ -477,9 +460,6 @@ pub struct ChatStreamChoice {
|
|||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = OPENAI SPEC - Completions API =
|
|
||||||
// ==================================================================
|
|
||||||
// Completions API request types (v1/completions) - DEPRECATED but still supported
|
// Completions API request types (v1/completions) - DEPRECATED but still supported
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -554,7 +534,6 @@ pub struct CompletionRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub seed: Option<i64>,
|
pub seed: Option<i64>,
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Top-k sampling parameter (-1 to disable)
|
/// Top-k sampling parameter (-1 to disable)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub top_k: Option<i32>,
|
pub top_k: Option<i32>,
|
||||||
@@ -599,7 +578,6 @@ pub struct CompletionRequest {
|
|||||||
#[serde(default = "default_true")]
|
#[serde(default = "default_true")]
|
||||||
pub skip_special_tokens: bool,
|
pub skip_special_tokens: bool,
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
/// Path to LoRA adapter(s) for model customization
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub lora_path: Option<LoRAPath>,
|
pub lora_path: Option<LoRAPath>,
|
||||||
@@ -638,8 +616,6 @@ impl GenerationRequest for CompletionRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Regular Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct CompletionResponse {
|
pub struct CompletionResponse {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@@ -668,8 +644,6 @@ pub struct CompletionChoice {
|
|||||||
pub hidden_states: Option<Vec<f32>>,
|
pub hidden_states: Option<Vec<f32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Streaming Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct CompletionStreamResponse {
|
pub struct CompletionStreamResponse {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@@ -690,12 +664,6 @@ pub struct CompletionStreamChoice {
|
|||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = OPENAI SPEC - Responses API =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
// ============= Tool Definitions =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ResponseTool {
|
pub struct ResponseTool {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
@@ -709,8 +677,6 @@ pub enum ResponseToolType {
|
|||||||
CodeInterpreter,
|
CodeInterpreter,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Reasoning Configuration =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ResponseReasoningParam {
|
pub struct ResponseReasoningParam {
|
||||||
#[serde(default = "default_reasoning_effort")]
|
#[serde(default = "default_reasoning_effort")]
|
||||||
@@ -729,8 +695,6 @@ pub enum ReasoningEffort {
|
|||||||
High,
|
High,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Input/Output Items =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
@@ -790,8 +754,6 @@ pub enum ResponseReasoningContent {
|
|||||||
ReasoningText { text: String },
|
ReasoningText { text: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Output Items for Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
@@ -823,8 +785,6 @@ pub enum ResponseOutputItem {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Service Tier =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum ServiceTier {
|
pub enum ServiceTier {
|
||||||
@@ -841,8 +801,6 @@ impl Default for ServiceTier {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Truncation =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum Truncation {
|
pub enum Truncation {
|
||||||
@@ -856,8 +814,6 @@ impl Default for Truncation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Response Status =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum ResponseStatus {
|
pub enum ResponseStatus {
|
||||||
@@ -868,8 +824,6 @@ pub enum ResponseStatus {
|
|||||||
Cancelled,
|
Cancelled,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Reasoning Info =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ReasoningInfo {
|
pub struct ReasoningInfo {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -878,8 +832,6 @@ pub struct ReasoningInfo {
|
|||||||
pub summary: Option<String>,
|
pub summary: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Text Format =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ResponseTextFormat {
|
pub struct ResponseTextFormat {
|
||||||
pub format: TextFormatType,
|
pub format: TextFormatType,
|
||||||
@@ -891,8 +843,6 @@ pub struct TextFormatType {
|
|||||||
pub format_type: String,
|
pub format_type: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Include Fields =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum IncludeField {
|
pub enum IncludeField {
|
||||||
@@ -910,8 +860,6 @@ pub enum IncludeField {
|
|||||||
ReasoningEncryptedContent,
|
ReasoningEncryptedContent,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Usage Info =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct UsageInfo {
|
pub struct UsageInfo {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
@@ -928,8 +876,6 @@ pub struct PromptTokenUsageInfo {
|
|||||||
pub cached_tokens: u32,
|
pub cached_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Response Usage Format =============
|
|
||||||
|
|
||||||
/// OpenAI Responses API usage format (different from standard UsageInfo)
|
/// OpenAI Responses API usage format (different from standard UsageInfo)
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ResponseUsage {
|
pub struct ResponseUsage {
|
||||||
@@ -1038,7 +984,6 @@ fn generate_request_id() -> String {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ResponsesRequest {
|
pub struct ResponsesRequest {
|
||||||
// ============= Core OpenAI API fields =============
|
|
||||||
/// Run the request in the background
|
/// Run the request in the background
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub background: bool,
|
pub background: bool,
|
||||||
@@ -1122,7 +1067,6 @@ pub struct ResponsesRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub user: Option<String>,
|
pub user: Option<String>,
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Request ID
|
/// Request ID
|
||||||
#[serde(default = "generate_request_id")]
|
#[serde(default = "generate_request_id")]
|
||||||
pub request_id: String,
|
pub request_id: String,
|
||||||
@@ -1606,8 +1550,6 @@ impl ResponsesResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Helper Functions =============
|
|
||||||
|
|
||||||
impl ResponseOutputItem {
|
impl ResponseOutputItem {
|
||||||
/// Create a new message output item
|
/// Create a new message output item
|
||||||
pub fn new_message(
|
pub fn new_message(
|
||||||
@@ -1708,20 +1650,12 @@ impl UsageInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = OPENAI SPEC - Common =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
// ============= Shared Request Components =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct StreamOptions {
|
pub struct StreamOptions {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub include_usage: Option<bool>,
|
pub include_usage: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Tool Choice Types =============
|
|
||||||
|
|
||||||
/// Tool choice value for simple string options
|
/// Tool choice value for simple string options
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
@@ -1793,8 +1727,6 @@ pub struct FunctionCallResponse {
|
|||||||
pub arguments: Option<String>, // JSON string
|
pub arguments: Option<String>, // JSON string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Usage Tracking =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Usage {
|
pub struct Usage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
@@ -1809,8 +1741,6 @@ pub struct CompletionTokensDetails {
|
|||||||
pub reasoning_tokens: Option<u32>,
|
pub reasoning_tokens: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Logprobs Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct LogProbs {
|
pub struct LogProbs {
|
||||||
pub tokens: Vec<String>,
|
pub tokens: Vec<String>,
|
||||||
@@ -1860,10 +1790,6 @@ pub struct ErrorDetail {
|
|||||||
pub code: Option<String>,
|
pub code: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = SGLANG SPEC - GENERATE API =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum InputIds {
|
pub enum InputIds {
|
||||||
@@ -1975,7 +1901,6 @@ pub struct GenerateRequest {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub return_logprob: bool,
|
pub return_logprob: bool,
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
/// Path to LoRA adapter(s) for model customization
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub lora_path: Option<LoRAPath>,
|
pub lora_path: Option<LoRAPath>,
|
||||||
@@ -2036,10 +1961,6 @@ impl GenerationRequest for GenerateRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = SGLANG SPEC - RERANK API =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
// Constants for rerank API
|
// Constants for rerank API
|
||||||
pub const DEFAULT_MODEL_NAME: &str = "default";
|
pub const DEFAULT_MODEL_NAME: &str = "default";
|
||||||
|
|
||||||
@@ -2237,10 +2158,6 @@ impl RerankResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = OPENAI SPEC - Embeddings API =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
/// Embeddings request compatible with OpenAI API
|
/// Embeddings request compatible with OpenAI API
|
||||||
/// We intentionally keep fields flexible to pass through to workers.
|
/// We intentionally keep fields flexible to pass through to workers.
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -2292,10 +2209,6 @@ impl GenerationRequest for EmbeddingRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = COMMON =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
/// Helper function for serde default value
|
/// Helper function for serde default value
|
||||||
pub fn default_true() -> bool {
|
pub fn default_true() -> bool {
|
||||||
true
|
true
|
||||||
@@ -2359,10 +2272,6 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use serde_json::{from_str, json, to_string};
|
use serde_json::{from_str, json, to_string};
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = RERANK REQUEST TESTS =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rerank_request_serialization() {
|
fn test_rerank_request_serialization() {
|
||||||
let request = RerankRequest {
|
let request = RerankRequest {
|
||||||
@@ -2534,10 +2443,6 @@ mod tests {
|
|||||||
assert_eq!(request.effective_top_k(), 3);
|
assert_eq!(request.effective_top_k(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = RERANK RESPONSE TESTS =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rerank_response_creation() {
|
fn test_rerank_response_creation() {
|
||||||
let results = vec![
|
let results = vec![
|
||||||
@@ -2709,10 +2614,6 @@ mod tests {
|
|||||||
assert_eq!(response.results[0].document, None);
|
assert_eq!(response.results[0].document, None);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = RERANK RESULT TESTS =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rerank_result_serialization() {
|
fn test_rerank_result_serialization() {
|
||||||
let result = RerankResult {
|
let result = RerankResult {
|
||||||
@@ -2755,10 +2656,6 @@ mod tests {
|
|||||||
assert_eq!(deserialized.meta_info, result.meta_info);
|
assert_eq!(deserialized.meta_info, result.meta_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = V1 COMPATIBILITY TESTS =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_v1_rerank_req_input_serialization() {
|
fn test_v1_rerank_req_input_serialization() {
|
||||||
let v1_input = V1RerankReqInput {
|
let v1_input = V1RerankReqInput {
|
||||||
@@ -2791,10 +2688,6 @@ mod tests {
|
|||||||
assert_eq!(request.user, None);
|
assert_eq!(request.user, None);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = GENERATION REQUEST TRAIT TESTS =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rerank_request_generation_request_trait() {
|
fn test_rerank_request_generation_request_trait() {
|
||||||
let request = RerankRequest {
|
let request = RerankRequest {
|
||||||
@@ -2812,10 +2705,6 @@ mod tests {
|
|||||||
assert_eq!(request.extract_text_for_routing(), "test query");
|
assert_eq!(request.extract_text_for_routing(), "test query");
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = EDGE CASES AND STRESS TESTS =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rerank_request_very_long_query() {
|
fn test_rerank_request_very_long_query() {
|
||||||
let long_query = "a".repeat(100000);
|
let long_query = "a".repeat(100000);
|
||||||
@@ -2918,10 +2807,6 @@ mod tests {
|
|||||||
assert_eq!(usage.total_tokens, 150);
|
assert_eq!(usage.total_tokens, 150);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = INTEGRATION TESTS =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_full_rerank_workflow() {
|
fn test_full_rerank_workflow() {
|
||||||
// Create request
|
// Create request
|
||||||
@@ -2980,7 +2865,6 @@ mod tests {
|
|||||||
// Apply top_k
|
// Apply top_k
|
||||||
response.apply_top_k(request.effective_top_k());
|
response.apply_top_k(request.effective_top_k());
|
||||||
|
|
||||||
// Verify results
|
|
||||||
assert_eq!(response.results.len(), 2);
|
assert_eq!(response.results.len(), 2);
|
||||||
assert_eq!(response.results[0].score, 0.95);
|
assert_eq!(response.results[0].score, 0.95);
|
||||||
assert_eq!(response.results[0].index, 0);
|
assert_eq!(response.results[0].index, 0);
|
||||||
@@ -2995,10 +2879,6 @@ mod tests {
|
|||||||
assert_eq!(deserialized.model, response.model);
|
assert_eq!(deserialized.model, response.model);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = EMBEDDINGS REQUEST TESTS =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_embedding_request_serialization_string_input() {
|
fn test_embedding_request_serialization_string_input() {
|
||||||
let req = EmbeddingRequest {
|
let req = EmbeddingRequest {
|
||||||
|
|||||||
@@ -537,10 +537,6 @@ pub trait ValidatableRequest:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================================================================
|
|
||||||
// = OPENAI CHAT COMPLETION VALIDATION =
|
|
||||||
// ==================================================================
|
|
||||||
|
|
||||||
impl SamplingOptionsProvider for ChatCompletionRequest {
|
impl SamplingOptionsProvider for ChatCompletionRequest {
|
||||||
fn get_temperature(&self) -> Option<f32> {
|
fn get_temperature(&self) -> Option<f32> {
|
||||||
self.temperature
|
self.temperature
|
||||||
@@ -909,7 +905,6 @@ mod tests {
|
|||||||
fn test_chat_cross_parameter_conflicts() {
|
fn test_chat_cross_parameter_conflicts() {
|
||||||
let mut request = create_valid_chat_request();
|
let mut request = create_valid_chat_request();
|
||||||
|
|
||||||
// Test 1: max_tokens vs max_completion_tokens conflict
|
|
||||||
request.max_tokens = Some(100);
|
request.max_tokens = Some(100);
|
||||||
request.max_completion_tokens = Some(200);
|
request.max_completion_tokens = Some(200);
|
||||||
assert!(
|
assert!(
|
||||||
@@ -921,7 +916,6 @@ mod tests {
|
|||||||
request.max_tokens = None;
|
request.max_tokens = None;
|
||||||
request.max_completion_tokens = None;
|
request.max_completion_tokens = None;
|
||||||
|
|
||||||
// Test 2: tools vs functions conflict (deprecated)
|
|
||||||
request.tools = Some(vec![]);
|
request.tools = Some(vec![]);
|
||||||
request.functions = Some(vec![]);
|
request.functions = Some(vec![]);
|
||||||
assert!(
|
assert!(
|
||||||
@@ -929,7 +923,6 @@ mod tests {
|
|||||||
"Should reject both tools and functions"
|
"Should reject both tools and functions"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test 3: logprobs=true without top_logprobs should be valid
|
|
||||||
let mut request = create_valid_chat_request();
|
let mut request = create_valid_chat_request();
|
||||||
request.logprobs = true;
|
request.logprobs = true;
|
||||||
request.top_logprobs = None;
|
request.top_logprobs = None;
|
||||||
@@ -938,7 +931,6 @@ mod tests {
|
|||||||
"logprobs=true without top_logprobs should be valid"
|
"logprobs=true without top_logprobs should be valid"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test 4: top_logprobs without logprobs=true should fail (OpenAI rule)
|
|
||||||
let mut request = create_valid_chat_request();
|
let mut request = create_valid_chat_request();
|
||||||
request.logprobs = false;
|
request.logprobs = false;
|
||||||
request.top_logprobs = Some(5);
|
request.top_logprobs = Some(5);
|
||||||
@@ -967,7 +959,6 @@ mod tests {
|
|||||||
fn test_parameter_ranges() {
|
fn test_parameter_ranges() {
|
||||||
let mut request = create_valid_chat_request();
|
let mut request = create_valid_chat_request();
|
||||||
|
|
||||||
// Test temperature range (0.0 to 2.0)
|
|
||||||
request.temperature = Some(1.5);
|
request.temperature = Some(1.5);
|
||||||
assert!(request.validate().is_ok());
|
assert!(request.validate().is_ok());
|
||||||
request.temperature = Some(-0.1);
|
request.temperature = Some(-0.1);
|
||||||
@@ -975,7 +966,6 @@ mod tests {
|
|||||||
request.temperature = Some(3.0);
|
request.temperature = Some(3.0);
|
||||||
assert!(request.validate().is_err());
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
// Test top_p range (0.0 to 1.0)
|
|
||||||
request.temperature = Some(1.0); // Reset
|
request.temperature = Some(1.0); // Reset
|
||||||
request.top_p = Some(0.9);
|
request.top_p = Some(0.9);
|
||||||
assert!(request.validate().is_ok());
|
assert!(request.validate().is_ok());
|
||||||
@@ -984,7 +974,6 @@ mod tests {
|
|||||||
request.top_p = Some(1.5);
|
request.top_p = Some(1.5);
|
||||||
assert!(request.validate().is_err());
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
// Test frequency_penalty range (-2.0 to 2.0)
|
|
||||||
request.top_p = Some(0.9); // Reset
|
request.top_p = Some(0.9); // Reset
|
||||||
request.frequency_penalty = Some(1.5);
|
request.frequency_penalty = Some(1.5);
|
||||||
assert!(request.validate().is_ok());
|
assert!(request.validate().is_ok());
|
||||||
@@ -993,7 +982,6 @@ mod tests {
|
|||||||
request.frequency_penalty = Some(3.0);
|
request.frequency_penalty = Some(3.0);
|
||||||
assert!(request.validate().is_err());
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
// Test presence_penalty range (-2.0 to 2.0)
|
|
||||||
request.frequency_penalty = Some(0.0); // Reset
|
request.frequency_penalty = Some(0.0); // Reset
|
||||||
request.presence_penalty = Some(-1.5);
|
request.presence_penalty = Some(-1.5);
|
||||||
assert!(request.validate().is_ok());
|
assert!(request.validate().is_ok());
|
||||||
@@ -1002,7 +990,6 @@ mod tests {
|
|||||||
request.presence_penalty = Some(2.5);
|
request.presence_penalty = Some(2.5);
|
||||||
assert!(request.validate().is_err());
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
// Test repetition_penalty range (0.0 to 2.0)
|
|
||||||
request.presence_penalty = Some(0.0); // Reset
|
request.presence_penalty = Some(0.0); // Reset
|
||||||
request.repetition_penalty = Some(1.2);
|
request.repetition_penalty = Some(1.2);
|
||||||
assert!(request.validate().is_ok());
|
assert!(request.validate().is_ok());
|
||||||
@@ -1011,7 +998,6 @@ mod tests {
|
|||||||
request.repetition_penalty = Some(2.1);
|
request.repetition_penalty = Some(2.1);
|
||||||
assert!(request.validate().is_err());
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
// Test min_p range (0.0 to 1.0)
|
|
||||||
request.repetition_penalty = Some(1.0); // Reset
|
request.repetition_penalty = Some(1.0); // Reset
|
||||||
request.min_p = Some(0.5);
|
request.min_p = Some(0.5);
|
||||||
assert!(request.validate().is_ok());
|
assert!(request.validate().is_ok());
|
||||||
|
|||||||
@@ -373,7 +373,6 @@ mod tests {
|
|||||||
// Both should use the same passthrough parser instance
|
// Both should use the same passthrough parser instance
|
||||||
assert!(Arc::ptr_eq(&parser1, &parser2));
|
assert!(Arc::ptr_eq(&parser1, &parser2));
|
||||||
|
|
||||||
// Verify it's actually a passthrough parser
|
|
||||||
let parser = parser1.lock().unwrap();
|
let parser = parser1.lock().unwrap();
|
||||||
assert_eq!(parser.model_type(), "passthrough");
|
assert_eq!(parser.model_type(), "passthrough");
|
||||||
}
|
}
|
||||||
@@ -456,7 +455,6 @@ mod tests {
|
|||||||
|
|
||||||
match p.detect_and_parse_reasoning(&input) {
|
match p.detect_and_parse_reasoning(&input) {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
// Verify parsing worked correctly with substantial content
|
|
||||||
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
|
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
|
||||||
assert!(result
|
assert!(result
|
||||||
.normal_text
|
.normal_text
|
||||||
|
|||||||
@@ -88,7 +88,6 @@ mod tests {
|
|||||||
fn test_kimi_partial_unicode() {
|
fn test_kimi_partial_unicode() {
|
||||||
let mut parser = KimiParser::new();
|
let mut parser = KimiParser::new();
|
||||||
|
|
||||||
// Test partial Unicode token buffering
|
|
||||||
let result1 = parser
|
let result1 = parser
|
||||||
.parse_reasoning_streaming_incremental("◁thi")
|
.parse_reasoning_streaming_incremental("◁thi")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|||||||
@@ -96,8 +96,6 @@ impl GrpcRouter {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Chat Implementation ============
|
|
||||||
|
|
||||||
/// Main route_chat implementation
|
/// Main route_chat implementation
|
||||||
async fn route_chat_impl(
|
async fn route_chat_impl(
|
||||||
&self,
|
&self,
|
||||||
@@ -207,7 +205,6 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Helper Methods ============
|
|
||||||
/// Select a worker for the request
|
/// Select a worker for the request
|
||||||
fn select_worker_for_request(
|
fn select_worker_for_request(
|
||||||
&self,
|
&self,
|
||||||
@@ -809,7 +806,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_transform_messages_mixed_content_types() {
|
fn test_transform_messages_mixed_content_types() {
|
||||||
// Test with both text and multimodal content
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
ChatMessage::User {
|
ChatMessage::User {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
@@ -833,7 +829,6 @@ mod tests {
|
|||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
// Test String format
|
|
||||||
let result_string =
|
let result_string =
|
||||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -842,7 +837,6 @@ mod tests {
|
|||||||
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
|
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
|
||||||
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
|
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
|
||||||
|
|
||||||
// Test OpenAI format
|
|
||||||
let result_openai =
|
let result_openai =
|
||||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|||||||
@@ -957,7 +957,6 @@ impl RouterTrait for PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
// Test model generation capability by selecting a random pair and testing them
|
|
||||||
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
||||||
|
|
||||||
// Select a random worker pair using the policy
|
// Select a random worker pair using the policy
|
||||||
@@ -972,7 +971,6 @@ impl RouterTrait for PDRouter {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test prefill server's health_generate
|
|
||||||
let prefill_url = format!("{}/health_generate", prefill.url());
|
let prefill_url = format!("{}/health_generate", prefill.url());
|
||||||
let (prefill_result, decode_result) = tokio::join!(
|
let (prefill_result, decode_result) = tokio::join!(
|
||||||
self.client.get(&prefill_url).send(),
|
self.client.get(&prefill_url).send(),
|
||||||
|
|||||||
@@ -1018,7 +1018,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let port = 8080u16;
|
let port = 8080u16;
|
||||||
|
|
||||||
// Test that unified handler works for regular mode
|
|
||||||
handle_pod_event(
|
handle_pod_event(
|
||||||
&pod_info,
|
&pod_info,
|
||||||
Arc::clone(&tracked_pods),
|
Arc::clone(&tracked_pods),
|
||||||
@@ -1045,7 +1044,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let port = 8080u16;
|
let port = 8080u16;
|
||||||
|
|
||||||
// Test that unified handler works for PD mode with prefill
|
|
||||||
handle_pod_event(
|
handle_pod_event(
|
||||||
&pod_info,
|
&pod_info,
|
||||||
Arc::clone(&tracked_pods),
|
Arc::clone(&tracked_pods),
|
||||||
@@ -1080,7 +1078,6 @@ mod tests {
|
|||||||
|
|
||||||
let port = 8080u16;
|
let port = 8080u16;
|
||||||
|
|
||||||
// Test that unified handler works for deletion in PD mode
|
|
||||||
handle_pod_deletion(
|
handle_pod_deletion(
|
||||||
&pod_info,
|
&pod_info,
|
||||||
Arc::clone(&tracked_pods),
|
Arc::clone(&tracked_pods),
|
||||||
|
|||||||
@@ -279,11 +279,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_tiktoken_tokenizer() {
|
fn test_create_tiktoken_tokenizer() {
|
||||||
// Test creating tokenizer for GPT models
|
|
||||||
let tokenizer = create_tokenizer("gpt-4").unwrap();
|
let tokenizer = create_tokenizer("gpt-4").unwrap();
|
||||||
assert!(tokenizer.vocab_size() > 0);
|
assert!(tokenizer.vocab_size() > 0);
|
||||||
|
|
||||||
// Test encoding and decoding
|
|
||||||
let text = "Hello, world!";
|
let text = "Hello, world!";
|
||||||
let encoding = tokenizer.encode(text).unwrap();
|
let encoding = tokenizer.encode(text).unwrap();
|
||||||
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
|
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
|
||||||
@@ -292,7 +290,6 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_download_tokenizer_from_hf() {
|
async fn test_download_tokenizer_from_hf() {
|
||||||
// Test with a small model that should have tokenizer files
|
|
||||||
// Skip this test if HF_TOKEN is not set and we're in CI
|
// Skip this test if HF_TOKEN is not set and we're in CI
|
||||||
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
|
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
|
||||||
println!("Skipping HF download test in CI without HF_TOKEN");
|
println!("Skipping HF download test in CI without HF_TOKEN");
|
||||||
|
|||||||
@@ -206,7 +206,6 @@ mod tests {
|
|||||||
// The incremental text should be " world" (with the space that the mock tokenizer adds)
|
// The incremental text should be " world" (with the space that the mock tokenizer adds)
|
||||||
assert_eq!(text2, " world");
|
assert_eq!(text2, " world");
|
||||||
|
|
||||||
// Verify the full text
|
|
||||||
assert_eq!(seq.text().unwrap(), "Hello world");
|
assert_eq!(seq.text().unwrap(), "Hello world");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -398,7 +398,6 @@ mod tests {
|
|||||||
// The fix ensures we only output NEW text, not accumulated text
|
// The fix ensures we only output NEW text, not accumulated text
|
||||||
assert_eq!(outputs.len(), 3);
|
assert_eq!(outputs.len(), 3);
|
||||||
|
|
||||||
// Verify no text is repeated
|
|
||||||
for i in 0..outputs.len() {
|
for i in 0..outputs.len() {
|
||||||
for j in i + 1..outputs.len() {
|
for j in i + 1..outputs.len() {
|
||||||
// No output should contain another (no accumulation)
|
// No output should contain another (no accumulation)
|
||||||
|
|||||||
@@ -36,22 +36,17 @@ fn test_tokenizer_wrapper() {
|
|||||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||||
|
|
||||||
// Test encoding
|
|
||||||
let encoding = tokenizer.encode("Hello world").unwrap();
|
let encoding = tokenizer.encode("Hello world").unwrap();
|
||||||
assert_eq!(encoding.token_ids(), &[1, 2]);
|
assert_eq!(encoding.token_ids(), &[1, 2]);
|
||||||
|
|
||||||
// Test decoding
|
|
||||||
let text = tokenizer.decode(&[1, 2], false).unwrap();
|
let text = tokenizer.decode(&[1, 2], false).unwrap();
|
||||||
assert_eq!(text, "Hello world");
|
assert_eq!(text, "Hello world");
|
||||||
|
|
||||||
// Test vocab size
|
|
||||||
assert_eq!(tokenizer.vocab_size(), 8);
|
assert_eq!(tokenizer.vocab_size(), 8);
|
||||||
|
|
||||||
// Test token to ID
|
|
||||||
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
|
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
|
||||||
assert_eq!(tokenizer.token_to_id("unknown"), None);
|
assert_eq!(tokenizer.token_to_id("unknown"), None);
|
||||||
|
|
||||||
// Test ID to token
|
|
||||||
assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string()));
|
assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string()));
|
||||||
assert_eq!(tokenizer.id_to_token(9999), None);
|
assert_eq!(tokenizer.id_to_token(9999), None);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -246,7 +246,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_unrecognized_model_name_returns_error() {
|
fn test_unrecognized_model_name_returns_error() {
|
||||||
// Test that unrecognized model names return an error
|
|
||||||
let result = TiktokenTokenizer::from_model_name("distilgpt-2");
|
let result = TiktokenTokenizer::from_model_name("distilgpt-2");
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
if let Err(e) = result {
|
if let Err(e) = result {
|
||||||
@@ -268,7 +267,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_recognized_model_names() {
|
fn test_recognized_model_names() {
|
||||||
// Test that recognized model names work correctly
|
|
||||||
assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
|
assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
|
||||||
assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
|
assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
|
||||||
assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
|
assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
|
||||||
|
|||||||
@@ -139,7 +139,6 @@ mod tests {
|
|||||||
async fn test_single_call_with_semicolon() {
|
async fn test_single_call_with_semicolon() {
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
// Note: Llama 3.2 doesn't handle multiple calls well
|
// Note: Llama 3.2 doesn't handle multiple calls well
|
||||||
// Test that we can at least parse a single call followed by semicolon
|
|
||||||
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
|
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
|
||||||
|
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
|||||||
@@ -102,7 +102,6 @@ impl PythonicParser {
|
|||||||
if bracket_count == 0 {
|
if bracket_count == 0 {
|
||||||
// Found the matching bracket
|
// Found the matching bracket
|
||||||
let extracted: String = chars[start_idx..=i].iter().collect();
|
let extracted: String = chars[start_idx..=i].iter().collect();
|
||||||
// Verify this actually contains a function call
|
|
||||||
if extracted.contains('(') && extracted.contains(')') {
|
if extracted.contains('(') && extracted.contains(')') {
|
||||||
return Some(extracted);
|
return Some(extracted);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,21 +21,18 @@ fn test_parse_state_new() {
|
|||||||
fn test_parse_state_process_char() {
|
fn test_parse_state_process_char() {
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Test bracket tracking
|
|
||||||
state.process_char('{');
|
state.process_char('{');
|
||||||
assert_eq!(state.bracket_depth, 1);
|
assert_eq!(state.bracket_depth, 1);
|
||||||
|
|
||||||
state.process_char('}');
|
state.process_char('}');
|
||||||
assert_eq!(state.bracket_depth, 0);
|
assert_eq!(state.bracket_depth, 0);
|
||||||
|
|
||||||
// Test string tracking
|
|
||||||
state.process_char('"');
|
state.process_char('"');
|
||||||
assert!(state.in_string);
|
assert!(state.in_string);
|
||||||
|
|
||||||
state.process_char('"');
|
state.process_char('"');
|
||||||
assert!(!state.in_string);
|
assert!(!state.in_string);
|
||||||
|
|
||||||
// Test escape handling
|
|
||||||
state.process_char('"');
|
state.process_char('"');
|
||||||
state.process_char('\\');
|
state.process_char('\\');
|
||||||
assert!(state.escape_next);
|
assert!(state.escape_next);
|
||||||
@@ -63,10 +60,8 @@ fn test_token_config() {
|
|||||||
fn test_parser_registry() {
|
fn test_parser_registry() {
|
||||||
let registry = ParserRegistry::new();
|
let registry = ParserRegistry::new();
|
||||||
|
|
||||||
// Test has default mappings
|
|
||||||
assert!(!registry.list_mappings().is_empty());
|
assert!(!registry.list_mappings().is_empty());
|
||||||
|
|
||||||
// Test model pattern matching
|
|
||||||
let mappings = registry.list_mappings();
|
let mappings = registry.list_mappings();
|
||||||
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
|
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
|
||||||
assert!(has_gpt);
|
assert!(has_gpt);
|
||||||
@@ -76,10 +71,8 @@ fn test_parser_registry() {
|
|||||||
fn test_parser_registry_pattern_matching() {
|
fn test_parser_registry_pattern_matching() {
|
||||||
let mut registry = ParserRegistry::new_for_testing();
|
let mut registry = ParserRegistry::new_for_testing();
|
||||||
|
|
||||||
// Test that model mappings work by checking the list
|
|
||||||
registry.map_model("test-model", "json");
|
registry.map_model("test-model", "json");
|
||||||
|
|
||||||
// Verify through list_mappings
|
|
||||||
let mappings = registry.list_mappings();
|
let mappings = registry.list_mappings();
|
||||||
let has_test = mappings
|
let has_test = mappings
|
||||||
.iter()
|
.iter()
|
||||||
@@ -112,25 +105,21 @@ fn test_tool_call_serialization() {
|
|||||||
fn test_partial_json_parser() {
|
fn test_partial_json_parser() {
|
||||||
let parser = PartialJson::default();
|
let parser = PartialJson::default();
|
||||||
|
|
||||||
// Test complete JSON
|
|
||||||
let input = r#"{"name": "test", "value": 42}"#;
|
let input = r#"{"name": "test", "value": 42}"#;
|
||||||
let (value, consumed) = parser.parse_value(input).unwrap();
|
let (value, consumed) = parser.parse_value(input).unwrap();
|
||||||
assert_eq!(value["name"], "test");
|
assert_eq!(value["name"], "test");
|
||||||
assert_eq!(value["value"], 42);
|
assert_eq!(value["value"], 42);
|
||||||
assert_eq!(consumed, input.len());
|
assert_eq!(consumed, input.len());
|
||||||
|
|
||||||
// Test incomplete JSON object
|
|
||||||
let input = r#"{"name": "test", "value": "#;
|
let input = r#"{"name": "test", "value": "#;
|
||||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||||
assert_eq!(value["name"], "test");
|
assert_eq!(value["name"], "test");
|
||||||
assert!(value["value"].is_null());
|
assert!(value["value"].is_null());
|
||||||
|
|
||||||
// Test incomplete string
|
|
||||||
let input = r#"{"name": "tes"#;
|
let input = r#"{"name": "tes"#;
|
||||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||||
assert_eq!(value["name"], "tes");
|
assert_eq!(value["name"], "tes");
|
||||||
|
|
||||||
// Test incomplete array
|
|
||||||
let input = r#"[1, 2, "#;
|
let input = r#"[1, 2, "#;
|
||||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||||
assert!(value.is_array());
|
assert!(value.is_array());
|
||||||
@@ -193,11 +182,9 @@ fn test_compute_diff() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_stream_result_variants() {
|
fn test_stream_result_variants() {
|
||||||
// Test Incomplete
|
|
||||||
let result = StreamResult::Incomplete;
|
let result = StreamResult::Incomplete;
|
||||||
matches!(result, StreamResult::Incomplete);
|
matches!(result, StreamResult::Incomplete);
|
||||||
|
|
||||||
// Test ToolName
|
|
||||||
let result = StreamResult::ToolName {
|
let result = StreamResult::ToolName {
|
||||||
index: 0,
|
index: 0,
|
||||||
name: "test".to_string(),
|
name: "test".to_string(),
|
||||||
@@ -209,7 +196,6 @@ fn test_stream_result_variants() {
|
|||||||
panic!("Expected ToolName variant");
|
panic!("Expected ToolName variant");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test ToolComplete
|
|
||||||
let tool = ToolCall {
|
let tool = ToolCall {
|
||||||
id: "123".to_string(),
|
id: "123".to_string(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
@@ -255,7 +241,6 @@ fn test_partial_tool_call() {
|
|||||||
async fn test_json_parser_complete_single() {
|
async fn test_json_parser_complete_single() {
|
||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
|
|
||||||
// Test single tool call with arguments
|
|
||||||
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#;
|
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#;
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
|
||||||
@@ -269,7 +254,6 @@ async fn test_json_parser_complete_single() {
|
|||||||
async fn test_json_parser_complete_array() {
|
async fn test_json_parser_complete_array() {
|
||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
|
|
||||||
// Test array of tool calls
|
|
||||||
let input = r#"[
|
let input = r#"[
|
||||||
{"name": "get_weather", "arguments": {"location": "SF"}},
|
{"name": "get_weather", "arguments": {"location": "SF"}},
|
||||||
{"name": "get_news", "arguments": {"query": "technology"}}
|
{"name": "get_news", "arguments": {"query": "technology"}}
|
||||||
@@ -286,7 +270,6 @@ async fn test_json_parser_complete_array() {
|
|||||||
async fn test_json_parser_with_parameters() {
|
async fn test_json_parser_with_parameters() {
|
||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
|
|
||||||
// Test with "parameters" instead of "arguments"
|
|
||||||
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#;
|
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#;
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
|
|
||||||
@@ -299,7 +282,6 @@ async fn test_json_parser_with_parameters() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_json_parser_with_tokens() {
|
async fn test_json_parser_with_tokens() {
|
||||||
// Test with custom wrapper tokens
|
|
||||||
let parser = JsonParser::with_config(TokenConfig {
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
start_tokens: vec!["[TOOL_CALLS] [".to_string()],
|
start_tokens: vec!["[TOOL_CALLS] [".to_string()],
|
||||||
end_tokens: vec!["]".to_string()],
|
end_tokens: vec!["]".to_string()],
|
||||||
@@ -315,7 +297,6 @@ async fn test_json_parser_with_tokens() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_multiline_json_with_tokens() {
|
async fn test_multiline_json_with_tokens() {
|
||||||
// Test that regex with (?s) flag properly handles multi-line JSON
|
|
||||||
let parser = JsonParser::with_config(TokenConfig {
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
start_tokens: vec!["<tool>".to_string()],
|
start_tokens: vec!["<tool>".to_string()],
|
||||||
end_tokens: vec!["</tool>".to_string()],
|
end_tokens: vec!["</tool>".to_string()],
|
||||||
@@ -342,7 +323,6 @@ async fn test_multiline_json_with_tokens() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_multiline_json_array() {
|
async fn test_multiline_json_array() {
|
||||||
// Test multi-line JSON array without wrapper tokens
|
|
||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
|
|
||||||
let input = r#"[
|
let input = r#"[
|
||||||
@@ -390,7 +370,6 @@ async fn test_json_parser_streaming() {
|
|||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Test with complete JSON
|
|
||||||
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||||
|
|
||||||
let result = parser
|
let result = parser
|
||||||
@@ -417,7 +396,6 @@ async fn test_registry_with_json_parser() {
|
|||||||
// Should get JSON parser for OpenAI models
|
// Should get JSON parser for OpenAI models
|
||||||
let parser = registry.get_parser("gpt-4-turbo").unwrap();
|
let parser = registry.get_parser("gpt-4-turbo").unwrap();
|
||||||
|
|
||||||
// Test that the parser works
|
|
||||||
let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
|
let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
@@ -677,7 +655,6 @@ mod edge_cases {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_multiple_token_pairs_with_conflicts() {
|
async fn test_multiple_token_pairs_with_conflicts() {
|
||||||
// Test with overlapping token patterns
|
|
||||||
let parser = JsonParser::with_config(TokenConfig {
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
start_tokens: vec!["<<".to_string(), "<tool>".to_string()],
|
start_tokens: vec!["<<".to_string(), "<tool>".to_string()],
|
||||||
end_tokens: vec![">>".to_string(), "</tool>".to_string()],
|
end_tokens: vec![">>".to_string(), "</tool>".to_string()],
|
||||||
@@ -708,7 +685,6 @@ mod edge_cases {
|
|||||||
async fn test_streaming_with_partial_chunks() {
|
async fn test_streaming_with_partial_chunks() {
|
||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
|
|
||||||
// Test 1: Very incomplete JSON (just opening brace) should return Incomplete
|
|
||||||
let mut state1 = ParseState::new();
|
let mut state1 = ParseState::new();
|
||||||
let partial = r#"{"#;
|
let partial = r#"{"#;
|
||||||
let result = parser
|
let result = parser
|
||||||
@@ -720,7 +696,6 @@ mod edge_cases {
|
|||||||
"Should return Incomplete for just opening brace"
|
"Should return Incomplete for just opening brace"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test 2: Complete JSON should return ToolComplete
|
|
||||||
let mut state2 = ParseState::new();
|
let mut state2 = ParseState::new();
|
||||||
let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
|
let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
|
||||||
let result = parser
|
let result = parser
|
||||||
@@ -738,7 +713,6 @@ mod edge_cases {
|
|||||||
_ => panic!("Expected ToolComplete for complete JSON"),
|
_ => panic!("Expected ToolComplete for complete JSON"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test 3: Partial JSON with name
|
|
||||||
// The PartialJson parser can complete partial JSON by filling in missing values
|
// The PartialJson parser can complete partial JSON by filling in missing values
|
||||||
let mut state3 = ParseState::new();
|
let mut state3 = ParseState::new();
|
||||||
let partial_with_name = r#"{"name": "test", "argum"#;
|
let partial_with_name = r#"{"name": "test", "argum"#;
|
||||||
@@ -863,7 +837,6 @@ mod stress_tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_concurrent_parser_usage() {
|
async fn test_concurrent_parser_usage() {
|
||||||
// Test that parser can be used concurrently
|
|
||||||
let parser = std::sync::Arc::new(JsonParser::new());
|
let parser = std::sync::Arc::new(JsonParser::new());
|
||||||
|
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|||||||
@@ -679,7 +679,6 @@ mod tests {
|
|||||||
fn test_get_smallest_tenant() {
|
fn test_get_smallest_tenant() {
|
||||||
let tree = Tree::new();
|
let tree = Tree::new();
|
||||||
|
|
||||||
// Test empty tree
|
|
||||||
assert_eq!(tree.get_smallest_tenant(), "empty");
|
assert_eq!(tree.get_smallest_tenant(), "empty");
|
||||||
|
|
||||||
// Insert data for tenant1 - "ap" + "icot" = 6 chars
|
// Insert data for tenant1 - "ap" + "icot" = 6 chars
|
||||||
@@ -689,7 +688,6 @@ mod tests {
|
|||||||
// Insert data for tenant2 - "cat" = 3 chars
|
// Insert data for tenant2 - "cat" = 3 chars
|
||||||
tree.insert("cat", "tenant2");
|
tree.insert("cat", "tenant2");
|
||||||
|
|
||||||
// Test - tenant2 should be smallest with 3 chars vs 6 chars
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tree.get_smallest_tenant(),
|
tree.get_smallest_tenant(),
|
||||||
"tenant2",
|
"tenant2",
|
||||||
@@ -702,7 +700,6 @@ mod tests {
|
|||||||
tree.insert("do", "tenant3");
|
tree.insert("do", "tenant3");
|
||||||
tree.insert("hi", "tenant4");
|
tree.insert("hi", "tenant4");
|
||||||
|
|
||||||
// Test - should return either tenant3 or tenant4 (both have 2 chars)
|
|
||||||
let smallest = tree.get_smallest_tenant();
|
let smallest = tree.get_smallest_tenant();
|
||||||
assert!(
|
assert!(
|
||||||
smallest == "tenant3" || smallest == "tenant4",
|
smallest == "tenant3" || smallest == "tenant4",
|
||||||
@@ -720,7 +717,6 @@ mod tests {
|
|||||||
"Expected tenant3 to be smallest with 2 characters"
|
"Expected tenant3 to be smallest with 2 characters"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test eviction
|
|
||||||
tree.evict_tenant_by_size(3); // This should evict tenants with more than 3 chars
|
tree.evict_tenant_by_size(3); // This should evict tenants with more than 3 chars
|
||||||
|
|
||||||
let post_eviction_smallest = tree.get_smallest_tenant();
|
let post_eviction_smallest = tree.get_smallest_tenant();
|
||||||
@@ -731,7 +727,6 @@ mod tests {
|
|||||||
fn test_tenant_char_count() {
|
fn test_tenant_char_count() {
|
||||||
let tree = Tree::new();
|
let tree = Tree::new();
|
||||||
|
|
||||||
// Phase 1: Initial insertions
|
|
||||||
tree.insert("apple", "tenant1");
|
tree.insert("apple", "tenant1");
|
||||||
tree.insert("apricot", "tenant1");
|
tree.insert("apricot", "tenant1");
|
||||||
tree.insert("banana", "tenant1");
|
tree.insert("banana", "tenant1");
|
||||||
@@ -755,7 +750,6 @@ mod tests {
|
|||||||
"Phase 1: Initial insertions"
|
"Phase 1: Initial insertions"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Phase 2: Additional insertions
|
|
||||||
tree.insert("apartment", "tenant1");
|
tree.insert("apartment", "tenant1");
|
||||||
tree.insert("appetite", "tenant2");
|
tree.insert("appetite", "tenant2");
|
||||||
tree.insert("ball", "tenant1");
|
tree.insert("ball", "tenant1");
|
||||||
@@ -778,7 +772,6 @@ mod tests {
|
|||||||
"Phase 2: Additional insertions"
|
"Phase 2: Additional insertions"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Phase 3: Overlapping insertions
|
|
||||||
tree.insert("zebra", "tenant1");
|
tree.insert("zebra", "tenant1");
|
||||||
tree.insert("zebra", "tenant2");
|
tree.insert("zebra", "tenant2");
|
||||||
tree.insert("zero", "tenant1");
|
tree.insert("zero", "tenant1");
|
||||||
@@ -801,7 +794,6 @@ mod tests {
|
|||||||
"Phase 3: Overlapping insertions"
|
"Phase 3: Overlapping insertions"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Phase 4: Eviction test
|
|
||||||
tree.evict_tenant_by_size(10);
|
tree.evict_tenant_by_size(10);
|
||||||
|
|
||||||
let computed_sizes = tree.get_used_size_per_tenant();
|
let computed_sizes = tree.get_used_size_per_tenant();
|
||||||
@@ -1088,8 +1080,6 @@ mod tests {
|
|||||||
|
|
||||||
tree.pretty_print();
|
tree.pretty_print();
|
||||||
|
|
||||||
// Test sequentially
|
|
||||||
|
|
||||||
for (text, tenant) in TEST_PAIRS.iter() {
|
for (text, tenant) in TEST_PAIRS.iter() {
|
||||||
let (matched_text, matched_tenant) = tree.prefix_match(text);
|
let (matched_text, matched_tenant) = tree.prefix_match(text);
|
||||||
assert_eq!(matched_text, *text);
|
assert_eq!(matched_text, *text);
|
||||||
@@ -1162,7 +1152,6 @@ mod tests {
|
|||||||
|
|
||||||
tree.pretty_print();
|
tree.pretty_print();
|
||||||
|
|
||||||
// Verify initial sizes
|
|
||||||
let sizes_before = tree.get_used_size_per_tenant();
|
let sizes_before = tree.get_used_size_per_tenant();
|
||||||
assert_eq!(sizes_before.get("tenant1").unwrap(), &5); // "hello" = 5
|
assert_eq!(sizes_before.get("tenant1").unwrap(), &5); // "hello" = 5
|
||||||
assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10
|
assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10
|
||||||
@@ -1172,12 +1161,10 @@ mod tests {
|
|||||||
|
|
||||||
tree.pretty_print();
|
tree.pretty_print();
|
||||||
|
|
||||||
// Verify sizes after eviction
|
|
||||||
let sizes_after = tree.get_used_size_per_tenant();
|
let sizes_after = tree.get_used_size_per_tenant();
|
||||||
assert_eq!(sizes_after.get("tenant1").unwrap(), &5); // Should be unchanged
|
assert_eq!(sizes_after.get("tenant1").unwrap(), &5); // Should be unchanged
|
||||||
assert_eq!(sizes_after.get("tenant2").unwrap(), &5); // Only "world" remains
|
assert_eq!(sizes_after.get("tenant2").unwrap(), &5); // Only "world" remains
|
||||||
|
|
||||||
// Verify "world" remains for tenant2
|
|
||||||
let (matched, tenant) = tree.prefix_match("world");
|
let (matched, tenant) = tree.prefix_match("world");
|
||||||
assert_eq!(matched, "world");
|
assert_eq!(matched, "world");
|
||||||
assert_eq!(tenant, "tenant2");
|
assert_eq!(tenant, "tenant2");
|
||||||
@@ -1208,7 +1195,6 @@ mod tests {
|
|||||||
|
|
||||||
// Check sizes after eviction
|
// Check sizes after eviction
|
||||||
let sizes_after = tree.get_used_size_per_tenant();
|
let sizes_after = tree.get_used_size_per_tenant();
|
||||||
// Verify all tenants are under their size limits
|
|
||||||
for (tenant, &size) in sizes_after.iter() {
|
for (tenant, &size) in sizes_after.iter() {
|
||||||
assert!(
|
assert!(
|
||||||
size <= max_size,
|
size <= max_size,
|
||||||
@@ -1287,7 +1273,6 @@ mod tests {
|
|||||||
let final_sizes = tree.get_used_size_per_tenant();
|
let final_sizes = tree.get_used_size_per_tenant();
|
||||||
println!("Final sizes after test completion: {:?}", final_sizes);
|
println!("Final sizes after test completion: {:?}", final_sizes);
|
||||||
|
|
||||||
// Verify all tenants are under limit
|
|
||||||
for (_, &size) in final_sizes.iter() {
|
for (_, &size) in final_sizes.iter() {
|
||||||
assert!(
|
assert!(
|
||||||
size <= max_size,
|
size <= max_size,
|
||||||
@@ -1364,14 +1349,12 @@ mod tests {
|
|||||||
tree.insert("help", "tenant1"); // tenant1: hel -> p
|
tree.insert("help", "tenant1"); // tenant1: hel -> p
|
||||||
tree.insert("helicopter", "tenant2"); // tenant2: hel -> icopter
|
tree.insert("helicopter", "tenant2"); // tenant2: hel -> icopter
|
||||||
|
|
||||||
// Test tenant1's data
|
|
||||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "hello"); // Full match for tenant1
|
assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "hello"); // Full match for tenant1
|
||||||
assert_eq!(tree.prefix_match_tenant("help", "tenant1"), "help"); // Exclusive to tenant1
|
assert_eq!(tree.prefix_match_tenant("help", "tenant1"), "help"); // Exclusive to tenant1
|
||||||
assert_eq!(tree.prefix_match_tenant("hel", "tenant1"), "hel"); // Shared prefix
|
assert_eq!(tree.prefix_match_tenant("hel", "tenant1"), "hel"); // Shared prefix
|
||||||
assert_eq!(tree.prefix_match_tenant("hello world", "tenant1"), "hello"); // Should stop at tenant1's boundary
|
assert_eq!(tree.prefix_match_tenant("hello world", "tenant1"), "hello"); // Should stop at tenant1's boundary
|
||||||
assert_eq!(tree.prefix_match_tenant("helicopter", "tenant1"), "hel"); // Should stop at tenant1's boundary
|
assert_eq!(tree.prefix_match_tenant("helicopter", "tenant1"), "hel"); // Should stop at tenant1's boundary
|
||||||
|
|
||||||
// Test tenant2's data
|
|
||||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); // Full match for tenant2
|
assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); // Full match for tenant2
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tree.prefix_match_tenant("hello world", "tenant2"),
|
tree.prefix_match_tenant("hello world", "tenant2"),
|
||||||
@@ -1384,7 +1367,6 @@ mod tests {
|
|||||||
assert_eq!(tree.prefix_match_tenant("hel", "tenant2"), "hel"); // Shared prefix
|
assert_eq!(tree.prefix_match_tenant("hel", "tenant2"), "hel"); // Shared prefix
|
||||||
assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "hel"); // Should stop at tenant2's boundary
|
assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "hel"); // Should stop at tenant2's boundary
|
||||||
|
|
||||||
// Test non-existent tenant
|
|
||||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant3"), ""); // Non-existent tenant
|
assert_eq!(tree.prefix_match_tenant("hello", "tenant3"), ""); // Non-existent tenant
|
||||||
assert_eq!(tree.prefix_match_tenant("help", "tenant3"), ""); // Non-existent tenant
|
assert_eq!(tree.prefix_match_tenant("help", "tenant3"), ""); // Non-existent tenant
|
||||||
}
|
}
|
||||||
@@ -1399,7 +1381,6 @@ mod tests {
|
|||||||
tree.insert("hello", "tenant2");
|
tree.insert("hello", "tenant2");
|
||||||
tree.insert("help", "tenant2");
|
tree.insert("help", "tenant2");
|
||||||
|
|
||||||
// Verify initial state
|
|
||||||
let initial_sizes = tree.get_used_size_per_tenant();
|
let initial_sizes = tree.get_used_size_per_tenant();
|
||||||
assert_eq!(initial_sizes.get("tenant1").unwrap(), &10); // "hello" + "world"
|
assert_eq!(initial_sizes.get("tenant1").unwrap(), &10); // "hello" + "world"
|
||||||
assert_eq!(initial_sizes.get("tenant2").unwrap(), &6); // "hello" + "p"
|
assert_eq!(initial_sizes.get("tenant2").unwrap(), &6); // "hello" + "p"
|
||||||
@@ -1407,7 +1388,6 @@ mod tests {
|
|||||||
// Evict tenant1
|
// Evict tenant1
|
||||||
tree.remove_tenant("tenant1");
|
tree.remove_tenant("tenant1");
|
||||||
|
|
||||||
// Verify after eviction
|
|
||||||
let final_sizes = tree.get_used_size_per_tenant();
|
let final_sizes = tree.get_used_size_per_tenant();
|
||||||
assert!(
|
assert!(
|
||||||
!final_sizes.contains_key("tenant1"),
|
!final_sizes.contains_key("tenant1"),
|
||||||
@@ -1419,11 +1399,9 @@ mod tests {
|
|||||||
"tenant2 should be unaffected"
|
"tenant2 should be unaffected"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify tenant1's data is inaccessible
|
|
||||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "");
|
assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "");
|
||||||
assert_eq!(tree.prefix_match_tenant("world", "tenant1"), "");
|
assert_eq!(tree.prefix_match_tenant("world", "tenant1"), "");
|
||||||
|
|
||||||
// Verify tenant2's data is still accessible
|
|
||||||
assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello");
|
assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello");
|
||||||
assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "help");
|
assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "help");
|
||||||
}
|
}
|
||||||
@@ -1441,7 +1419,6 @@ mod tests {
|
|||||||
tree.insert("banana", "tenant2");
|
tree.insert("banana", "tenant2");
|
||||||
tree.insert("ball", "tenant2");
|
tree.insert("ball", "tenant2");
|
||||||
|
|
||||||
// Verify initial state
|
|
||||||
let initial_sizes = tree.get_used_size_per_tenant();
|
let initial_sizes = tree.get_used_size_per_tenant();
|
||||||
println!("Initial sizes: {:?}", initial_sizes);
|
println!("Initial sizes: {:?}", initial_sizes);
|
||||||
tree.pretty_print();
|
tree.pretty_print();
|
||||||
@@ -1449,29 +1426,24 @@ mod tests {
|
|||||||
// Evict tenant1
|
// Evict tenant1
|
||||||
tree.remove_tenant("tenant1");
|
tree.remove_tenant("tenant1");
|
||||||
|
|
||||||
// Verify final state
|
|
||||||
let final_sizes = tree.get_used_size_per_tenant();
|
let final_sizes = tree.get_used_size_per_tenant();
|
||||||
println!("Final sizes: {:?}", final_sizes);
|
println!("Final sizes: {:?}", final_sizes);
|
||||||
tree.pretty_print();
|
tree.pretty_print();
|
||||||
|
|
||||||
// Verify tenant1 is completely removed
|
|
||||||
assert!(
|
assert!(
|
||||||
!final_sizes.contains_key("tenant1"),
|
!final_sizes.contains_key("tenant1"),
|
||||||
"tenant1 should be completely removed"
|
"tenant1 should be completely removed"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify all tenant1's data is inaccessible
|
|
||||||
assert_eq!(tree.prefix_match_tenant("apple", "tenant1"), "");
|
assert_eq!(tree.prefix_match_tenant("apple", "tenant1"), "");
|
||||||
assert_eq!(tree.prefix_match_tenant("application", "tenant1"), "");
|
assert_eq!(tree.prefix_match_tenant("application", "tenant1"), "");
|
||||||
assert_eq!(tree.prefix_match_tenant("banana", "tenant1"), "");
|
assert_eq!(tree.prefix_match_tenant("banana", "tenant1"), "");
|
||||||
|
|
||||||
// Verify tenant2's data is intact
|
|
||||||
assert_eq!(tree.prefix_match_tenant("apple", "tenant2"), "apple");
|
assert_eq!(tree.prefix_match_tenant("apple", "tenant2"), "apple");
|
||||||
assert_eq!(tree.prefix_match_tenant("appetite", "tenant2"), "appetite");
|
assert_eq!(tree.prefix_match_tenant("appetite", "tenant2"), "appetite");
|
||||||
assert_eq!(tree.prefix_match_tenant("banana", "tenant2"), "banana");
|
assert_eq!(tree.prefix_match_tenant("banana", "tenant2"), "banana");
|
||||||
assert_eq!(tree.prefix_match_tenant("ball", "tenant2"), "ball");
|
assert_eq!(tree.prefix_match_tenant("ball", "tenant2"), "ball");
|
||||||
|
|
||||||
// Verify the tree structure is still valid for tenant2
|
|
||||||
let tenant2_size = final_sizes.get("tenant2").unwrap();
|
let tenant2_size = final_sizes.get("tenant2").unwrap();
|
||||||
assert_eq!(tenant2_size, &(5 + 5 + 6 + 2)); // "apple" + "etite" + "banana" + "ll"
|
assert_eq!(tenant2_size, &(5 + 5 + 6 + 2)); // "apple" + "etite" + "banana" + "ll"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -576,7 +576,6 @@ mod model_info_tests {
|
|||||||
let ctx = TestContext::new(vec![]).await;
|
let ctx = TestContext::new(vec![]).await;
|
||||||
let app = ctx.create_app().await;
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
// Test server info with no workers
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("GET")
|
.method("GET")
|
||||||
.uri("/get_server_info")
|
.uri("/get_server_info")
|
||||||
@@ -593,7 +592,6 @@ mod model_info_tests {
|
|||||||
resp.status()
|
resp.status()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test model info with no workers
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("GET")
|
.method("GET")
|
||||||
.uri("/get_model_info")
|
.uri("/get_model_info")
|
||||||
@@ -610,7 +608,6 @@ mod model_info_tests {
|
|||||||
resp.status()
|
resp.status()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test v1/models with no workers
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("GET")
|
.method("GET")
|
||||||
.uri("/v1/models")
|
.uri("/v1/models")
|
||||||
@@ -652,7 +649,6 @@ mod model_info_tests {
|
|||||||
|
|
||||||
let app = ctx.create_app().await;
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
// Test that model info is consistent across workers
|
|
||||||
for _ in 0..5 {
|
for _ in 0..5 {
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("GET")
|
.method("GET")
|
||||||
@@ -795,7 +791,6 @@ mod worker_management_tests {
|
|||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
|
|
||||||
// Verify it's removed
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("GET")
|
.method("GET")
|
||||||
.uri("/list_workers")
|
.uri("/list_workers")
|
||||||
@@ -1302,7 +1297,6 @@ mod error_tests {
|
|||||||
|
|
||||||
let app = ctx.create_app().await;
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
// Test unknown endpoint
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("GET")
|
.method("GET")
|
||||||
.uri("/unknown_endpoint")
|
.uri("/unknown_endpoint")
|
||||||
@@ -1312,7 +1306,6 @@ mod error_tests {
|
|||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||||
|
|
||||||
// Test POST to unknown endpoint
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.uri("/api/v2/generate")
|
.uri("/api/v2/generate")
|
||||||
@@ -1606,7 +1599,6 @@ mod cache_tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
|
||||||
// Verify the response contains load information
|
|
||||||
assert!(body_json.is_object());
|
assert!(body_json.is_object());
|
||||||
// The exact structure depends on the implementation
|
// The exact structure depends on the implementation
|
||||||
// but should contain worker load information
|
// but should contain worker load information
|
||||||
@@ -1797,7 +1789,6 @@ mod request_id_tests {
|
|||||||
|
|
||||||
let app = ctx.create_app().await;
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
// Test 1: Request without any request ID header should generate one
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"text": "Test request",
|
"text": "Test request",
|
||||||
"stream": false
|
"stream": false
|
||||||
@@ -1830,7 +1821,6 @@ mod request_id_tests {
|
|||||||
"Request ID should have content after prefix"
|
"Request ID should have content after prefix"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test 2: Request with custom x-request-id should preserve it
|
|
||||||
let custom_id = "custom-request-id-123";
|
let custom_id = "custom-request-id-123";
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
@@ -1847,7 +1837,6 @@ mod request_id_tests {
|
|||||||
assert!(response_id.is_some());
|
assert!(response_id.is_some());
|
||||||
assert_eq!(response_id.unwrap(), custom_id);
|
assert_eq!(response_id.unwrap(), custom_id);
|
||||||
|
|
||||||
// Test 3: Different endpoints should have different prefixes
|
|
||||||
let chat_payload = json!({
|
let chat_payload = json!({
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
"model": "test-model"
|
"model": "test-model"
|
||||||
@@ -1871,7 +1860,6 @@ mod request_id_tests {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.starts_with("chatcmpl-"));
|
.starts_with("chatcmpl-"));
|
||||||
|
|
||||||
// Test 4: Alternative request ID headers should be recognized
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.uri("/generate")
|
.uri("/generate")
|
||||||
@@ -1948,7 +1936,6 @@ mod request_id_tests {
|
|||||||
"stream": false
|
"stream": false
|
||||||
});
|
});
|
||||||
|
|
||||||
// Test custom header is recognized
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.uri("/generate")
|
.uri("/generate")
|
||||||
@@ -2013,7 +2000,6 @@ mod rerank_tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
|
||||||
// Verify response structure
|
|
||||||
assert!(body_json.get("results").is_some());
|
assert!(body_json.get("results").is_some());
|
||||||
assert!(body_json.get("model").is_some());
|
assert!(body_json.get("model").is_some());
|
||||||
assert_eq!(body_json["model"], "test-rerank-model");
|
assert_eq!(body_json["model"], "test-rerank-model");
|
||||||
@@ -2021,7 +2007,6 @@ mod rerank_tests {
|
|||||||
let results = body_json["results"].as_array().unwrap();
|
let results = body_json["results"].as_array().unwrap();
|
||||||
assert_eq!(results.len(), 2);
|
assert_eq!(results.len(), 2);
|
||||||
|
|
||||||
// Verify results are sorted by score (highest first)
|
|
||||||
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
|
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
|
||||||
|
|
||||||
ctx.shutdown().await;
|
ctx.shutdown().await;
|
||||||
@@ -2164,7 +2149,6 @@ mod rerank_tests {
|
|||||||
|
|
||||||
let app = ctx.create_app().await;
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
// Test V1 API format (simplified input)
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"query": "machine learning algorithms",
|
"query": "machine learning algorithms",
|
||||||
"documents": [
|
"documents": [
|
||||||
@@ -2189,7 +2173,6 @@ mod rerank_tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
|
||||||
// Verify response structure
|
|
||||||
assert!(body_json.get("results").is_some());
|
assert!(body_json.get("results").is_some());
|
||||||
assert!(body_json.get("model").is_some());
|
assert!(body_json.get("model").is_some());
|
||||||
|
|
||||||
@@ -2199,7 +2182,6 @@ mod rerank_tests {
|
|||||||
let results = body_json["results"].as_array().unwrap();
|
let results = body_json["results"].as_array().unwrap();
|
||||||
assert_eq!(results.len(), 3); // All documents should be returned
|
assert_eq!(results.len(), 3); // All documents should be returned
|
||||||
|
|
||||||
// Verify results are sorted by score (highest first)
|
|
||||||
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
|
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
|
||||||
assert!(results[1]["score"].as_f64().unwrap() >= results[2]["score"].as_f64().unwrap());
|
assert!(results[1]["score"].as_f64().unwrap() >= results[2]["score"].as_f64().unwrap());
|
||||||
|
|
||||||
@@ -2224,7 +2206,6 @@ mod rerank_tests {
|
|||||||
|
|
||||||
let app = ctx.create_app().await;
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
// Test empty query string (validation should fail)
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"query": "",
|
"query": "",
|
||||||
"documents": ["Document 1", "Document 2"],
|
"documents": ["Document 1", "Document 2"],
|
||||||
@@ -2241,7 +2222,6 @@ mod rerank_tests {
|
|||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
|
|
||||||
// Test query with only whitespace (validation should fail)
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"query": " ",
|
"query": " ",
|
||||||
"documents": ["Document 1", "Document 2"],
|
"documents": ["Document 1", "Document 2"],
|
||||||
@@ -2258,7 +2238,6 @@ mod rerank_tests {
|
|||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
|
|
||||||
// Test empty documents list (validation should fail)
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"query": "test query",
|
"query": "test query",
|
||||||
"documents": [],
|
"documents": [],
|
||||||
@@ -2275,7 +2254,6 @@ mod rerank_tests {
|
|||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
|
|
||||||
// Test invalid top_k (validation should fail)
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"query": "test query",
|
"query": "test query",
|
||||||
"documents": ["Document 1", "Document 2"],
|
"documents": ["Document 1", "Document 2"],
|
||||||
|
|||||||
@@ -93,19 +93,16 @@ fn test_mixed_model_ids() {
|
|||||||
policy.add_worker(&worker3);
|
policy.add_worker(&worker3);
|
||||||
policy.add_worker(&worker4);
|
policy.add_worker(&worker4);
|
||||||
|
|
||||||
// Test selection with default workers only
|
|
||||||
let default_workers: Vec<Arc<dyn Worker>> =
|
let default_workers: Vec<Arc<dyn Worker>> =
|
||||||
vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())];
|
vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())];
|
||||||
let selected = policy.select_worker(&default_workers, Some("test request"));
|
let selected = policy.select_worker(&default_workers, Some("test request"));
|
||||||
assert!(selected.is_some(), "Should select from default workers");
|
assert!(selected.is_some(), "Should select from default workers");
|
||||||
|
|
||||||
// Test selection with specific model workers only
|
|
||||||
let llama_workers: Vec<Arc<dyn Worker>> =
|
let llama_workers: Vec<Arc<dyn Worker>> =
|
||||||
vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())];
|
vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())];
|
||||||
let selected = policy.select_worker(&llama_workers, Some("test request"));
|
let selected = policy.select_worker(&llama_workers, Some("test request"));
|
||||||
assert!(selected.is_some(), "Should select from llama-3 workers");
|
assert!(selected.is_some(), "Should select from llama-3 workers");
|
||||||
|
|
||||||
// Test selection with mixed workers
|
|
||||||
let all_workers: Vec<Arc<dyn Worker>> = vec![
|
let all_workers: Vec<Arc<dyn Worker>> = vec![
|
||||||
Arc::new(worker1.clone()),
|
Arc::new(worker1.clone()),
|
||||||
Arc::new(worker2.clone()),
|
Arc::new(worker2.clone()),
|
||||||
@@ -144,7 +141,6 @@ fn test_remove_worker_by_url_backward_compat() {
|
|||||||
// Should remove from all trees since we don't know the model
|
// Should remove from all trees since we don't know the model
|
||||||
policy.remove_worker_by_url("http://worker1:8080");
|
policy.remove_worker_by_url("http://worker1:8080");
|
||||||
|
|
||||||
// Verify removal worked
|
|
||||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())];
|
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())];
|
||||||
let selected = policy.select_worker(&workers, Some("test"));
|
let selected = policy.select_worker(&workers, Some("test"));
|
||||||
assert_eq!(selected, Some(0), "Should only have worker2 left");
|
assert_eq!(selected, Some(0), "Should only have worker2 left");
|
||||||
|
|||||||
@@ -89,7 +89,6 @@ fn test_chat_template_with_tokens() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_llama_style_template() {
|
fn test_llama_style_template() {
|
||||||
// Test a Llama-style chat template
|
|
||||||
let template = r#"
|
let template = r#"
|
||||||
{%- if messages[0]['role'] == 'system' -%}
|
{%- if messages[0]['role'] == 'system' -%}
|
||||||
{%- set system_message = messages[0]['content'] -%}
|
{%- set system_message = messages[0]['content'] -%}
|
||||||
@@ -160,7 +159,6 @@ fn test_llama_style_template() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chatml_template() {
|
fn test_chatml_template() {
|
||||||
// Test a ChatML-style template
|
|
||||||
let template = r#"
|
let template = r#"
|
||||||
{%- for message in messages %}
|
{%- for message in messages %}
|
||||||
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
||||||
@@ -241,13 +239,11 @@ assistant:
|
|||||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Test without generation prompt
|
|
||||||
let result = processor
|
let result = processor
|
||||||
.apply_chat_template(&json_messages, ChatTemplateParams::default())
|
.apply_chat_template(&json_messages, ChatTemplateParams::default())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(result.trim(), "user: Test");
|
assert_eq!(result.trim(), "user: Test");
|
||||||
|
|
||||||
// Test with generation prompt
|
|
||||||
let result_with_prompt = processor
|
let result_with_prompt = processor
|
||||||
.apply_chat_template(
|
.apply_chat_template(
|
||||||
&json_messages,
|
&json_messages,
|
||||||
@@ -275,7 +271,6 @@ fn test_empty_messages_template() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_content_format_detection() {
|
fn test_content_format_detection() {
|
||||||
// Test string format detection
|
|
||||||
let string_template = r#"
|
let string_template = r#"
|
||||||
{%- for message in messages -%}
|
{%- for message in messages -%}
|
||||||
{{ message.role }}: {{ message.content }}
|
{{ message.role }}: {{ message.content }}
|
||||||
@@ -286,7 +281,6 @@ fn test_content_format_detection() {
|
|||||||
ChatTemplateContentFormat::String
|
ChatTemplateContentFormat::String
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test OpenAI format detection
|
|
||||||
let openai_template = r#"
|
let openai_template = r#"
|
||||||
{%- for message in messages -%}
|
{%- for message in messages -%}
|
||||||
{%- for content in message.content -%}
|
{%- for content in message.content -%}
|
||||||
@@ -302,7 +296,6 @@ fn test_content_format_detection() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_template_with_multimodal_content() {
|
fn test_template_with_multimodal_content() {
|
||||||
// Test that multimodal messages work correctly when serialized to JSON
|
|
||||||
let template = r#"
|
let template = r#"
|
||||||
{%- for message in messages %}
|
{%- for message in messages %}
|
||||||
{{ message.role }}:
|
{{ message.role }}:
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Test that the custom template is used
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
spec::ChatMessage::User {
|
spec::ChatMessage::User {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
@@ -89,7 +88,6 @@ mod tests {
|
|||||||
.apply_chat_template(&json_messages, params)
|
.apply_chat_template(&json_messages, params)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Verify the custom template format
|
|
||||||
assert!(result.contains("<|user|>Hello"));
|
assert!(result.contains("<|user|>Hello"));
|
||||||
assert!(result.contains("<|assistant|>Hi there"));
|
assert!(result.contains("<|assistant|>Hi there"));
|
||||||
assert!(result.ends_with("<|assistant|>"));
|
assert!(result.ends_with("<|assistant|>"));
|
||||||
|
|||||||
@@ -148,7 +148,6 @@ mod tests {
|
|||||||
async fn test_mock_server_with_rmcp_client() {
|
async fn test_mock_server_with_rmcp_client() {
|
||||||
let mut server = MockMCPServer::start().await.unwrap();
|
let mut server = MockMCPServer::start().await.unwrap();
|
||||||
|
|
||||||
// Test that we can connect with rmcp client
|
|
||||||
use rmcp::transport::StreamableHttpClientTransport;
|
use rmcp::transport::StreamableHttpClientTransport;
|
||||||
use rmcp::ServiceExt;
|
use rmcp::ServiceExt;
|
||||||
|
|
||||||
@@ -158,7 +157,6 @@ mod tests {
|
|||||||
assert!(client.is_ok(), "Should be able to connect to mock server");
|
assert!(client.is_ok(), "Should be able to connect to mock server");
|
||||||
|
|
||||||
if let Ok(client) = client {
|
if let Ok(client) = client {
|
||||||
// Test listing tools
|
|
||||||
let tools = client.peer().list_all_tools().await;
|
let tools = client.peer().list_all_tools().await;
|
||||||
assert!(tools.is_ok(), "Should be able to list tools");
|
assert!(tools.is_ok(), "Should be able to list tools");
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,6 @@ pub fn ensure_tokenizer_cached() -> PathBuf {
|
|||||||
|
|
||||||
let content = response.bytes().expect("Failed to read tokenizer content");
|
let content = response.bytes().expect("Failed to read tokenizer content");
|
||||||
|
|
||||||
// Verify we got actual JSON content
|
|
||||||
if content.len() < 100 {
|
if content.len() < 100 {
|
||||||
panic!("Downloaded content too small: {} bytes", content.len());
|
panic!("Downloaded content too small: {} bytes", content.len());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
// This test suite validates the complete MCP implementation against the
|
// This test suite validates the complete MCP implementation against the
|
||||||
// functionality required for SGLang responses API integration.
|
// functionality required for SGLang responses API integration.
|
||||||
//
|
//
|
||||||
// Test Coverage:
|
|
||||||
// - Core MCP server functionality
|
// - Core MCP server functionality
|
||||||
// - Tool session management (individual and multi-tool)
|
// - Tool session management (individual and multi-tool)
|
||||||
// - Tool execution and error handling
|
// - Tool execution and error handling
|
||||||
@@ -26,7 +25,6 @@ async fn create_mock_server() -> MockMCPServer {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_mcp_server_initialization() {
|
async fn test_mcp_server_initialization() {
|
||||||
// Test that we can create an empty configuration
|
|
||||||
let config = McpConfig { servers: vec![] };
|
let config = McpConfig { servers: vec![] };
|
||||||
|
|
||||||
// Should fail with no servers
|
// Should fail with no servers
|
||||||
@@ -329,7 +327,6 @@ async fn test_tool_info_structure() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_sse_connection() {
|
async fn test_sse_connection() {
|
||||||
// Test with a non-existent command using STDIO to avoid retry delays
|
|
||||||
// This tests that SSE configuration is properly handled even when connection fails
|
// This tests that SSE configuration is properly handled even when connection fails
|
||||||
let config = McpConfig {
|
let config = McpConfig {
|
||||||
servers: vec![McpServerConfig {
|
servers: vec![McpServerConfig {
|
||||||
@@ -351,8 +348,6 @@ async fn test_sse_connection() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_transport_types() {
|
async fn test_transport_types() {
|
||||||
// Test different transport configurations
|
|
||||||
|
|
||||||
// HTTP/Streamable transport
|
// HTTP/Streamable transport
|
||||||
let http_config = McpServerConfig {
|
let http_config = McpServerConfig {
|
||||||
name: "http_server".to_string(),
|
name: "http_server".to_string(),
|
||||||
@@ -444,7 +439,6 @@ async fn test_complete_workflow() {
|
|||||||
// 7. Clean shutdown
|
// 7. Clean shutdown
|
||||||
manager.shutdown().await;
|
manager.shutdown().await;
|
||||||
|
|
||||||
// Verify all required capabilities for responses API integration
|
|
||||||
let capabilities = [
|
let capabilities = [
|
||||||
"MCP server initialization",
|
"MCP server initialization",
|
||||||
"Tool server connection and discovery",
|
"Tool server connection and discovery",
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ async fn test_policy_registry_with_router_manager() {
|
|||||||
// Create RouterManager with shared registries
|
// Create RouterManager with shared registries
|
||||||
let _router_manager = RouterManager::new(worker_registry.clone());
|
let _router_manager = RouterManager::new(worker_registry.clone());
|
||||||
|
|
||||||
// Test adding workers with different models and policies
|
|
||||||
|
|
||||||
// Add first worker for llama-3 with cache_aware policy hint
|
// Add first worker for llama-3 with cache_aware policy hint
|
||||||
let mut labels1 = HashMap::new();
|
let mut labels1 = HashMap::new();
|
||||||
labels1.insert("policy".to_string(), "cache_aware".to_string());
|
labels1.insert("policy".to_string(), "cache_aware".to_string());
|
||||||
@@ -44,7 +42,6 @@ async fn test_policy_registry_with_router_manager() {
|
|||||||
// This would normally connect to a real worker, but for testing we'll just verify the structure
|
// This would normally connect to a real worker, but for testing we'll just verify the structure
|
||||||
// In a real test, we'd need to mock the worker or use a test server
|
// In a real test, we'd need to mock the worker or use a test server
|
||||||
|
|
||||||
// Verify PolicyRegistry has the correct policy for llama-3
|
|
||||||
let _llama_policy = policy_registry.get_policy("llama-3");
|
let _llama_policy = policy_registry.get_policy("llama-3");
|
||||||
// After first worker is added, llama-3 should have a policy
|
// After first worker is added, llama-3 should have a policy
|
||||||
|
|
||||||
@@ -88,10 +85,8 @@ async fn test_policy_registry_with_router_manager() {
|
|||||||
chat_template: None,
|
chat_template: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Verify gpt-4 has random policy
|
|
||||||
let _gpt_policy = policy_registry.get_policy("gpt-4");
|
let _gpt_policy = policy_registry.get_policy("gpt-4");
|
||||||
|
|
||||||
// Test removing workers
|
|
||||||
// When we remove both llama-3 workers, the policy should be cleaned up
|
// When we remove both llama-3 workers, the policy should be cleaned up
|
||||||
|
|
||||||
println!("PolicyRegistry integration test structure created");
|
println!("PolicyRegistry integration test structure created");
|
||||||
@@ -113,7 +108,6 @@ fn test_policy_registry_cleanup() {
|
|||||||
let policy2 = registry.on_worker_added("model-1", Some("random"));
|
let policy2 = registry.on_worker_added("model-1", Some("random"));
|
||||||
assert_eq!(policy2.name(), "cache_aware"); // Should still be cache_aware
|
assert_eq!(policy2.name(), "cache_aware"); // Should still be cache_aware
|
||||||
|
|
||||||
// Verify policy exists
|
|
||||||
assert!(registry.get_policy("model-1").is_some());
|
assert!(registry.get_policy("model-1").is_some());
|
||||||
|
|
||||||
// Remove first worker - policy should remain
|
// Remove first worker - policy should remain
|
||||||
@@ -143,7 +137,6 @@ fn test_policy_registry_multiple_models() {
|
|||||||
assert_eq!(gpt_policy.name(), "random");
|
assert_eq!(gpt_policy.name(), "random");
|
||||||
assert_eq!(mistral_policy.name(), "round_robin"); // Default
|
assert_eq!(mistral_policy.name(), "round_robin"); // Default
|
||||||
|
|
||||||
// Verify all policies are stored
|
|
||||||
assert!(registry.get_policy("llama-3").is_some());
|
assert!(registry.get_policy("llama-3").is_some());
|
||||||
assert!(registry.get_policy("gpt-4").is_some());
|
assert!(registry.get_policy("gpt-4").is_some());
|
||||||
assert!(registry.get_policy("mistral").is_some());
|
assert!(registry.get_policy("mistral").is_some());
|
||||||
|
|||||||
@@ -126,7 +126,6 @@ mod request_format_tests {
|
|||||||
}])
|
}])
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Test 1: Basic text request
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"text": "Hello, world!",
|
"text": "Hello, world!",
|
||||||
"stream": false
|
"stream": false
|
||||||
@@ -135,7 +134,6 @@ mod request_format_tests {
|
|||||||
let result = ctx.make_request("/generate", payload).await;
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Test 2: Request with sampling parameters
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"text": "Tell me a story",
|
"text": "Tell me a story",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -149,7 +147,6 @@ mod request_format_tests {
|
|||||||
let result = ctx.make_request("/generate", payload).await;
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Test 3: Request with input_ids
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"input_ids": [1, 2, 3, 4, 5],
|
"input_ids": [1, 2, 3, 4, 5],
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -176,7 +173,6 @@ mod request_format_tests {
|
|||||||
}])
|
}])
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Test 1: Basic chat completion
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -197,7 +193,6 @@ mod request_format_tests {
|
|||||||
Some("chat.completion")
|
Some("chat.completion")
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test 2: Chat completion with parameters
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -226,7 +221,6 @@ mod request_format_tests {
|
|||||||
}])
|
}])
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Test 1: Basic completion
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"prompt": "Once upon a time",
|
"prompt": "Once upon a time",
|
||||||
@@ -244,7 +238,6 @@ mod request_format_tests {
|
|||||||
Some("text_completion")
|
Some("text_completion")
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test 2: Completion with array prompt
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"prompt": ["First prompt", "Second prompt"],
|
"prompt": ["First prompt", "Second prompt"],
|
||||||
@@ -255,7 +248,6 @@ mod request_format_tests {
|
|||||||
let result = ctx.make_request("/v1/completions", payload).await;
|
let result = ctx.make_request("/v1/completions", payload).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Test 3: Completion with logprobs
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"prompt": "The capital of France is",
|
"prompt": "The capital of France is",
|
||||||
@@ -281,7 +273,6 @@ mod request_format_tests {
|
|||||||
}])
|
}])
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Test batch text generation
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"text": ["First text", "Second text", "Third text"],
|
"text": ["First text", "Second text", "Third text"],
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -294,7 +285,6 @@ mod request_format_tests {
|
|||||||
let result = ctx.make_request("/generate", payload).await;
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Test batch with input_ids
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||||
"stream": false
|
"stream": false
|
||||||
@@ -317,7 +307,6 @@ mod request_format_tests {
|
|||||||
}])
|
}])
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Test with return_logprob
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"text": "Test",
|
"text": "Test",
|
||||||
"return_logprob": true,
|
"return_logprob": true,
|
||||||
@@ -327,7 +316,6 @@ mod request_format_tests {
|
|||||||
let result = ctx.make_request("/generate", payload).await;
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Test with json_schema
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"text": "Generate JSON",
|
"text": "Generate JSON",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -340,7 +328,6 @@ mod request_format_tests {
|
|||||||
let result = ctx.make_request("/generate", payload).await;
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Test with ignore_eos
|
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
"text": "Continue forever",
|
"text": "Continue forever",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -368,7 +355,6 @@ mod request_format_tests {
|
|||||||
}])
|
}])
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Test with empty body - should still work with mock worker
|
|
||||||
let payload = json!({});
|
let payload = json!({});
|
||||||
|
|
||||||
let result = ctx.make_request("/generate", payload).await;
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ fn test_responses_request_creation() {
|
|||||||
repetition_penalty: 1.0,
|
repetition_penalty: 1.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test GenerationRequest trait implementation
|
|
||||||
assert!(!request.is_stream());
|
assert!(!request.is_stream());
|
||||||
assert_eq!(request.get_model(), Some("test-model"));
|
assert_eq!(request.get_model(), Some("test-model"));
|
||||||
let routing_text = request.extract_text_for_routing();
|
let routing_text = request.extract_text_for_routing();
|
||||||
@@ -139,7 +138,6 @@ fn test_usage_conversion() {
|
|||||||
8
|
8
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test reverse conversion
|
|
||||||
let back_to_usage = response_usage.to_usage_info();
|
let back_to_usage = response_usage.to_usage_info();
|
||||||
assert_eq!(back_to_usage.prompt_tokens, 15);
|
assert_eq!(back_to_usage.prompt_tokens, 15);
|
||||||
assert_eq!(back_to_usage.completion_tokens, 25);
|
assert_eq!(back_to_usage.completion_tokens, 25);
|
||||||
@@ -152,7 +150,6 @@ fn test_reasoning_param_default() {
|
|||||||
effort: Some(ReasoningEffort::Medium),
|
effort: Some(ReasoningEffort::Medium),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test JSON serialization/deserialization preserves default
|
|
||||||
let json = serde_json::to_string(¶m).unwrap();
|
let json = serde_json::to_string(¶m).unwrap();
|
||||||
let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap();
|
let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap();
|
||||||
|
|
||||||
@@ -197,7 +194,6 @@ fn test_json_serialization() {
|
|||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.2,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test that everything can be serialized to JSON and back
|
|
||||||
let json = serde_json::to_string(&request).expect("Serialization should work");
|
let json = serde_json::to_string(&request).expect("Serialization should work");
|
||||||
let parsed: ResponsesRequest =
|
let parsed: ResponsesRequest =
|
||||||
serde_json::from_str(&json).expect("Deserialization should work");
|
serde_json::from_str(&json).expect("Deserialization should work");
|
||||||
|
|||||||
@@ -197,7 +197,6 @@ mod streaming_tests {
|
|||||||
let events = result.unwrap();
|
let events = result.unwrap();
|
||||||
assert!(events.len() >= 2); // At least one chunk + [DONE]
|
assert!(events.len() >= 2); // At least one chunk + [DONE]
|
||||||
|
|
||||||
// Verify events are valid JSON (except [DONE])
|
|
||||||
for event in &events {
|
for event in &events {
|
||||||
if event != "[DONE]" {
|
if event != "[DONE]" {
|
||||||
let parsed: Result<serde_json::Value, _> = serde_json::from_str(event);
|
let parsed: Result<serde_json::Value, _> = serde_json::from_str(event);
|
||||||
@@ -329,7 +328,6 @@ mod streaming_tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_sse_format_parsing() {
|
async fn test_sse_format_parsing() {
|
||||||
// Test SSE format parsing
|
|
||||||
let parse_sse_chunk = |chunk: &[u8]| -> Vec<String> {
|
let parse_sse_chunk = |chunk: &[u8]| -> Vec<String> {
|
||||||
let text = String::from_utf8_lossy(chunk);
|
let text = String::from_utf8_lossy(chunk);
|
||||||
text.lines()
|
text.lines()
|
||||||
@@ -347,7 +345,6 @@ mod streaming_tests {
|
|||||||
assert_eq!(events[1], "{\"text\":\" world\"}");
|
assert_eq!(events[1], "{\"text\":\" world\"}");
|
||||||
assert_eq!(events[2], "[DONE]");
|
assert_eq!(events[2], "[DONE]");
|
||||||
|
|
||||||
// Test with mixed content
|
|
||||||
let mixed = b"event: message\ndata: {\"test\":true}\n\n: comment\ndata: [DONE]\n\n";
|
let mixed = b"event: message\ndata: {\"test\":true}\n\n: comment\ndata: [DONE]\n\n";
|
||||||
let events = parse_sse_chunk(mixed);
|
let events = parse_sse_chunk(mixed);
|
||||||
|
|
||||||
|
|||||||
@@ -84,8 +84,6 @@ fn create_minimal_completion_request() -> CompletionRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Basic Unit Tests =============
|
|
||||||
|
|
||||||
/// Test basic OpenAI router creation and configuration
|
/// Test basic OpenAI router creation and configuration
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_openai_router_creation() {
|
async fn test_openai_router_creation() {
|
||||||
@@ -575,7 +573,6 @@ async fn test_unsupported_endpoints() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Test generate endpoint (SGLang-specific, should not be supported)
|
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
prompt: None,
|
prompt: None,
|
||||||
text: Some("Hello world".to_string()),
|
text: Some("Hello world".to_string()),
|
||||||
@@ -593,7 +590,6 @@ async fn test_unsupported_endpoints() {
|
|||||||
let response = router.route_generate(None, &generate_request, None).await;
|
let response = router.route_generate(None, &generate_request, None).await;
|
||||||
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||||
|
|
||||||
// Test completion endpoint (should also not be supported)
|
|
||||||
let completion_request = create_minimal_completion_request();
|
let completion_request = create_minimal_completion_request();
|
||||||
let response = router
|
let response = router
|
||||||
.route_completion(None, &completion_request, None)
|
.route_completion(None, &completion_request, None)
|
||||||
@@ -601,8 +597,6 @@ async fn test_unsupported_endpoints() {
|
|||||||
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Mock Server E2E Tests =============
|
|
||||||
|
|
||||||
/// Test chat completion with mock OpenAI server
|
/// Test chat completion with mock OpenAI server
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_openai_router_chat_completion_with_mock() {
|
async fn test_openai_router_chat_completion_with_mock() {
|
||||||
@@ -635,7 +629,6 @@ async fn test_openai_router_chat_completion_with_mock() {
|
|||||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||||
let chat_response: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
let chat_response: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||||
|
|
||||||
// Verify it's a valid chat completion response
|
|
||||||
assert_eq!(chat_response["object"], "chat.completion");
|
assert_eq!(chat_response["object"], "chat.completion");
|
||||||
assert_eq!(chat_response["model"], "gpt-3.5-turbo");
|
assert_eq!(chat_response["model"], "gpt-3.5-turbo");
|
||||||
assert!(!chat_response["choices"].as_array().unwrap().is_empty());
|
assert!(!chat_response["choices"].as_array().unwrap().is_empty());
|
||||||
@@ -704,7 +697,6 @@ async fn test_openai_e2e_with_server() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
|
||||||
// Verify the response structure
|
|
||||||
assert_eq!(response_json["object"], "chat.completion");
|
assert_eq!(response_json["object"], "chat.completion");
|
||||||
assert_eq!(response_json["model"], "gpt-3.5-turbo");
|
assert_eq!(response_json["model"], "gpt-3.5-turbo");
|
||||||
assert!(!response_json["choices"].as_array().unwrap().is_empty());
|
assert!(!response_json["choices"].as_array().unwrap().is_empty());
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ mod test_pd_routing {
|
|||||||
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
|
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
|
||||||
use sglang_router_rs::routers::RouterFactory;
|
use sglang_router_rs::routers::RouterFactory;
|
||||||
|
|
||||||
// Test-only struct to help validate PD request parsing
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct PDRequest {
|
struct PDRequest {
|
||||||
pub is_stream: bool,
|
pub is_stream: bool,
|
||||||
@@ -17,14 +16,12 @@ mod test_pd_routing {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PDRequest {
|
impl PDRequest {
|
||||||
// Extract PD-relevant info from JSON for testing
|
|
||||||
pub fn from_json(json: &serde_json::Value) -> Self {
|
pub fn from_json(json: &serde_json::Value) -> Self {
|
||||||
let is_stream = json
|
let is_stream = json
|
||||||
.get("stream")
|
.get("stream")
|
||||||
.and_then(|v| v.as_bool())
|
.and_then(|v| v.as_bool())
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
// Detect batch size from text or input_ids
|
|
||||||
let batch_size = if let Some(text) = json.get("text") {
|
let batch_size = if let Some(text) = json.get("text") {
|
||||||
text.as_array().map(|arr| arr.len())
|
text.as_array().map(|arr| arr.len())
|
||||||
} else if let Some(input_ids) = json.get("input_ids") {
|
} else if let Some(input_ids) = json.get("input_ids") {
|
||||||
@@ -40,15 +37,10 @@ mod test_pd_routing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ========================================================================
|
|
||||||
// Phase 1: Basic PD Components and Router Creation
|
|
||||||
// ========================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_worker_types() {
|
fn test_worker_types() {
|
||||||
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
||||||
|
|
||||||
// Test worker creation for prefill servers
|
|
||||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||||
BasicWorkerBuilder::new("http://prefill:8080")
|
BasicWorkerBuilder::new("http://prefill:8080")
|
||||||
.worker_type(WorkerType::Prefill {
|
.worker_type(WorkerType::Prefill {
|
||||||
@@ -65,7 +57,6 @@ mod test_pd_routing {
|
|||||||
_ => panic!("Expected Prefill worker type"),
|
_ => panic!("Expected Prefill worker type"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test worker creation for decode servers
|
|
||||||
let decode_worker: Box<dyn Worker> = Box::new(
|
let decode_worker: Box<dyn Worker> = Box::new(
|
||||||
BasicWorkerBuilder::new("http://decode:8080")
|
BasicWorkerBuilder::new("http://decode:8080")
|
||||||
.worker_type(WorkerType::Decode)
|
.worker_type(WorkerType::Decode)
|
||||||
@@ -78,7 +69,6 @@ mod test_pd_routing {
|
|||||||
_ => panic!("Expected Decode worker type"),
|
_ => panic!("Expected Decode worker type"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test regular worker creation
|
|
||||||
let regular_worker: Box<dyn Worker> = Box::new(
|
let regular_worker: Box<dyn Worker> = Box::new(
|
||||||
BasicWorkerBuilder::new("http://regular:8080")
|
BasicWorkerBuilder::new("http://regular:8080")
|
||||||
.worker_type(WorkerType::Regular)
|
.worker_type(WorkerType::Regular)
|
||||||
@@ -94,7 +84,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pd_selection_policies() {
|
fn test_pd_selection_policies() {
|
||||||
// Test all PD selection policy variants
|
|
||||||
// Note: These policies are only used when pd_disaggregation=true
|
// Note: These policies are only used when pd_disaggregation=true
|
||||||
let policies = vec![
|
let policies = vec![
|
||||||
PDSelectionPolicy::Random,
|
PDSelectionPolicy::Random,
|
||||||
@@ -107,7 +96,6 @@ mod test_pd_routing {
|
|||||||
];
|
];
|
||||||
|
|
||||||
for policy in policies {
|
for policy in policies {
|
||||||
// Verify each policy can be created and matched
|
|
||||||
match &policy {
|
match &policy {
|
||||||
PDSelectionPolicy::Random => {
|
PDSelectionPolicy::Random => {
|
||||||
assert!(matches!(policy, PDSelectionPolicy::Random));
|
assert!(matches!(policy, PDSelectionPolicy::Random));
|
||||||
@@ -126,7 +114,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_pd_router_configuration() {
|
async fn test_pd_router_configuration() {
|
||||||
// Test PD router configuration with various policies
|
|
||||||
// In the new structure, RoutingMode and PolicyConfig are separate
|
// In the new structure, RoutingMode and PolicyConfig are separate
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
(
|
(
|
||||||
@@ -221,7 +208,6 @@ mod test_pd_routing {
|
|||||||
"Router creation should succeed with empty worker"
|
"Router creation should succeed with empty worker"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify that no workers are registered since we didn't initialize them
|
|
||||||
let stats = app_context.worker_registry.stats();
|
let stats = app_context.worker_registry.stats();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
stats.total_workers, 0,
|
stats.total_workers, 0,
|
||||||
@@ -230,13 +216,8 @@ mod test_pd_routing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ========================================================================
|
|
||||||
// Phase 2: Bootstrap Injection and Request Handling
|
|
||||||
// ========================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pd_request_from_json() {
|
fn test_pd_request_from_json() {
|
||||||
// Test PDRequest parsing from single text request
|
|
||||||
let single_json = json!({
|
let single_json = json!({
|
||||||
"text": "Hello world",
|
"text": "Hello world",
|
||||||
"stream": false,
|
"stream": false,
|
||||||
@@ -248,7 +229,6 @@ mod test_pd_routing {
|
|||||||
assert!(!pd_req.is_stream);
|
assert!(!pd_req.is_stream);
|
||||||
assert_eq!(pd_req.batch_size, None);
|
assert_eq!(pd_req.batch_size, None);
|
||||||
|
|
||||||
// Test PDRequest parsing from batch text request
|
|
||||||
let batch_json = json!({
|
let batch_json = json!({
|
||||||
"text": ["Hello", "World", "Test"],
|
"text": ["Hello", "World", "Test"],
|
||||||
"stream": true,
|
"stream": true,
|
||||||
@@ -259,7 +239,6 @@ mod test_pd_routing {
|
|||||||
assert!(pd_req.is_stream);
|
assert!(pd_req.is_stream);
|
||||||
assert_eq!(pd_req.batch_size, Some(3));
|
assert_eq!(pd_req.batch_size, Some(3));
|
||||||
|
|
||||||
// Test PDRequest parsing from input_ids request
|
|
||||||
let ids_json = json!({
|
let ids_json = json!({
|
||||||
"input_ids": [[1, 2, 3], [4, 5, 6]],
|
"input_ids": [[1, 2, 3], [4, 5, 6]],
|
||||||
"stream": false
|
"stream": false
|
||||||
@@ -269,7 +248,6 @@ mod test_pd_routing {
|
|||||||
assert!(!pd_req.is_stream);
|
assert!(!pd_req.is_stream);
|
||||||
assert_eq!(pd_req.batch_size, Some(2));
|
assert_eq!(pd_req.batch_size, Some(2));
|
||||||
|
|
||||||
// Test PDRequest parsing from chat request
|
|
||||||
let chat_json = json!({
|
let chat_json = json!({
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a helpful assistant"},
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
@@ -288,14 +266,12 @@ mod test_pd_routing {
|
|||||||
// Since we can't test the actual inject_bootstrap_fields function here
|
// Since we can't test the actual inject_bootstrap_fields function here
|
||||||
// (it's private in the router module), we'll test the expected behavior
|
// (it's private in the router module), we'll test the expected behavior
|
||||||
|
|
||||||
// Simulate bootstrap injection for single request
|
|
||||||
let mut single_json = json!({
|
let mut single_json = json!({
|
||||||
"text": "Hello world",
|
"text": "Hello world",
|
||||||
"stream": false,
|
"stream": false,
|
||||||
"temperature": 0.7
|
"temperature": 0.7
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create a prefill worker to simulate injection
|
|
||||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||||
BasicWorkerBuilder::new("http://prefill1:8080")
|
BasicWorkerBuilder::new("http://prefill1:8080")
|
||||||
.worker_type(WorkerType::Prefill {
|
.worker_type(WorkerType::Prefill {
|
||||||
@@ -305,24 +281,20 @@ mod test_pd_routing {
|
|||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Extract bootstrap port from worker type
|
|
||||||
let bootstrap_port = match prefill_worker.worker_type() {
|
let bootstrap_port = match prefill_worker.worker_type() {
|
||||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Simulate what inject_bootstrap_fields would do
|
|
||||||
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
|
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
|
||||||
single_json["bootstrap_port"] = json!(bootstrap_port);
|
single_json["bootstrap_port"] = json!(bootstrap_port);
|
||||||
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
||||||
|
|
||||||
// Verify bootstrap fields are added correctly
|
|
||||||
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
||||||
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
|
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
|
||||||
assert!(single_json["bootstrap_room"].is_u64());
|
assert!(single_json["bootstrap_room"].is_u64());
|
||||||
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
||||||
|
|
||||||
// Simulate bootstrap injection for batch request
|
|
||||||
let mut batch_json = json!({
|
let mut batch_json = json!({
|
||||||
"text": ["Hello", "World", "Test"],
|
"text": ["Hello", "World", "Test"],
|
||||||
"stream": true
|
"stream": true
|
||||||
@@ -334,7 +306,6 @@ mod test_pd_routing {
|
|||||||
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||||
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
||||||
|
|
||||||
// Verify batch bootstrap fields
|
|
||||||
assert!(batch_json["bootstrap_host"].is_array());
|
assert!(batch_json["bootstrap_host"].is_array());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
batch_json["bootstrap_host"].as_array().unwrap().len(),
|
batch_json["bootstrap_host"].as_array().unwrap().len(),
|
||||||
@@ -347,7 +318,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_request_serialization() {
|
fn test_request_serialization() {
|
||||||
// Test that requests can be properly serialized and deserialized
|
|
||||||
let request = json!({
|
let request = json!({
|
||||||
"text": "Test prompt",
|
"text": "Test prompt",
|
||||||
"stream": false,
|
"stream": false,
|
||||||
@@ -360,13 +330,10 @@ mod test_pd_routing {
|
|||||||
"bootstrap_room": 12345u64
|
"bootstrap_room": 12345u64
|
||||||
});
|
});
|
||||||
|
|
||||||
// Convert to bytes (as would happen in the router)
|
|
||||||
let bytes = serde_json::to_vec(&request).unwrap();
|
let bytes = serde_json::to_vec(&request).unwrap();
|
||||||
|
|
||||||
// Parse back from bytes
|
|
||||||
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||||
|
|
||||||
// Verify all fields are preserved
|
|
||||||
assert_eq!(parsed["text"], "Test prompt");
|
assert_eq!(parsed["text"], "Test prompt");
|
||||||
assert_eq!(parsed["stream"], false);
|
assert_eq!(parsed["stream"], false);
|
||||||
assert_eq!(parsed["temperature"], 0.7);
|
assert_eq!(parsed["temperature"], 0.7);
|
||||||
@@ -378,7 +345,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_hostname_extraction() {
|
fn test_hostname_extraction() {
|
||||||
// Test various URL formats
|
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
("http://localhost:8080", "localhost"),
|
("http://localhost:8080", "localhost"),
|
||||||
("http://10.0.0.1:8080", "10.0.0.1"),
|
("http://10.0.0.1:8080", "10.0.0.1"),
|
||||||
@@ -395,13 +361,11 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pd_request_edge_cases() {
|
fn test_pd_request_edge_cases() {
|
||||||
// Test empty request
|
|
||||||
let empty_json = json!({});
|
let empty_json = json!({});
|
||||||
let pd_req = PDRequest::from_json(&empty_json);
|
let pd_req = PDRequest::from_json(&empty_json);
|
||||||
assert!(!pd_req.is_stream);
|
assert!(!pd_req.is_stream);
|
||||||
assert_eq!(pd_req.batch_size, None);
|
assert_eq!(pd_req.batch_size, None);
|
||||||
|
|
||||||
// Test request with only stream field
|
|
||||||
let stream_only = json!({
|
let stream_only = json!({
|
||||||
"stream": true
|
"stream": true
|
||||||
});
|
});
|
||||||
@@ -409,14 +373,12 @@ mod test_pd_routing {
|
|||||||
assert!(pd_req.is_stream);
|
assert!(pd_req.is_stream);
|
||||||
assert_eq!(pd_req.batch_size, None);
|
assert_eq!(pd_req.batch_size, None);
|
||||||
|
|
||||||
// Test request with empty text array
|
|
||||||
let empty_batch = json!({
|
let empty_batch = json!({
|
||||||
"text": []
|
"text": []
|
||||||
});
|
});
|
||||||
let pd_req = PDRequest::from_json(&empty_batch);
|
let pd_req = PDRequest::from_json(&empty_batch);
|
||||||
assert_eq!(pd_req.batch_size, Some(0));
|
assert_eq!(pd_req.batch_size, Some(0));
|
||||||
|
|
||||||
// Test request with non-array text (should be None)
|
|
||||||
let non_array_text = json!({
|
let non_array_text = json!({
|
||||||
"text": "single string"
|
"text": "single string"
|
||||||
});
|
});
|
||||||
@@ -424,29 +386,21 @@ mod test_pd_routing {
|
|||||||
assert_eq!(pd_req.batch_size, None);
|
assert_eq!(pd_req.batch_size, None);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ========================================================================
|
|
||||||
// Phase 2: Background Load Monitoring Tests
|
|
||||||
// ========================================================================
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_background_load_monitoring() {
|
async fn test_background_load_monitoring() {
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::sync::watch;
|
use tokio::sync::watch;
|
||||||
|
|
||||||
// Create a watch channel for testing
|
|
||||||
let (tx, rx) = watch::channel(HashMap::new());
|
let (tx, rx) = watch::channel(HashMap::new());
|
||||||
|
|
||||||
// Simulate load updates
|
|
||||||
let mut loads = HashMap::new();
|
let mut loads = HashMap::new();
|
||||||
loads.insert("http://prefill1:8080".to_string(), 10);
|
loads.insert("http://prefill1:8080".to_string(), 10);
|
||||||
loads.insert("http://prefill2:8080".to_string(), 20);
|
loads.insert("http://prefill2:8080".to_string(), 20);
|
||||||
loads.insert("http://decode1:8080".to_string(), 5);
|
loads.insert("http://decode1:8080".to_string(), 5);
|
||||||
loads.insert("http://decode2:8080".to_string(), 15);
|
loads.insert("http://decode2:8080".to_string(), 15);
|
||||||
|
|
||||||
// Send the loads
|
|
||||||
tx.send(loads.clone()).unwrap();
|
tx.send(loads.clone()).unwrap();
|
||||||
|
|
||||||
// Verify receiver gets the update
|
|
||||||
let received_loads = rx.borrow();
|
let received_loads = rx.borrow();
|
||||||
assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10));
|
assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10));
|
||||||
assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20));
|
assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20));
|
||||||
@@ -456,7 +410,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_load_monitoring_configuration() {
|
fn test_load_monitoring_configuration() {
|
||||||
// Test that load monitoring is only enabled for PowerOfTwo policy
|
|
||||||
let policies = vec![
|
let policies = vec![
|
||||||
(PDSelectionPolicy::Random, false),
|
(PDSelectionPolicy::Random, false),
|
||||||
(PDSelectionPolicy::PowerOfTwo, true),
|
(PDSelectionPolicy::PowerOfTwo, true),
|
||||||
@@ -483,42 +436,31 @@ mod test_pd_routing {
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::sync::watch;
|
use tokio::sync::watch;
|
||||||
|
|
||||||
// Test watch channel's broadcast behavior
|
|
||||||
let (tx, rx1) = watch::channel(HashMap::new());
|
let (tx, rx1) = watch::channel(HashMap::new());
|
||||||
let rx2 = rx1.clone();
|
let rx2 = rx1.clone();
|
||||||
|
|
||||||
// Initial state - empty map
|
|
||||||
assert!(rx1.borrow().is_empty());
|
assert!(rx1.borrow().is_empty());
|
||||||
assert!(rx2.borrow().is_empty());
|
assert!(rx2.borrow().is_empty());
|
||||||
|
|
||||||
// Update 1
|
|
||||||
let mut loads = HashMap::new();
|
let mut loads = HashMap::new();
|
||||||
loads.insert("worker1".to_string(), 10);
|
loads.insert("worker1".to_string(), 10);
|
||||||
tx.send(loads.clone()).unwrap();
|
tx.send(loads.clone()).unwrap();
|
||||||
|
|
||||||
// Both receivers see the update
|
|
||||||
assert_eq!(rx1.borrow().get("worker1"), Some(&10));
|
assert_eq!(rx1.borrow().get("worker1"), Some(&10));
|
||||||
assert_eq!(rx2.borrow().get("worker1"), Some(&10));
|
assert_eq!(rx2.borrow().get("worker1"), Some(&10));
|
||||||
|
|
||||||
// Update 2 - overwrites previous
|
|
||||||
loads.insert("worker1".to_string(), 20);
|
loads.insert("worker1".to_string(), 20);
|
||||||
loads.insert("worker2".to_string(), 30);
|
loads.insert("worker2".to_string(), 30);
|
||||||
tx.send(loads).unwrap();
|
tx.send(loads).unwrap();
|
||||||
|
|
||||||
// Both receivers see the latest state
|
|
||||||
assert_eq!(rx1.borrow().get("worker1"), Some(&20));
|
assert_eq!(rx1.borrow().get("worker1"), Some(&20));
|
||||||
assert_eq!(rx2.borrow().get("worker2"), Some(&30));
|
assert_eq!(rx2.borrow().get("worker2"), Some(&30));
|
||||||
}
|
}
|
||||||
|
|
||||||
// ========================================================================
|
|
||||||
// Tests based on bench_one_batch_server.py patterns
|
|
||||||
// ========================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_generate_request_formats() {
|
fn test_generate_request_formats() {
|
||||||
// Based on bench_one_batch_server.py request patterns
|
// Based on bench_one_batch_server.py request patterns
|
||||||
|
|
||||||
// Test 1: Batch request with input_ids (most common in benchmarks)
|
|
||||||
let batch_request = json!({
|
let batch_request = json!({
|
||||||
"input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
"input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -534,7 +476,6 @@ mod test_pd_routing {
|
|||||||
assert!(pd_req.is_stream);
|
assert!(pd_req.is_stream);
|
||||||
assert_eq!(pd_req.batch_size, Some(3));
|
assert_eq!(pd_req.batch_size, Some(3));
|
||||||
|
|
||||||
// Test 2: Request with return_logprob (critical for PD)
|
|
||||||
let logprob_request = json!({
|
let logprob_request = json!({
|
||||||
"input_ids": [[1, 2, 3]],
|
"input_ids": [[1, 2, 3]],
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -548,7 +489,6 @@ mod test_pd_routing {
|
|||||||
assert_eq!(logprob_request["return_logprob"], true);
|
assert_eq!(logprob_request["return_logprob"], true);
|
||||||
assert_eq!(logprob_request["stream"], false);
|
assert_eq!(logprob_request["stream"], false);
|
||||||
|
|
||||||
// Test 3: Large batch sizes from benchmark
|
|
||||||
let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py
|
let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py
|
||||||
for bs in batch_sizes {
|
for bs in batch_sizes {
|
||||||
let request = json!({
|
let request = json!({
|
||||||
@@ -567,7 +507,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sampling_params_handling() {
|
fn test_sampling_params_handling() {
|
||||||
// Test various sampling parameters from bench_one_batch_server.py
|
|
||||||
let sampling_params_variations = vec![
|
let sampling_params_variations = vec![
|
||||||
json!({
|
json!({
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
@@ -595,14 +534,12 @@ mod test_pd_routing {
|
|||||||
"stream": false
|
"stream": false
|
||||||
});
|
});
|
||||||
|
|
||||||
// Verify params are preserved
|
|
||||||
assert_eq!(request["sampling_params"], params);
|
assert_eq!(request["sampling_params"], params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_streaming_response_parsing() {
|
fn test_streaming_response_parsing() {
|
||||||
// Test SSE format parsing from streaming responses
|
|
||||||
let sse_chunks = ["data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
|
let sse_chunks = ["data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
|
||||||
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
|
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
|
||||||
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
|
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
|
||||||
@@ -615,13 +552,11 @@ mod test_pd_routing {
|
|||||||
assert!(parsed["meta_info"]["completion_tokens"].is_u64());
|
assert!(parsed["meta_info"]["completion_tokens"].is_u64());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test [DONE] detection
|
|
||||||
assert_eq!(sse_chunks[3], "data: [DONE]");
|
assert_eq!(sse_chunks[3], "data: [DONE]");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_ttft_calculation() {
|
fn test_ttft_calculation() {
|
||||||
// Test Time To First Token calculation pattern
|
|
||||||
let first_token_response = json!({
|
let first_token_response = json!({
|
||||||
"text": "Hello",
|
"text": "Hello",
|
||||||
"meta_info": {
|
"meta_info": {
|
||||||
@@ -637,7 +572,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_throughput_metrics() {
|
fn test_throughput_metrics() {
|
||||||
// Test throughput calculation patterns from bench_one_batch_server.py
|
|
||||||
let batch_size = 16;
|
let batch_size = 16;
|
||||||
let input_len = 1024;
|
let input_len = 1024;
|
||||||
let output_len = 16;
|
let output_len = 16;
|
||||||
@@ -655,7 +589,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_error_response_handling() {
|
fn test_error_response_handling() {
|
||||||
// Test error response format from bench_one_batch_server.py
|
|
||||||
let error_response = json!({
|
let error_response = json!({
|
||||||
"error": "Request has failed. Invalid input format."
|
"error": "Request has failed. Invalid input format."
|
||||||
});
|
});
|
||||||
@@ -666,7 +599,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_structured_output_request() {
|
fn test_structured_output_request() {
|
||||||
// Test structured output format (json_schema)
|
|
||||||
let structured_request = json!({
|
let structured_request = json!({
|
||||||
"text": "What is the capital of France? Answer in JSON.",
|
"text": "What is the capital of France? Answer in JSON.",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -687,7 +619,6 @@ mod test_pd_routing {
|
|||||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||||
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
|
||||||
|
|
||||||
// Test bootstrap injection with actual benchmark request patterns
|
|
||||||
let mut benchmark_request = json!({
|
let mut benchmark_request = json!({
|
||||||
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -699,7 +630,6 @@ mod test_pd_routing {
|
|||||||
"stream": true
|
"stream": true
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create a prefill worker to simulate injection
|
|
||||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||||
BasicWorkerBuilder::new("http://prefill:8080")
|
BasicWorkerBuilder::new("http://prefill:8080")
|
||||||
.worker_type(WorkerType::Prefill {
|
.worker_type(WorkerType::Prefill {
|
||||||
@@ -709,7 +639,6 @@ mod test_pd_routing {
|
|||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Extract bootstrap port from worker type
|
|
||||||
let bootstrap_port = match prefill_worker.worker_type() {
|
let bootstrap_port = match prefill_worker.worker_type() {
|
||||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
_ => None,
|
_ => None,
|
||||||
@@ -722,7 +651,6 @@ mod test_pd_routing {
|
|||||||
benchmark_request["bootstrap_room"] =
|
benchmark_request["bootstrap_room"] =
|
||||||
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
||||||
|
|
||||||
// Verify bootstrap fields match batch size
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
benchmark_request["bootstrap_host"]
|
benchmark_request["bootstrap_host"]
|
||||||
.as_array()
|
.as_array()
|
||||||
@@ -745,14 +673,12 @@ mod test_pd_routing {
|
|||||||
batch_size
|
batch_size
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify original fields are preserved
|
|
||||||
assert_eq!(benchmark_request["return_logprob"], true);
|
assert_eq!(benchmark_request["return_logprob"], true);
|
||||||
assert_eq!(benchmark_request["stream"], true);
|
assert_eq!(benchmark_request["stream"], true);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_server_info_response_format() {
|
fn test_server_info_response_format() {
|
||||||
// Test server info format expected by bench_one_batch_server.py
|
|
||||||
let server_info = json!({
|
let server_info = json!({
|
||||||
"internal_states": [{
|
"internal_states": [{
|
||||||
"avg_spec_accept_length": 3.5,
|
"avg_spec_accept_length": 3.5,
|
||||||
@@ -769,16 +695,13 @@ mod test_pd_routing {
|
|||||||
]
|
]
|
||||||
});
|
});
|
||||||
|
|
||||||
// Verify structure matches what benchmark expects
|
|
||||||
assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64());
|
assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64());
|
||||||
assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64());
|
assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64());
|
||||||
assert!(server_info["prefill"].is_array());
|
assert!(server_info["prefill"].is_array());
|
||||||
assert!(server_info["decode"].is_array());
|
assert!(server_info["decode"].is_array());
|
||||||
}
|
}
|
||||||
|
|
||||||
// ========================================================================
|
|
||||||
// Comprehensive Endpoint Coverage Test
|
// Comprehensive Endpoint Coverage Test
|
||||||
// ========================================================================
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pd_endpoints_coverage() {
|
fn test_pd_endpoints_coverage() {
|
||||||
@@ -807,7 +730,6 @@ mod test_pd_routing {
|
|||||||
assert_eq!(implemented_count, 10);
|
assert_eq!(implemented_count, 10);
|
||||||
assert_eq!(total_count, 11);
|
assert_eq!(total_count, 11);
|
||||||
|
|
||||||
// Document the missing endpoint
|
|
||||||
let missing: Vec<_> = implemented_endpoints
|
let missing: Vec<_> = implemented_endpoints
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|(_, _, impl_status)| !impl_status)
|
.filter(|(_, _, impl_status)| !impl_status)
|
||||||
@@ -819,14 +741,12 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_large_batch_bootstrap_injection() {
|
fn test_large_batch_bootstrap_injection() {
|
||||||
// Test bootstrap injection performance with very large batches
|
|
||||||
// This simulates the bench_one_batch_server.py scenario
|
// This simulates the bench_one_batch_server.py scenario
|
||||||
let large_batch_sizes = vec![1024, 4096, 8192];
|
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||||
|
|
||||||
for batch_size in large_batch_sizes {
|
for batch_size in large_batch_sizes {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
// Simulate a large batch request
|
|
||||||
let mut large_batch_request = json!({
|
let mut large_batch_request = json!({
|
||||||
"input_ids": vec![vec![1, 2, 3, 4]; batch_size],
|
"input_ids": vec![vec![1, 2, 3, 4]; batch_size],
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -836,7 +756,6 @@ mod test_pd_routing {
|
|||||||
"stream": true
|
"stream": true
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create a prefill worker to simulate injection
|
|
||||||
let prefill_worker: Box<dyn Worker> = Box::new(
|
let prefill_worker: Box<dyn Worker> = Box::new(
|
||||||
BasicWorkerBuilder::new("http://prefill:8080")
|
BasicWorkerBuilder::new("http://prefill:8080")
|
||||||
.worker_type(WorkerType::Prefill {
|
.worker_type(WorkerType::Prefill {
|
||||||
@@ -846,7 +765,6 @@ mod test_pd_routing {
|
|||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Extract bootstrap port from worker type
|
|
||||||
let bootstrap_port = match prefill_worker.worker_type() {
|
let bootstrap_port = match prefill_worker.worker_type() {
|
||||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
_ => None,
|
_ => None,
|
||||||
@@ -861,7 +779,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
let elapsed = start.elapsed();
|
let elapsed = start.elapsed();
|
||||||
|
|
||||||
// Verify bootstrap fields are correctly sized
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
large_batch_request["bootstrap_host"]
|
large_batch_request["bootstrap_host"]
|
||||||
.as_array()
|
.as_array()
|
||||||
@@ -899,7 +816,6 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_payload_size_calculation() {
|
fn test_payload_size_calculation() {
|
||||||
// Test payload size estimation for bench_one_batch_server.py scenarios
|
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
(1, 1024, 16), // Small batch
|
(1, 1024, 16), // Small batch
|
||||||
(16, 1024, 16), // Medium batch
|
(16, 1024, 16), // Medium batch
|
||||||
@@ -937,14 +853,12 @@ mod test_pd_routing {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_policy_type_to_pd_selection_policy_mapping() {
|
fn test_policy_type_to_pd_selection_policy_mapping() {
|
||||||
// Test that PDSelectionPolicy doesn't include RoundRobin
|
|
||||||
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
|
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
pd_policy_count, 3,
|
pd_policy_count, 3,
|
||||||
"PDSelectionPolicy should have exactly 3 variants"
|
"PDSelectionPolicy should have exactly 3 variants"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify that each PDSelectionPolicy variant can be created
|
|
||||||
let _random = PDSelectionPolicy::Random;
|
let _random = PDSelectionPolicy::Random;
|
||||||
let _po2 = PDSelectionPolicy::PowerOfTwo;
|
let _po2 = PDSelectionPolicy::PowerOfTwo;
|
||||||
let _cache_aware = PDSelectionPolicy::CacheAware {
|
let _cache_aware = PDSelectionPolicy::CacheAware {
|
||||||
|
|||||||
@@ -84,7 +84,6 @@ fn test_sequence_operations() {
|
|||||||
for prompt in TEST_PROMPTS.iter() {
|
for prompt in TEST_PROMPTS.iter() {
|
||||||
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
|
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
|
||||||
|
|
||||||
// Test Sequence with append_text
|
|
||||||
let mut sequence = Sequence::new(tokenizer.clone());
|
let mut sequence = Sequence::new(tokenizer.clone());
|
||||||
sequence.append_text(prompt).expect("Failed to append text");
|
sequence.append_text(prompt).expect("Failed to append text");
|
||||||
|
|
||||||
@@ -95,7 +94,6 @@ fn test_sequence_operations() {
|
|||||||
);
|
);
|
||||||
assert_eq!(sequence.text().unwrap(), *prompt, "Sequence text mismatch");
|
assert_eq!(sequence.text().unwrap(), *prompt, "Sequence text mismatch");
|
||||||
|
|
||||||
// Test incremental decoding with append_token
|
|
||||||
let mut decoder = Sequence::new(tokenizer.clone());
|
let mut decoder = Sequence::new(tokenizer.clone());
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
|
|
||||||
@@ -178,7 +176,6 @@ fn test_stop_sequence_decoder() {
|
|||||||
.expect("Failed to load tokenizer"),
|
.expect("Failed to load tokenizer"),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test with various stop sequences
|
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
(
|
(
|
||||||
"Hello world! Stop here. Continue after.",
|
"Hello world! Stop here. Continue after.",
|
||||||
@@ -237,7 +234,6 @@ fn test_stop_sequence_decoder() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_factory_creation() {
|
fn test_factory_creation() {
|
||||||
// Test factory creation method
|
|
||||||
let tokenizer_path = ensure_tokenizer_cached();
|
let tokenizer_path = ensure_tokenizer_cached();
|
||||||
let tokenizer = factory::create_tokenizer(tokenizer_path.to_str().unwrap())
|
let tokenizer = factory::create_tokenizer(tokenizer_path.to_str().unwrap())
|
||||||
.expect("Failed to create tokenizer via factory");
|
.expect("Failed to create tokenizer via factory");
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{DeepSeekParser, ParseState, StreamResult, To
|
|||||||
async fn test_deepseek_complete_parsing() {
|
async fn test_deepseek_complete_parsing() {
|
||||||
let parser = DeepSeekParser::new();
|
let parser = DeepSeekParser::new();
|
||||||
|
|
||||||
// Test single tool call
|
|
||||||
let input = r#"Let me help you with that.
|
let input = r#"Let me help you with that.
|
||||||
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||||
```json
|
```json
|
||||||
@@ -18,7 +17,6 @@ The weather in Tokyo is..."#;
|
|||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].function.name, "get_weather");
|
assert_eq!(result[0].function.name, "get_weather");
|
||||||
|
|
||||||
// Verify arguments
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||||
assert_eq!(args["location"], "Tokyo");
|
assert_eq!(args["location"], "Tokyo");
|
||||||
assert_eq!(args["units"], "celsius");
|
assert_eq!(args["units"], "celsius");
|
||||||
|
|||||||
@@ -167,8 +167,6 @@ async fn test_unicode_edge_cases() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_nested_brackets_in_strings() {
|
async fn test_nested_brackets_in_strings() {
|
||||||
// Test that parsers correctly handle brackets within string literals
|
|
||||||
|
|
||||||
let mistral_parser = MistralParser::new();
|
let mistral_parser = MistralParser::new();
|
||||||
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array: [1, 2, 3]"}}]"#;
|
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array: [1, 2, 3]"}}]"#;
|
||||||
let result = mistral_parser.parse_complete(input).await.unwrap();
|
let result = mistral_parser.parse_complete(input).await.unwrap();
|
||||||
@@ -186,8 +184,6 @@ async fn test_nested_brackets_in_strings() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_multiple_formats_in_text() {
|
async fn test_multiple_formats_in_text() {
|
||||||
// Test that parsers don't get confused by other formats in the text
|
|
||||||
|
|
||||||
let json_parser = JsonParser::new();
|
let json_parser = JsonParser::new();
|
||||||
let input = r#"
|
let input = r#"
|
||||||
Here's some text with [TOOL_CALLS] that shouldn't trigger.
|
Here's some text with [TOOL_CALLS] that shouldn't trigger.
|
||||||
@@ -272,7 +268,6 @@ async fn test_partial_token_at_buffer_boundary() {
|
|||||||
let parser = QwenParser::new();
|
let parser = QwenParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Test case that would fail with the bug:
|
|
||||||
// Send exactly "<tool" which is a 5-character prefix of "<tool_call>\n"
|
// Send exactly "<tool" which is a 5-character prefix of "<tool_call>\n"
|
||||||
let result = parser.parse_incremental("<tool", &mut state).await.unwrap();
|
let result = parser.parse_incremental("<tool", &mut state).await.unwrap();
|
||||||
assert!(matches!(result, StreamResult::Incomplete));
|
assert!(matches!(result, StreamResult::Incomplete));
|
||||||
@@ -303,7 +298,6 @@ async fn test_partial_token_at_buffer_boundary() {
|
|||||||
async fn test_exact_prefix_lengths() {
|
async fn test_exact_prefix_lengths() {
|
||||||
let parser = QwenParser::new();
|
let parser = QwenParser::new();
|
||||||
|
|
||||||
// Test various exact prefix lengths that would be missed by exclusive range
|
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
("<", 1), // 1-char prefix
|
("<", 1), // 1-char prefix
|
||||||
("<t", 2), // 2-char prefix
|
("<t", 2), // 2-char prefix
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{Glm4MoeParser, ParseState, StreamResult, Too
|
|||||||
async fn test_glm4_complete_parsing() {
|
async fn test_glm4_complete_parsing() {
|
||||||
let parser = Glm4MoeParser::new();
|
let parser = Glm4MoeParser::new();
|
||||||
|
|
||||||
// Test single tool call
|
|
||||||
let input = r#"Let me search for that.
|
let input = r#"Let me search for that.
|
||||||
<tool_call>get_weather
|
<tool_call>get_weather
|
||||||
<arg_key>city</arg_key>
|
<arg_key>city</arg_key>
|
||||||
@@ -20,7 +19,6 @@ The weather will be..."#;
|
|||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].function.name, "get_weather");
|
assert_eq!(result[0].function.name, "get_weather");
|
||||||
|
|
||||||
// Verify arguments
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||||
assert_eq!(args["city"], "Beijing");
|
assert_eq!(args["city"], "Beijing");
|
||||||
assert_eq!(args["date"], "2024-12-25");
|
assert_eq!(args["date"], "2024-12-25");
|
||||||
@@ -51,7 +49,6 @@ async fn test_glm4_multiple_tools() {
|
|||||||
async fn test_glm4_type_conversion() {
|
async fn test_glm4_type_conversion() {
|
||||||
let parser = Glm4MoeParser::new();
|
let parser = Glm4MoeParser::new();
|
||||||
|
|
||||||
// Test various value types
|
|
||||||
let input = r#"<tool_call>process
|
let input = r#"<tool_call>process
|
||||||
<arg_key>count</arg_key>
|
<arg_key>count</arg_key>
|
||||||
<arg_value>42</arg_value>
|
<arg_value>42</arg_value>
|
||||||
@@ -132,7 +129,6 @@ fn test_glm4_format_detection() {
|
|||||||
async fn test_glm4_python_literal_values() {
|
async fn test_glm4_python_literal_values() {
|
||||||
let parser = Glm4MoeParser::new();
|
let parser = Glm4MoeParser::new();
|
||||||
|
|
||||||
// Test Python-style boolean values
|
|
||||||
let input = r#"<tool_call>config
|
let input = r#"<tool_call>config
|
||||||
<arg_key>debug</arg_key>
|
<arg_key>debug</arg_key>
|
||||||
<arg_value>True</arg_value>
|
<arg_value>True</arg_value>
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{GptOssParser, ParseState, StreamResult, Tool
|
|||||||
async fn test_gpt_oss_complete_parsing() {
|
async fn test_gpt_oss_complete_parsing() {
|
||||||
let parser = GptOssParser::new();
|
let parser = GptOssParser::new();
|
||||||
|
|
||||||
// Test single tool call
|
|
||||||
let input = r#"Let me search for that information.
|
let input = r#"Let me search for that information.
|
||||||
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "rust programming", "limit": 10}<|call|>
|
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "rust programming", "limit": 10}<|call|>
|
||||||
Here are the results..."#;
|
Here are the results..."#;
|
||||||
@@ -15,7 +14,6 @@ Here are the results..."#;
|
|||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].function.name, "search");
|
assert_eq!(result[0].function.name, "search");
|
||||||
|
|
||||||
// Verify arguments
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||||
assert_eq!(args["query"], "rust programming");
|
assert_eq!(args["query"], "rust programming");
|
||||||
assert_eq!(args["limit"], 10);
|
assert_eq!(args["limit"], 10);
|
||||||
@@ -38,7 +36,6 @@ async fn test_gpt_oss_multiple_tools() {
|
|||||||
async fn test_gpt_oss_with_namespace() {
|
async fn test_gpt_oss_with_namespace() {
|
||||||
let parser = GptOssParser::new();
|
let parser = GptOssParser::new();
|
||||||
|
|
||||||
// Test with different namespace patterns
|
|
||||||
let input = r#"<|channel|>commentary to=api.users.create<|constrain|>json<|message|>{"name": "John", "email": "john@example.com"}<|call|>
|
let input = r#"<|channel|>commentary to=api.users.create<|constrain|>json<|message|>{"name": "John", "email": "john@example.com"}<|call|>
|
||||||
<|channel|>commentary to=tools.calculator.add<|constrain|>json<|message|>{"x": 10, "y": 20}<|call|>"#;
|
<|channel|>commentary to=tools.calculator.add<|constrain|>json<|message|>{"x": 10, "y": 20}<|call|>"#;
|
||||||
|
|
||||||
@@ -52,7 +49,6 @@ async fn test_gpt_oss_with_namespace() {
|
|||||||
async fn test_gpt_oss_with_assistant_prefix() {
|
async fn test_gpt_oss_with_assistant_prefix() {
|
||||||
let parser = GptOssParser::new();
|
let parser = GptOssParser::new();
|
||||||
|
|
||||||
// Test with <|start|>assistant prefix
|
|
||||||
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
|
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
|
||||||
|
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
@@ -64,7 +60,6 @@ async fn test_gpt_oss_with_assistant_prefix() {
|
|||||||
async fn test_gpt_oss_empty_args() {
|
async fn test_gpt_oss_empty_args() {
|
||||||
let parser = GptOssParser::new();
|
let parser = GptOssParser::new();
|
||||||
|
|
||||||
// Test with empty arguments
|
|
||||||
let input =
|
let input =
|
||||||
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
|
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
|
||||||
|
|
||||||
@@ -130,7 +125,6 @@ fn test_gpt_oss_format_detection() {
|
|||||||
async fn test_gpt_oss_with_whitespace() {
|
async fn test_gpt_oss_with_whitespace() {
|
||||||
let parser = GptOssParser::new();
|
let parser = GptOssParser::new();
|
||||||
|
|
||||||
// Test with whitespace after function name
|
|
||||||
let input = r#"<|channel|>commentary to=functions.test <|constrain|>json<|message|>{"key": "value"}<|call|>"#;
|
let input = r#"<|channel|>commentary to=functions.test <|constrain|>json<|message|>{"key": "value"}<|call|>"#;
|
||||||
|
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
@@ -142,7 +136,6 @@ async fn test_gpt_oss_with_whitespace() {
|
|||||||
async fn test_gpt_oss_complex_json() {
|
async fn test_gpt_oss_complex_json() {
|
||||||
let parser = GptOssParser::new();
|
let parser = GptOssParser::new();
|
||||||
|
|
||||||
// Test with complex nested JSON
|
|
||||||
let input = r#"<|channel|>commentary to=functions.process<|constrain|>json<|message|>{
|
let input = r#"<|channel|>commentary to=functions.process<|constrain|>json<|message|>{
|
||||||
"nested": {
|
"nested": {
|
||||||
"data": [1, 2, 3],
|
"data": [1, 2, 3],
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{KimiK2Parser, ParseState, StreamResult, Tool
|
|||||||
async fn test_kimik2_complete_parsing() {
|
async fn test_kimik2_complete_parsing() {
|
||||||
let parser = KimiK2Parser::new();
|
let parser = KimiK2Parser::new();
|
||||||
|
|
||||||
// Test single tool call
|
|
||||||
let input = r#"Let me help you with that.
|
let input = r#"Let me help you with that.
|
||||||
<|tool_calls_section_begin|>
|
<|tool_calls_section_begin|>
|
||||||
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
|
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
|
||||||
@@ -17,7 +16,6 @@ The weather in Tokyo is..."#;
|
|||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].function.name, "get_weather");
|
assert_eq!(result[0].function.name, "get_weather");
|
||||||
|
|
||||||
// Verify arguments
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||||
assert_eq!(args["location"], "Tokyo");
|
assert_eq!(args["location"], "Tokyo");
|
||||||
assert_eq!(args["units"], "celsius");
|
assert_eq!(args["units"], "celsius");
|
||||||
@@ -42,7 +40,6 @@ async fn test_kimik2_multiple_tools() {
|
|||||||
async fn test_kimik2_with_whitespace() {
|
async fn test_kimik2_with_whitespace() {
|
||||||
let parser = KimiK2Parser::new();
|
let parser = KimiK2Parser::new();
|
||||||
|
|
||||||
// Test with extra whitespace
|
|
||||||
let input = r#"<|tool_calls_section_begin|>
|
let input = r#"<|tool_calls_section_begin|>
|
||||||
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|>
|
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|>
|
||||||
<|tool_calls_section_end|>"#;
|
<|tool_calls_section_end|>"#;
|
||||||
@@ -114,7 +111,6 @@ fn test_kimik2_format_detection() {
|
|||||||
async fn test_kimik2_sequential_indices() {
|
async fn test_kimik2_sequential_indices() {
|
||||||
let parser = KimiK2Parser::new();
|
let parser = KimiK2Parser::new();
|
||||||
|
|
||||||
// Test with proper sequential indexing
|
|
||||||
let input = r#"<|tool_calls_section_begin|>
|
let input = r#"<|tool_calls_section_begin|>
|
||||||
<|tool_call_begin|>functions.first:0<|tool_call_argument_begin|>{"param": "a"}<|tool_call_end|>
|
<|tool_call_begin|>functions.first:0<|tool_call_argument_begin|>{"param": "a"}<|tool_call_end|>
|
||||||
<|tool_call_begin|>functions.second:1<|tool_call_argument_begin|>{"param": "b"}<|tool_call_end|>
|
<|tool_call_begin|>functions.second:1<|tool_call_argument_begin|>{"param": "b"}<|tool_call_end|>
|
||||||
|
|||||||
@@ -116,7 +116,6 @@ async fn test_llama_real_world_output() {
|
|||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].function.name, "web_search");
|
assert_eq!(result[0].function.name, "web_search");
|
||||||
|
|
||||||
// Test with nicely formatted JSON
|
|
||||||
let formatted_input = r#"<|python_tag|>{
|
let formatted_input = r#"<|python_tag|>{
|
||||||
"name": "get_current_time",
|
"name": "get_current_time",
|
||||||
"arguments": {
|
"arguments": {
|
||||||
@@ -144,7 +143,6 @@ async fn test_llama_json_array_format() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_single_json() {
|
async fn test_single_json() {
|
||||||
// Test parsing plain JSON without python_tag
|
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
|
let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
|
||||||
|
|
||||||
@@ -158,7 +156,6 @@ async fn test_single_json() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_multiple_json_with_separator() {
|
async fn test_multiple_json_with_separator() {
|
||||||
// Test multiple JSON objects with semicolon separator
|
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#;
|
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#;
|
||||||
|
|
||||||
@@ -170,7 +167,6 @@ async fn test_multiple_json_with_separator() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_multiple_json_with_separator_customized() {
|
async fn test_multiple_json_with_separator_customized() {
|
||||||
// Test multiple JSON objects with python_tag repeated
|
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#;
|
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#;
|
||||||
|
|
||||||
@@ -182,7 +178,6 @@ async fn test_multiple_json_with_separator_customized() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_json_with_trailing_text() {
|
async fn test_json_with_trailing_text() {
|
||||||
// Test JSON with trailing text after
|
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#;
|
let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#;
|
||||||
|
|
||||||
@@ -193,7 +188,6 @@ async fn test_json_with_trailing_text() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_invalid_then_valid_json() {
|
async fn test_invalid_then_valid_json() {
|
||||||
// Test error recovery - invalid JSON followed by valid JSON
|
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#;
|
let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#;
|
||||||
|
|
||||||
@@ -206,7 +200,6 @@ async fn test_invalid_then_valid_json() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_plain_text_only() {
|
async fn test_plain_text_only() {
|
||||||
// Test plain text with no tool calls
|
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
let text = "This is just plain explanation text.";
|
let text = "This is just plain explanation text.";
|
||||||
|
|
||||||
@@ -216,7 +209,6 @@ async fn test_plain_text_only() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_with_python_tag_prefix() {
|
async fn test_with_python_tag_prefix() {
|
||||||
// Test text before python_tag
|
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#;
|
let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#;
|
||||||
|
|
||||||
@@ -225,9 +217,7 @@ async fn test_with_python_tag_prefix() {
|
|||||||
assert_eq!(result[0].function.name, "get_weather");
|
assert_eq!(result[0].function.name, "get_weather");
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// STREAMING TESTS
|
// STREAMING TESTS
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_llama_streaming_simple() {
|
async fn test_llama_streaming_simple() {
|
||||||
@@ -332,7 +322,6 @@ async fn test_llama_streaming_with_text_before() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_llama_streaming_multiple_tools() {
|
async fn test_llama_streaming_multiple_tools() {
|
||||||
// Test streaming multiple tool calls with semicolon separator
|
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -361,7 +350,6 @@ async fn test_llama_streaming_multiple_tools() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_llama_streaming_multiple_tools_chunked() {
|
async fn test_llama_streaming_multiple_tools_chunked() {
|
||||||
// Test streaming multiple tool calls arriving in chunks
|
|
||||||
let parser = LlamaParser::new();
|
let parser = LlamaParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ use sglang_router_rs::tool_parser::{
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_mixed_formats_in_text() {
|
async fn test_mixed_formats_in_text() {
|
||||||
// Test that parsers correctly ignore other formats' markers
|
|
||||||
|
|
||||||
let json_parser = JsonParser::new();
|
let json_parser = JsonParser::new();
|
||||||
let input = r#"
|
let input = r#"
|
||||||
Some text with [TOOL_CALLS] marker that shouldn't trigger.
|
Some text with [TOOL_CALLS] marker that shouldn't trigger.
|
||||||
@@ -37,8 +35,6 @@ async fn test_mixed_formats_in_text() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_format_markers_in_string_content() {
|
async fn test_format_markers_in_string_content() {
|
||||||
// Test that format markers inside string content don't interfere
|
|
||||||
|
|
||||||
let pythonic_parser = PythonicParser::new();
|
let pythonic_parser = PythonicParser::new();
|
||||||
let input = r#"[echo(text="Use [TOOL_CALLS] and <tool_call> in text")]"#;
|
let input = r#"[echo(text="Use [TOOL_CALLS] and <tool_call> in text")]"#;
|
||||||
|
|
||||||
@@ -101,7 +97,6 @@ async fn test_multiple_sequential_calls_different_formats() {
|
|||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].function.name, "call1");
|
assert_eq!(result[0].function.name, "call1");
|
||||||
|
|
||||||
// Test plain JSON separately
|
|
||||||
let input2 = r#"{"name": "call2", "arguments": {"x": 1}}"#;
|
let input2 = r#"{"name": "call2", "arguments": {"x": 1}}"#;
|
||||||
let result2 = llama_parser.parse_complete(input2).await.unwrap();
|
let result2 = llama_parser.parse_complete(input2).await.unwrap();
|
||||||
assert_eq!(result2.len(), 1);
|
assert_eq!(result2.len(), 1);
|
||||||
@@ -133,7 +128,6 @@ async fn test_empty_and_whitespace_variations() {
|
|||||||
async fn test_special_json_values() {
|
async fn test_special_json_values() {
|
||||||
let json_parser = JsonParser::new();
|
let json_parser = JsonParser::new();
|
||||||
|
|
||||||
// Test various special JSON values
|
|
||||||
let input = r#"{
|
let input = r#"{
|
||||||
"name": "test_special",
|
"name": "test_special",
|
||||||
"arguments": {
|
"arguments": {
|
||||||
@@ -183,8 +177,6 @@ async fn test_parser_recovery_after_invalid_input() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_boundary_cases_for_extraction() {
|
async fn test_boundary_cases_for_extraction() {
|
||||||
// Test edge cases in JSON extraction from text
|
|
||||||
|
|
||||||
let json_parser = JsonParser::new();
|
let json_parser = JsonParser::new();
|
||||||
|
|
||||||
// JSON at the very beginning
|
// JSON at the very beginning
|
||||||
@@ -259,7 +251,6 @@ async fn test_mistral_with_pretty_json() {
|
|||||||
async fn test_qwen_with_cdata_like_content() {
|
async fn test_qwen_with_cdata_like_content() {
|
||||||
let parser = QwenParser::new();
|
let parser = QwenParser::new();
|
||||||
|
|
||||||
// Test with content that looks like CDATA but isn't
|
|
||||||
// Note: QwenParser expects exactly "<tool_call>\n" with the newline
|
// Note: QwenParser expects exactly "<tool_call>\n" with the newline
|
||||||
let input = r#"<tool_call>
|
let input = r#"<tool_call>
|
||||||
{"name": "process", "arguments": {"xml": "<![CDATA[some data]]>"}}
|
{"name": "process", "arguments": {"xml": "<![CDATA[some data]]>"}}
|
||||||
|
|||||||
@@ -180,7 +180,6 @@ These functions will provide the information you need."#;
|
|||||||
async fn test_pythonic_nested_brackets_in_lists() {
|
async fn test_pythonic_nested_brackets_in_lists() {
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
|
|
||||||
// Test nested brackets within list arguments
|
|
||||||
let input = r#"[process_matrix(data=[[1, 2], [3, 4]], labels=["row[0]", "row[1]"])]"#;
|
let input = r#"[process_matrix(data=[[1, 2], [3, 4]], labels=["row[0]", "row[1]"])]"#;
|
||||||
|
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
@@ -196,7 +195,6 @@ async fn test_pythonic_nested_brackets_in_lists() {
|
|||||||
async fn test_pythonic_nested_brackets_in_dicts() {
|
async fn test_pythonic_nested_brackets_in_dicts() {
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
|
|
||||||
// Test nested brackets within dictionary arguments
|
|
||||||
let input =
|
let input =
|
||||||
r#"[analyze(config={"patterns": ["[a-z]+", "[0-9]+"], "nested": {"list": [1, [2, 3]]}})]"#;
|
r#"[analyze(config={"patterns": ["[a-z]+", "[0-9]+"], "nested": {"list": [1, [2, 3]]}})]"#;
|
||||||
|
|
||||||
@@ -213,7 +211,6 @@ async fn test_pythonic_nested_brackets_in_dicts() {
|
|||||||
async fn test_pythonic_mixed_quotes() {
|
async fn test_pythonic_mixed_quotes() {
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
|
|
||||||
// Test mixed quote types in arguments
|
|
||||||
let input = r#"[format_text(single='Hello', double="World", mixed="It's \"quoted\"")]"#;
|
let input = r#"[format_text(single='Hello', double="World", mixed="It's \"quoted\"")]"#;
|
||||||
|
|
||||||
let result = parser.parse_complete(input).await.unwrap();
|
let result = parser.parse_complete(input).await.unwrap();
|
||||||
@@ -230,7 +227,6 @@ async fn test_pythonic_mixed_quotes() {
|
|||||||
async fn test_pythonic_complex_nesting() {
|
async fn test_pythonic_complex_nesting() {
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
|
|
||||||
// Test complex nested structures
|
|
||||||
let input = r#"[transform(
|
let input = r#"[transform(
|
||||||
matrix=[[1, [2, 3]], [4, [5, [6, 7]]]],
|
matrix=[[1, [2, 3]], [4, [5, [6, 7]]]],
|
||||||
operations=[{"type": "scale", "factor": [2, 3]}, {"type": "rotate", "angle": 90}],
|
operations=[{"type": "scale", "factor": [2, 3]}, {"type": "rotate", "angle": 90}],
|
||||||
@@ -250,7 +246,6 @@ async fn test_pythonic_complex_nesting() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_no_brackets() {
|
async fn test_parse_streaming_no_brackets() {
|
||||||
// Test parsing text with no brackets (no tool calls)
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -268,7 +263,6 @@ async fn test_parse_streaming_no_brackets() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_complete_tool_call() {
|
async fn test_parse_streaming_complete_tool_call() {
|
||||||
// Test parsing a complete tool call
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -289,7 +283,6 @@ async fn test_parse_streaming_complete_tool_call() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_text_before_tool_call() {
|
async fn test_parse_streaming_text_before_tool_call() {
|
||||||
// Test parsing text that appears before a tool call
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -308,7 +301,6 @@ async fn test_parse_streaming_text_before_tool_call() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_partial_tool_call() {
|
async fn test_parse_streaming_partial_tool_call() {
|
||||||
// Test parsing a partial tool call that spans multiple chunks
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -340,7 +332,6 @@ async fn test_parse_streaming_partial_tool_call() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_bracket_without_text_before() {
|
async fn test_parse_streaming_bracket_without_text_before() {
|
||||||
// Test parsing a tool call that starts at the beginning of the text
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -359,7 +350,6 @@ async fn test_parse_streaming_bracket_without_text_before() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_text_after_tool_call() {
|
async fn test_parse_streaming_text_after_tool_call() {
|
||||||
// Test parsing text that appears after a tool call
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -379,7 +369,6 @@ async fn test_parse_streaming_text_after_tool_call() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_multiple_tool_calls() {
|
async fn test_parse_streaming_multiple_tool_calls() {
|
||||||
// Test parsing multiple tool calls in sequence
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -401,7 +390,6 @@ async fn test_parse_streaming_multiple_tool_calls() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_opening_bracket_only() {
|
async fn test_parse_streaming_opening_bracket_only() {
|
||||||
// Test parsing text with only an opening bracket but no closing bracket
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -418,7 +406,6 @@ async fn test_parse_streaming_opening_bracket_only() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_nested_brackets() {
|
async fn test_parse_streaming_nested_brackets() {
|
||||||
// Test parsing tool calls with nested brackets in arguments
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -439,7 +426,6 @@ async fn test_parse_streaming_nested_brackets() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_nested_brackets_dict() {
|
async fn test_parse_streaming_nested_brackets_dict() {
|
||||||
// Test parsing tool calls with nested dictionaries and lists
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -460,7 +446,6 @@ async fn test_parse_streaming_nested_brackets_dict() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_multiple_tools_with_nested_brackets() {
|
async fn test_parse_streaming_multiple_tools_with_nested_brackets() {
|
||||||
// Test parsing multiple tool calls with nested brackets
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -480,7 +465,6 @@ async fn test_parse_streaming_multiple_tools_with_nested_brackets() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_partial_nested_brackets() {
|
async fn test_parse_streaming_partial_nested_brackets() {
|
||||||
// Test parsing partial tool calls with nested brackets across chunks
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -514,7 +498,6 @@ async fn test_parse_streaming_partial_nested_brackets() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parse_streaming_with_python_start_and_end_token() {
|
async fn test_parse_streaming_with_python_start_and_end_token() {
|
||||||
// Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||||
|
|
||||||
@@ -544,7 +527,6 @@ async fn test_parse_streaming_with_python_start_and_end_token() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_detect_and_parse_with_python_start_and_end_token() {
|
async fn test_detect_and_parse_with_python_start_and_end_token() {
|
||||||
// Test parsing a message that starts with <|python_start|> and contains a valid tool call
|
|
||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
|
|
||||||
let text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars.";
|
let text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars.";
|
||||||
|
|||||||
@@ -189,7 +189,6 @@ async fn test_buffer_drain_optimization() {
|
|||||||
// First chunk - incomplete tool call
|
// First chunk - incomplete tool call
|
||||||
let chunk1 = "<tool_call>\n{\"name\": \"test1\", ";
|
let chunk1 = "<tool_call>\n{\"name\": \"test1\", ";
|
||||||
let _result = parser.parse_incremental(chunk1, &mut state).await.unwrap();
|
let _result = parser.parse_incremental(chunk1, &mut state).await.unwrap();
|
||||||
// Phase 2 simplified streaming might not handle partial JSON correctly
|
|
||||||
// The important thing is buffer accumulation works
|
// The important thing is buffer accumulation works
|
||||||
assert!(!state.buffer.is_empty());
|
assert!(!state.buffer.is_empty());
|
||||||
|
|
||||||
@@ -197,32 +196,23 @@ async fn test_buffer_drain_optimization() {
|
|||||||
let chunk2 = "\"arguments\": {}}\n</tool_call><tool_call>\n{\"name\": \"test2\", ";
|
let chunk2 = "\"arguments\": {}}\n</tool_call><tool_call>\n{\"name\": \"test2\", ";
|
||||||
let result = parser.parse_incremental(chunk2, &mut state).await.unwrap();
|
let result = parser.parse_incremental(chunk2, &mut state).await.unwrap();
|
||||||
|
|
||||||
match result {
|
if let StreamResult::ToolComplete(tool) = result {
|
||||||
StreamResult::ToolComplete(tool) => {
|
assert_eq!(tool.function.name, "test1");
|
||||||
assert_eq!(tool.function.name, "test1");
|
// After consuming the first tool, buffer should contain only the second tool start
|
||||||
// After consuming the first tool, buffer should contain only the second tool start
|
assert!(state.buffer.starts_with("<tool_call>"));
|
||||||
assert!(state.buffer.starts_with("<tool_call>"));
|
assert!(state.buffer.contains("test2"));
|
||||||
assert!(state.buffer.contains("test2"));
|
} else {
|
||||||
}
|
// The important thing is the buffer is managed correctly
|
||||||
_ => {
|
|
||||||
// Phase 2 simplified streaming might return Incomplete
|
|
||||||
// The important thing is the buffer is managed correctly
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Complete the second tool
|
// Complete the second tool
|
||||||
let chunk3 = "\"arguments\": {\"x\": 1}}\n</tool_call>";
|
let chunk3 = "\"arguments\": {\"x\": 1}}\n</tool_call>";
|
||||||
let result = parser.parse_incremental(chunk3, &mut state).await.unwrap();
|
let result = parser.parse_incremental(chunk3, &mut state).await.unwrap();
|
||||||
|
|
||||||
match result {
|
if let StreamResult::ToolComplete(tool) = result {
|
||||||
StreamResult::ToolComplete(tool) => {
|
assert_eq!(tool.function.name, "test2");
|
||||||
assert_eq!(tool.function.name, "test2");
|
// Buffer should be empty after consuming all tools
|
||||||
// Buffer should be empty after consuming all tools
|
assert!(state.buffer.is_empty() || !state.buffer.contains("</tool_call>"));
|
||||||
assert!(state.buffer.is_empty() || !state.buffer.contains("</tool_call>"));
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Phase 2 simplified streaming might handle this differently
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -253,7 +243,4 @@ async fn test_buffer_efficiency_with_multiple_tools() {
|
|||||||
// Simplified streaming might return Incomplete
|
// Simplified streaming might return Incomplete
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify no memory issues or panics occurred with drain()
|
|
||||||
// Test passes if we reach this point without panic
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -126,7 +126,6 @@ async fn test_unknown_model_fallback() {
|
|||||||
async fn test_pattern_specificity() {
|
async fn test_pattern_specificity() {
|
||||||
let registry = ParserRegistry::new();
|
let registry = ParserRegistry::new();
|
||||||
|
|
||||||
// Test that more specific patterns take precedence
|
|
||||||
// llama-4* should match before llama-*
|
// llama-4* should match before llama-*
|
||||||
let parser = registry.get_parser("llama-4-70b").unwrap();
|
let parser = registry.get_parser("llama-4-70b").unwrap();
|
||||||
assert!(parser.detect_format(r#"[test_function(x=1)]"#)); // Pythonic format
|
assert!(parser.detect_format(r#"[test_function(x=1)]"#)); // Pythonic format
|
||||||
@@ -139,7 +138,6 @@ async fn test_pattern_specificity() {
|
|||||||
async fn test_real_world_model_outputs() {
|
async fn test_real_world_model_outputs() {
|
||||||
let registry = ParserRegistry::new();
|
let registry = ParserRegistry::new();
|
||||||
|
|
||||||
// Test with realistic outputs from different models
|
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
(
|
(
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{ParseState, Step3Parser, StreamResult, ToolP
|
|||||||
async fn test_step3_complete_parsing() {
|
async fn test_step3_complete_parsing() {
|
||||||
let parser = Step3Parser::new();
|
let parser = Step3Parser::new();
|
||||||
|
|
||||||
// Test single tool call
|
|
||||||
let input = r#"Let me help you.
|
let input = r#"Let me help you.
|
||||||
<|tool_calls_begin|>
|
<|tool_calls_begin|>
|
||||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="search">
|
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="search">
|
||||||
@@ -20,7 +19,6 @@ Here are the results..."#;
|
|||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].function.name, "search");
|
assert_eq!(result[0].function.name, "search");
|
||||||
|
|
||||||
// Verify arguments
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||||
assert_eq!(args["query"], "rust programming");
|
assert_eq!(args["query"], "rust programming");
|
||||||
assert_eq!(args["limit"], 10);
|
assert_eq!(args["limit"], 10);
|
||||||
@@ -127,7 +125,6 @@ fn test_step3_format_detection() {
|
|||||||
async fn test_step3_nested_steptml() {
|
async fn test_step3_nested_steptml() {
|
||||||
let parser = Step3Parser::new();
|
let parser = Step3Parser::new();
|
||||||
|
|
||||||
// Test with complex parameter values
|
|
||||||
let input = r#"<|tool_calls_begin|>
|
let input = r#"<|tool_calls_begin|>
|
||||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="config">
|
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="config">
|
||||||
<steptml:parameter name="settings">{"nested": {"key": "value"}}</steptml:parameter>
|
<steptml:parameter name="settings">{"nested": {"key": "value"}}</steptml:parameter>
|
||||||
@@ -148,7 +145,6 @@ async fn test_step3_nested_steptml() {
|
|||||||
async fn test_step3_python_literals() {
|
async fn test_step3_python_literals() {
|
||||||
let parser = Step3Parser::new();
|
let parser = Step3Parser::new();
|
||||||
|
|
||||||
// Test Python-style literals
|
|
||||||
let input = r#"<|tool_calls_begin|>
|
let input = r#"<|tool_calls_begin|>
|
||||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="test">
|
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="test">
|
||||||
<steptml:parameter name="bool_true">True</steptml:parameter>
|
<steptml:parameter name="bool_true">True</steptml:parameter>
|
||||||
@@ -211,7 +207,6 @@ async fn test_json_parameter_values() {
|
|||||||
async fn test_step3_parameter_with_angle_brackets() {
|
async fn test_step3_parameter_with_angle_brackets() {
|
||||||
let parser = Step3Parser::new();
|
let parser = Step3Parser::new();
|
||||||
|
|
||||||
// Test parameter value containing < character
|
|
||||||
let input = r#"<|tool_calls_begin|>
|
let input = r#"<|tool_calls_begin|>
|
||||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="compare">
|
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="compare">
|
||||||
<steptml:parameter name="expression">a < b && b > c</steptml:parameter>
|
<steptml:parameter name="expression">a < b && b > c</steptml:parameter>
|
||||||
@@ -223,7 +218,6 @@ async fn test_step3_parameter_with_angle_brackets() {
|
|||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].function.name, "compare");
|
assert_eq!(result[0].function.name, "compare");
|
||||||
|
|
||||||
// Verify the parameter value was parsed correctly
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||||
assert_eq!(args["expression"], "a < b && b > c");
|
assert_eq!(args["expression"], "a < b && b > c");
|
||||||
assert_eq!(args["context"], "comparison test");
|
assert_eq!(args["context"], "comparison test");
|
||||||
@@ -233,7 +227,6 @@ async fn test_step3_parameter_with_angle_brackets() {
|
|||||||
async fn test_step3_empty_function_name() {
|
async fn test_step3_empty_function_name() {
|
||||||
let parser = Step3Parser::new();
|
let parser = Step3Parser::new();
|
||||||
|
|
||||||
// Test empty function name
|
|
||||||
let input = r#"<|tool_calls_begin|>
|
let input = r#"<|tool_calls_begin|>
|
||||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="">
|
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="">
|
||||||
<steptml:parameter name="param">value</steptml:parameter>
|
<steptml:parameter name="param">value</steptml:parameter>
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ async fn test_json_streaming_simple() {
|
|||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Phase 2 note: This test sends the full JSON at once in the last chunk
|
|
||||||
// In real streaming, chunks would be smaller
|
|
||||||
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||||
|
|
||||||
let result = parser
|
let result = parser
|
||||||
@@ -21,7 +19,6 @@ async fn test_json_streaming_simple() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// With complete JSON sent at once, we should get ToolComplete
|
|
||||||
match result {
|
match result {
|
||||||
StreamResult::ToolComplete(tool) => {
|
StreamResult::ToolComplete(tool) => {
|
||||||
assert_eq!(tool.function.name, "get_weather");
|
assert_eq!(tool.function.name, "get_weather");
|
||||||
@@ -37,7 +34,6 @@ async fn test_json_streaming_array() {
|
|||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Stream a JSON array of tools
|
|
||||||
let chunks = vec![
|
let chunks = vec![
|
||||||
r#"["#,
|
r#"["#,
|
||||||
r#"{"name": "tool1", "#,
|
r#"{"name": "tool1", "#,
|
||||||
@@ -57,7 +53,6 @@ async fn test_json_streaming_array() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Current implementation may handle this differently
|
// Current implementation may handle this differently
|
||||||
// We're mainly testing that it doesn't crash
|
|
||||||
assert!(tool_count <= 2, "Should parse at most 2 tools");
|
assert!(tool_count <= 2, "Should parse at most 2 tools");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,7 +90,6 @@ async fn test_pythonic_streaming() {
|
|||||||
let parser = PythonicParser::new();
|
let parser = PythonicParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Send complete pythonic format at once
|
|
||||||
let full_input = r#"[get_weather(city="London", units="celsius")]"#;
|
let full_input = r#"[get_weather(city="London", units="celsius")]"#;
|
||||||
|
|
||||||
let result = parser
|
let result = parser
|
||||||
@@ -149,7 +143,6 @@ async fn test_qwen_streaming() {
|
|||||||
let parser = QwenParser::new();
|
let parser = QwenParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Send complete Qwen format at once (with exact format expected by parser)
|
|
||||||
// Note: Parser expects newline after both tags
|
// Note: Parser expects newline after both tags
|
||||||
let full_input = "<tool_call>\n{\"name\": \"translate\", \"arguments\": {\"text\": \"hello\", \"to\": \"zh\"}}\n</tool_call>";
|
let full_input = "<tool_call>\n{\"name\": \"translate\", \"arguments\": {\"text\": \"hello\", \"to\": \"zh\"}}\n</tool_call>";
|
||||||
|
|
||||||
@@ -176,12 +169,10 @@ async fn test_streaming_incomplete_stays_incomplete() {
|
|||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Send truly incomplete JSON that can't be auto-completed
|
|
||||||
let chunks = vec![r#"{"na"#, r#"me": "#];
|
let chunks = vec![r#"{"na"#, r#"me": "#];
|
||||||
|
|
||||||
for chunk in chunks {
|
for chunk in chunks {
|
||||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||||
// Should return Incomplete for partial JSON that can't be auto-completed
|
|
||||||
assert!(
|
assert!(
|
||||||
matches!(result, StreamResult::Incomplete),
|
matches!(result, StreamResult::Incomplete),
|
||||||
"Should return Incomplete for partial JSON, got: {:?}",
|
"Should return Incomplete for partial JSON, got: {:?}",
|
||||||
@@ -189,7 +180,6 @@ async fn test_streaming_incomplete_stays_incomplete() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Buffer should contain the accumulated incomplete JSON
|
|
||||||
assert!(!state.buffer.is_empty());
|
assert!(!state.buffer.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,8 +188,6 @@ async fn test_streaming_with_text_before_tool() {
|
|||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// For streaming, the parser expects clean JSON
|
|
||||||
// Mixed text extraction only works in parse_complete, not parse_incremental
|
|
||||||
let full_input = r#"{"name": "test", "arguments": {}}"#;
|
let full_input = r#"{"name": "test", "arguments": {}}"#;
|
||||||
|
|
||||||
let result = parser
|
let result = parser
|
||||||
@@ -221,10 +209,8 @@ async fn test_streaming_with_text_before_tool() {
|
|||||||
async fn test_streaming_buffer_accumulation() {
|
async fn test_streaming_buffer_accumulation() {
|
||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
|
|
||||||
// Test: Complete JSON should clear buffer after parsing
|
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Send partial JSON that can't be interpreted as complete
|
|
||||||
let result1 = parser
|
let result1 = parser
|
||||||
.parse_incremental(r#"{"na"#, &mut state)
|
.parse_incremental(r#"{"na"#, &mut state)
|
||||||
.await
|
.await
|
||||||
@@ -236,7 +222,6 @@ async fn test_streaming_buffer_accumulation() {
|
|||||||
"Buffer should accumulate incomplete JSON"
|
"Buffer should accumulate incomplete JSON"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Send rest of JSON
|
|
||||||
let result2 = parser
|
let result2 = parser
|
||||||
.parse_incremental(r#"me": "test", "arguments": {}}"#, &mut state)
|
.parse_incremental(r#"me": "test", "arguments": {}}"#, &mut state)
|
||||||
.await
|
.await
|
||||||
@@ -262,7 +247,6 @@ async fn test_streaming_multiple_tools_sequential() {
|
|||||||
let parser = QwenParser::new();
|
let parser = QwenParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Send complete Qwen format with newlines
|
|
||||||
let full_input = r#"<tool_call>
|
let full_input = r#"<tool_call>
|
||||||
{"name": "tool1", "arguments": {}}
|
{"name": "tool1", "arguments": {}}
|
||||||
</tool_call>"#;
|
</tool_call>"#;
|
||||||
@@ -286,13 +270,11 @@ async fn test_streaming_multiple_tools_sequential() {
|
|||||||
async fn test_streaming_reset_after_error() {
|
async fn test_streaming_reset_after_error() {
|
||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
|
|
||||||
// First attempt with invalid JSON
|
|
||||||
let mut state1 = ParseState::new();
|
let mut state1 = ParseState::new();
|
||||||
let _ = parser
|
let _ = parser
|
||||||
.parse_incremental(r#"{"name": invalid}"#, &mut state1)
|
.parse_incremental(r#"{"name": invalid}"#, &mut state1)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Second attempt with valid JSON should work with fresh state
|
|
||||||
let mut state2 = ParseState::new();
|
let mut state2 = ParseState::new();
|
||||||
let result = parser
|
let result = parser
|
||||||
.parse_incremental(r#"{"name": "test", "arguments": {}}"#, &mut state2)
|
.parse_incremental(r#"{"name": "test", "arguments": {}}"#, &mut state2)
|
||||||
@@ -309,7 +291,6 @@ async fn test_streaming_with_unicode_chunks() {
|
|||||||
let parser = JsonParser::new();
|
let parser = JsonParser::new();
|
||||||
let mut state = ParseState::new();
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
// Send complete JSON with unicode
|
|
||||||
let full_input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍"}}"#;
|
let full_input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍"}}"#;
|
||||||
|
|
||||||
let result = parser
|
let result = parser
|
||||||
@@ -317,8 +298,6 @@ async fn test_streaming_with_unicode_chunks() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Phase 2 may return partial results even with complete JSON
|
|
||||||
// The important thing is that unicode is handled without crashes
|
|
||||||
match result {
|
match result {
|
||||||
StreamResult::ToolComplete(tool) => {
|
StreamResult::ToolComplete(tool) => {
|
||||||
assert_eq!(tool.function.name, "translate");
|
assert_eq!(tool.function.name, "translate");
|
||||||
@@ -327,10 +306,8 @@ async fn test_streaming_with_unicode_chunks() {
|
|||||||
}
|
}
|
||||||
StreamResult::ToolName { name, .. } => {
|
StreamResult::ToolName { name, .. } => {
|
||||||
assert_eq!(name, "translate");
|
assert_eq!(name, "translate");
|
||||||
// Phase 2 partial streaming behavior - acceptable
|
|
||||||
}
|
}
|
||||||
StreamResult::ToolArguments { arguments, .. } => {
|
StreamResult::ToolArguments { arguments, .. } => {
|
||||||
// Verify unicode was preserved
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&arguments).unwrap();
|
let args: serde_json::Value = serde_json::from_str(&arguments).unwrap();
|
||||||
assert!(args["text"].as_str().unwrap().contains("世界"));
|
assert!(args["text"].as_str().unwrap().contains("世界"));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,20 +25,17 @@ async fn test_json_with_xml_style_wrapper() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_json_with_multiple_wrapper_pairs() {
|
async fn test_json_with_multiple_wrapper_pairs() {
|
||||||
// Test with multiple start/end token pairs
|
|
||||||
let parser = JsonParser::with_config(TokenConfig {
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
start_tokens: vec!["<tool>".to_string(), "<<TOOL>>".to_string()],
|
start_tokens: vec!["<tool>".to_string(), "<<TOOL>>".to_string()],
|
||||||
end_tokens: vec!["</tool>".to_string(), "<</TOOL>>".to_string()],
|
end_tokens: vec!["</tool>".to_string(), "<</TOOL>>".to_string()],
|
||||||
separator: ", ".to_string(),
|
separator: ", ".to_string(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Test first pair
|
|
||||||
let input1 = r#"<tool>{"name": "tool1", "arguments": {}}</tool>"#;
|
let input1 = r#"<tool>{"name": "tool1", "arguments": {}}</tool>"#;
|
||||||
let result1 = parser.parse_complete(input1).await.unwrap();
|
let result1 = parser.parse_complete(input1).await.unwrap();
|
||||||
assert_eq!(result1.len(), 1);
|
assert_eq!(result1.len(), 1);
|
||||||
assert_eq!(result1[0].function.name, "tool1");
|
assert_eq!(result1[0].function.name, "tool1");
|
||||||
|
|
||||||
// Test second pair
|
|
||||||
let input2 = r#"<<TOOL>>{"name": "tool2", "arguments": {}}<</TOOL>>"#;
|
let input2 = r#"<<TOOL>>{"name": "tool2", "arguments": {}}<</TOOL>>"#;
|
||||||
let result2 = parser.parse_complete(input2).await.unwrap();
|
let result2 = parser.parse_complete(input2).await.unwrap();
|
||||||
assert_eq!(result2.len(), 1);
|
assert_eq!(result2.len(), 1);
|
||||||
@@ -47,7 +44,6 @@ async fn test_json_with_multiple_wrapper_pairs() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_json_with_only_start_token() {
|
async fn test_json_with_only_start_token() {
|
||||||
// Test when only start token is provided (no end token)
|
|
||||||
let parser = JsonParser::with_config(TokenConfig {
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
start_tokens: vec![">>>FUNCTION:".to_string()],
|
start_tokens: vec![">>>FUNCTION:".to_string()],
|
||||||
end_tokens: vec!["".to_string()], // Empty end token
|
end_tokens: vec!["".to_string()], // Empty end token
|
||||||
@@ -232,7 +228,6 @@ async fn test_json_incomplete_wrapper_tokens() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_json_empty_wrapper_tokens() {
|
async fn test_json_empty_wrapper_tokens() {
|
||||||
// Test with empty wrapper tokens (should behave like default)
|
|
||||||
let parser = JsonParser::with_config(TokenConfig {
|
let parser = JsonParser::with_config(TokenConfig {
|
||||||
start_tokens: vec![],
|
start_tokens: vec![],
|
||||||
end_tokens: vec![],
|
end_tokens: vec![],
|
||||||
|
|||||||
Reference in New Issue
Block a user