[router] add cargo clippy in CI and fix-up linting errors (#9242)
This commit is contained in:
6
.github/workflows/pr-test-rust.yml
vendored
6
.github/workflows/pr-test-rust.yml
vendored
@@ -27,6 +27,12 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
bash scripts/ci/ci_install_rust.sh
|
bash scripts/ci/ci_install_rust.sh
|
||||||
|
|
||||||
|
- name: Run lint
|
||||||
|
run: |
|
||||||
|
source "$HOME/.cargo/env"
|
||||||
|
cd sgl-router/
|
||||||
|
cargo clippy --all-targets --all-features -- -D warnings
|
||||||
|
|
||||||
- name: Run fmt
|
- name: Run fmt
|
||||||
run: |
|
run: |
|
||||||
source "$HOME/.cargo/env"
|
source "$HOME/.cargo/env"
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ fn create_test_worker() -> BasicWorker {
|
|||||||
fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
|
fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
|
||||||
let hostname = get_hostname(worker.url());
|
let hostname = get_hostname(worker.url());
|
||||||
let bootstrap_port = match worker.worker_type() {
|
let bootstrap_port = match worker.worker_type() {
|
||||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port.clone(),
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
(hostname, bootstrap_port)
|
(hostname, bootstrap_port)
|
||||||
|
|||||||
@@ -137,8 +137,7 @@ mod tests {
|
|||||||
fn test_worker_result_type_alias() {
|
fn test_worker_result_type_alias() {
|
||||||
// Test Ok variant
|
// Test Ok variant
|
||||||
let result: WorkerResult<i32> = Ok(42);
|
let result: WorkerResult<i32> = Ok(42);
|
||||||
assert!(result.is_ok());
|
assert!(matches!(result, Ok(42)));
|
||||||
assert_eq!(result.unwrap(), 42);
|
|
||||||
|
|
||||||
// Test Err variant
|
// Test Err variant
|
||||||
let error = WorkerError::WorkerNotFound {
|
let error = WorkerError::WorkerNotFound {
|
||||||
|
|||||||
@@ -311,13 +311,7 @@ impl Worker for BasicWorker {
|
|||||||
|
|
||||||
// Use the shared client with a custom timeout for this request
|
// Use the shared client with a custom timeout for this request
|
||||||
let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
|
let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
|
||||||
Ok(response) => {
|
Ok(response) => response.status().is_success(),
|
||||||
if response.status().is_success() {
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => false,
|
Err(_) => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -571,6 +565,7 @@ impl WorkerFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create workers from URLs with automatic type detection
|
/// Create workers from URLs with automatic type detection
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
pub fn create_from_urls(
|
pub fn create_from_urls(
|
||||||
regular_urls: Vec<String>,
|
regular_urls: Vec<String>,
|
||||||
prefill_urls: Vec<(String, Option<u16>)>,
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
@@ -1202,12 +1197,6 @@ mod tests {
|
|||||||
for handle in handles {
|
for handle in handles {
|
||||||
handle.await.unwrap();
|
handle.await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final state should be deterministic (last write wins)
|
|
||||||
// We can't predict the exact final state due to scheduling,
|
|
||||||
// but we can verify no data corruption occurred
|
|
||||||
let final_health = worker.is_healthy();
|
|
||||||
assert!(final_health == true || final_health == false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test WorkerFactory
|
// Test WorkerFactory
|
||||||
|
|||||||
@@ -249,6 +249,7 @@ impl Router {
|
|||||||
health_check_interval_secs = 60,
|
health_check_interval_secs = 60,
|
||||||
health_check_endpoint = String::from("/health"),
|
health_check_endpoint = String::from("/health"),
|
||||||
))]
|
))]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn new(
|
fn new(
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
policy: PolicyType,
|
policy: PolicyType,
|
||||||
|
|||||||
@@ -510,25 +510,9 @@ mod tests {
|
|||||||
|
|
||||||
// ============= Duration Bucket Tests =============
|
// ============= Duration Bucket Tests =============
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_duration_bucket_values() {
|
|
||||||
let expected_buckets = vec![
|
|
||||||
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0,
|
|
||||||
60.0, 90.0, 120.0, 180.0, 240.0,
|
|
||||||
];
|
|
||||||
|
|
||||||
// The buckets are defined in start_prometheus function
|
|
||||||
assert_eq!(expected_buckets.len(), 20);
|
|
||||||
|
|
||||||
// Verify proper ordering
|
|
||||||
for i in 1..expected_buckets.len() {
|
|
||||||
assert!(expected_buckets[i] > expected_buckets[i - 1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_duration_bucket_coverage() {
|
fn test_duration_bucket_coverage() {
|
||||||
let test_cases = vec![
|
let test_cases: [(f64, &str); 7] = [
|
||||||
(0.0005, "sub-millisecond"),
|
(0.0005, "sub-millisecond"),
|
||||||
(0.005, "5ms"),
|
(0.005, "5ms"),
|
||||||
(0.05, "50ms"),
|
(0.05, "50ms"),
|
||||||
@@ -538,7 +522,7 @@ mod tests {
|
|||||||
(240.0, "4m"),
|
(240.0, "4m"),
|
||||||
];
|
];
|
||||||
|
|
||||||
let buckets = vec![
|
let buckets: [f64; 20] = [
|
||||||
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0,
|
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0,
|
||||||
60.0, 90.0, 120.0, 180.0, 240.0,
|
60.0, 90.0, 120.0, 180.0, 240.0,
|
||||||
];
|
];
|
||||||
@@ -546,7 +530,7 @@ mod tests {
|
|||||||
for (duration, label) in test_cases {
|
for (duration, label) in test_cases {
|
||||||
let bucket_found = buckets
|
let bucket_found = buckets
|
||||||
.iter()
|
.iter()
|
||||||
.any(|&b| ((b - duration) as f64).abs() < 0.0001 || b > duration);
|
.any(|&b| (b - duration).abs() < 0.0001 || b > duration);
|
||||||
assert!(bucket_found, "No bucket found for {} ({})", duration, label);
|
assert!(bucket_found, "No bucket found for {} ({})", duration, label);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -558,14 +542,13 @@ mod tests {
|
|||||||
let matcher = Matcher::Suffix(String::from("duration_seconds"));
|
let matcher = Matcher::Suffix(String::from("duration_seconds"));
|
||||||
|
|
||||||
// Test matching behavior
|
// Test matching behavior
|
||||||
let _matching_metrics = vec![
|
let _matching_metrics = [
|
||||||
"request_duration_seconds",
|
"request_duration_seconds",
|
||||||
"response_duration_seconds",
|
"response_duration_seconds",
|
||||||
"sgl_router_request_duration_seconds",
|
"sgl_router_request_duration_seconds",
|
||||||
];
|
];
|
||||||
|
|
||||||
let _non_matching_metrics =
|
let _non_matching_metrics = ["duration_total", "duration_seconds_total", "other_metric"];
|
||||||
vec!["duration_total", "duration_seconds_total", "other_metric"];
|
|
||||||
|
|
||||||
// Note: We can't directly test Matcher matching without the internals,
|
// Note: We can't directly test Matcher matching without the internals,
|
||||||
// but we can verify the matcher is created correctly
|
// but we can verify the matcher is created correctly
|
||||||
@@ -611,8 +594,8 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_custom_buckets_for_different_metrics() {
|
fn test_custom_buckets_for_different_metrics() {
|
||||||
// Test that we can create different bucket configurations
|
// Test that we can create different bucket configurations
|
||||||
let request_buckets = vec![0.001, 0.01, 0.1, 1.0, 10.0];
|
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
|
||||||
let generate_buckets = vec![0.1, 0.5, 1.0, 5.0, 30.0, 60.0];
|
let generate_buckets = [0.1, 0.5, 1.0, 5.0, 30.0, 60.0];
|
||||||
|
|
||||||
assert_eq!(request_buckets.len(), 5);
|
assert_eq!(request_buckets.len(), 5);
|
||||||
assert_eq!(generate_buckets.len(), 6);
|
assert_eq!(generate_buckets.len(), 6);
|
||||||
@@ -730,9 +713,6 @@ mod tests {
|
|||||||
for handle in handles {
|
for handle in handles {
|
||||||
handle.join().unwrap();
|
handle.join().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we get here without panic, concurrent access works
|
|
||||||
assert!(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Edge Cases Tests =============
|
// ============= Edge Cases Tests =============
|
||||||
@@ -743,9 +723,6 @@ mod tests {
|
|||||||
RouterMetrics::record_request("");
|
RouterMetrics::record_request("");
|
||||||
RouterMetrics::set_worker_health("", true);
|
RouterMetrics::set_worker_health("", true);
|
||||||
RouterMetrics::record_policy_decision("", "");
|
RouterMetrics::record_policy_decision("", "");
|
||||||
|
|
||||||
// If we get here without panic, empty strings are handled
|
|
||||||
assert!(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -754,14 +731,11 @@ mod tests {
|
|||||||
|
|
||||||
RouterMetrics::record_request(&long_label);
|
RouterMetrics::record_request(&long_label);
|
||||||
RouterMetrics::set_worker_health(&long_label, false);
|
RouterMetrics::set_worker_health(&long_label, false);
|
||||||
|
|
||||||
// If we get here without panic, long labels are handled
|
|
||||||
assert!(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_special_characters_in_labels() {
|
fn test_special_characters_in_labels() {
|
||||||
let special_labels = vec![
|
let special_labels = [
|
||||||
"test/with/slashes",
|
"test/with/slashes",
|
||||||
"test-with-dashes",
|
"test-with-dashes",
|
||||||
"test_with_underscores",
|
"test_with_underscores",
|
||||||
@@ -773,9 +747,6 @@ mod tests {
|
|||||||
RouterMetrics::record_request(label);
|
RouterMetrics::record_request(label);
|
||||||
RouterMetrics::set_worker_health(label, true);
|
RouterMetrics::set_worker_health(label, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we get here without panic, special characters are handled
|
|
||||||
assert!(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -788,9 +759,7 @@ mod tests {
|
|||||||
RouterMetrics::set_worker_load("worker", usize::MAX);
|
RouterMetrics::set_worker_load("worker", usize::MAX);
|
||||||
|
|
||||||
RouterMetrics::record_request_duration("route", Duration::from_nanos(1));
|
RouterMetrics::record_request_duration("route", Duration::from_nanos(1));
|
||||||
RouterMetrics::record_request_duration("route", Duration::from_secs(86400)); // 24 hours
|
// 24 hours
|
||||||
|
RouterMetrics::record_request_duration("route", Duration::from_secs(86400));
|
||||||
// If we get here without panic, extreme values are handled
|
|
||||||
assert!(true);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ mod tests {
|
|||||||
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
|
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
|
||||||
|
|
||||||
// Run multiple selections
|
// Run multiple selections
|
||||||
let mut selected_counts = vec![0; 3];
|
let mut selected_counts = [0; 3];
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||||
selected_counts[idx] += 1;
|
selected_counts[idx] += 1;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use axum::body::Body;
|
use axum::body::Body;
|
||||||
use axum::extract::Request;
|
use axum::extract::Request;
|
||||||
use axum::http::{HeaderMap, HeaderName, HeaderValue};
|
use axum::http::HeaderMap;
|
||||||
|
|
||||||
/// Copy request headers to a Vec of name-value string pairs
|
/// Copy request headers to a Vec of name-value string pairs
|
||||||
/// Used for forwarding headers to backend workers
|
/// Used for forwarding headers to backend workers
|
||||||
|
|||||||
@@ -363,6 +363,7 @@ impl PDRouter {
|
|||||||
Ok(format!("Successfully removed decode server: {}", url))
|
Ok(format!("Successfully removed decode server: {}", url))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn new(
|
pub async fn new(
|
||||||
prefill_urls: Vec<(String, Option<u16>)>,
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
decode_urls: Vec<String>,
|
decode_urls: Vec<String>,
|
||||||
@@ -733,6 +734,7 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Internal method that performs the actual dual dispatch (without retry logic)
|
// Internal method that performs the actual dual dispatch (without retry logic)
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn execute_dual_dispatch_internal(
|
async fn execute_dual_dispatch_internal(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
@@ -1145,7 +1147,7 @@ impl PDRouter {
|
|||||||
*response.status_mut() = status;
|
*response.status_mut() = status;
|
||||||
|
|
||||||
// Use provided headers or create new ones, then ensure content-type is set for streaming
|
// Use provided headers or create new ones, then ensure content-type is set for streaming
|
||||||
let mut headers = headers.unwrap_or_else(HeaderMap::new);
|
let mut headers = headers.unwrap_or_default();
|
||||||
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||||
*response.headers_mut() = headers;
|
*response.headers_mut() = headers;
|
||||||
|
|
||||||
@@ -1160,41 +1162,41 @@ impl PDRouter {
|
|||||||
return_logprob: bool,
|
return_logprob: bool,
|
||||||
prefill_body: Option<bytes::Bytes>,
|
prefill_body: Option<bytes::Bytes>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
match res.bytes().await {
|
let response = res.bytes().await;
|
||||||
Ok(decode_body) => {
|
let decode_body = match response {
|
||||||
if return_logprob && prefill_body.is_some() {
|
Ok(decode_body) => decode_body,
|
||||||
// Merge logprobs from prefill and decode
|
|
||||||
let prefill_body = prefill_body.as_ref().unwrap();
|
|
||||||
match (
|
|
||||||
serde_json::from_slice::<Value>(prefill_body),
|
|
||||||
serde_json::from_slice::<Value>(&decode_body),
|
|
||||||
) {
|
|
||||||
(Ok(prefill_json), Ok(mut decode_json)) => {
|
|
||||||
// Use helper to merge logprobs
|
|
||||||
Self::merge_logprobs_in_json(&prefill_json, &mut decode_json);
|
|
||||||
|
|
||||||
// Return merged response
|
|
||||||
match serde_json::to_vec(&decode_json) {
|
|
||||||
Ok(body) => (status, body).into_response(),
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to serialize merged response: {}", e);
|
|
||||||
(status, decode_body).into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// If parsing fails, just return decode response
|
|
||||||
warn!("Failed to parse responses for logprob merging");
|
|
||||||
(status, decode_body).into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
(status, decode_body).into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to read decode response: {}", e);
|
error!("Failed to read decode response: {}", e);
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response").into_response()
|
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !return_logprob {
|
||||||
|
return (status, decode_body).into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(prefill_body) = prefill_body else {
|
||||||
|
return (status, decode_body).into_response();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Merge logprobs from prefill and decode
|
||||||
|
let (Ok(prefill_json), Ok(mut decode_json)) = (
|
||||||
|
serde_json::from_slice::<Value>(&prefill_body),
|
||||||
|
serde_json::from_slice::<Value>(&decode_body),
|
||||||
|
) else {
|
||||||
|
warn!("Failed to parse responses for logprob merging");
|
||||||
|
return (status, decode_body).into_response();
|
||||||
|
};
|
||||||
|
|
||||||
|
Self::merge_logprobs_in_json(&prefill_json, &mut decode_json);
|
||||||
|
|
||||||
|
// Return merged response
|
||||||
|
match serde_json::to_vec(&decode_json) {
|
||||||
|
Ok(body) => (status, body).into_response(),
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to serialize merged response: {}", e);
|
||||||
|
(status, decode_body).into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ pub struct Router {
|
|||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
/// Create a new router with injected policy and client
|
/// Create a new router with injected policy and client
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn new(
|
pub async fn new(
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ struct EvictionEntry {
|
|||||||
|
|
||||||
impl Eq for EvictionEntry {}
|
impl Eq for EvictionEntry {}
|
||||||
|
|
||||||
|
#[allow(clippy::non_canonical_partial_ord_impl)]
|
||||||
impl PartialOrd for EvictionEntry {
|
impl PartialOrd for EvictionEntry {
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||||
Some(self.timestamp.cmp(&other.timestamp))
|
Some(self.timestamp.cmp(&other.timestamp))
|
||||||
@@ -862,8 +863,8 @@ mod tests {
|
|||||||
// spawn 3 threads for insert
|
// spawn 3 threads for insert
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
|
|
||||||
let texts = vec!["hello", "apple", "banana"];
|
let texts = ["hello", "apple", "banana"];
|
||||||
let tenants = vec!["tenant1", "tenant2", "tenant3"];
|
let tenants = ["tenant1", "tenant2", "tenant3"];
|
||||||
|
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
@@ -916,13 +917,12 @@ mod tests {
|
|||||||
// spawn 3 threads for insert
|
// spawn 3 threads for insert
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
|
|
||||||
let texts = vec!["apple", "apabc", "acbdeds"];
|
static TEXTS: [&str; 3] = ["apple", "apabc", "acbdeds"];
|
||||||
|
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
for i in 0..3 {
|
for text in TEXTS.iter() {
|
||||||
let tree_clone = Arc::clone(&tree_clone);
|
let tree_clone = Arc::clone(&tree_clone);
|
||||||
let text = texts[i];
|
|
||||||
let tenant = "tenant0";
|
let tenant = "tenant0";
|
||||||
|
|
||||||
let handle = thread::spawn(move || {
|
let handle = thread::spawn(move || {
|
||||||
@@ -942,14 +942,13 @@ mod tests {
|
|||||||
|
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
|
|
||||||
for i in 0..3 {
|
for text in TEXTS.iter() {
|
||||||
let tree_clone = Arc::clone(&tree_clone);
|
let tree_clone = Arc::clone(&tree_clone);
|
||||||
let text = texts[i];
|
|
||||||
let tenant = "tenant0";
|
let tenant = "tenant0";
|
||||||
|
|
||||||
let handle = thread::spawn(move || {
|
let handle = thread::spawn(move || {
|
||||||
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
|
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
|
||||||
assert_eq!(matched_text, text);
|
assert_eq!(matched_text, *text);
|
||||||
assert_eq!(matched_tenant, tenant);
|
assert_eq!(matched_tenant, tenant);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -964,13 +963,13 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_group_prefix_insert_match_concurrent() {
|
fn test_group_prefix_insert_match_concurrent() {
|
||||||
let prefix = vec![
|
static PREFIXES: [&str; 4] = [
|
||||||
"Clock strikes midnight, I'm still wide awake",
|
"Clock strikes midnight, I'm still wide awake",
|
||||||
"Got dreams bigger than these city lights",
|
"Got dreams bigger than these city lights",
|
||||||
"Time waits for no one, gotta make my move",
|
"Time waits for no one, gotta make my move",
|
||||||
"Started from the bottom, that's no metaphor",
|
"Started from the bottom, that's no metaphor",
|
||||||
];
|
];
|
||||||
let suffix = vec![
|
let suffixes = [
|
||||||
"Got too much to prove, ain't got time to lose",
|
"Got too much to prove, ain't got time to lose",
|
||||||
"History in the making, yeah, you can't erase this",
|
"History in the making, yeah, you can't erase this",
|
||||||
];
|
];
|
||||||
@@ -978,10 +977,10 @@ mod tests {
|
|||||||
|
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
for i in 0..prefix.len() {
|
for (i, prefix) in PREFIXES.iter().enumerate() {
|
||||||
for j in 0..suffix.len() {
|
for suffix in suffixes.iter() {
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
let text = format!("{} {}", prefix[i], suffix[j]);
|
let text = format!("{} {}", prefix, suffix);
|
||||||
let tenant = format!("tenant{}", i);
|
let tenant = format!("tenant{}", i);
|
||||||
|
|
||||||
let handle = thread::spawn(move || {
|
let handle = thread::spawn(move || {
|
||||||
@@ -1000,17 +999,15 @@ mod tests {
|
|||||||
tree.pretty_print();
|
tree.pretty_print();
|
||||||
|
|
||||||
// check matching using multi threads
|
// check matching using multi threads
|
||||||
|
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
for i in 0..prefix.len() {
|
for (i, prefix) in PREFIXES.iter().enumerate() {
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
let text = prefix[i];
|
|
||||||
|
|
||||||
let handle = thread::spawn(move || {
|
let handle = thread::spawn(move || {
|
||||||
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
|
let (matched_text, matched_tenant) = tree_clone.prefix_match(prefix);
|
||||||
let tenant = format!("tenant{}", i);
|
let tenant = format!("tenant{}", i);
|
||||||
assert_eq!(matched_text, text);
|
assert_eq!(matched_text, *prefix);
|
||||||
assert_eq!(matched_tenant, tenant);
|
assert_eq!(matched_tenant, tenant);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -1027,13 +1024,13 @@ mod tests {
|
|||||||
fn test_mixed_concurrent_insert_match() {
|
fn test_mixed_concurrent_insert_match() {
|
||||||
// ensure it does not deadlock instead of doing correctness check
|
// ensure it does not deadlock instead of doing correctness check
|
||||||
|
|
||||||
let prefix = vec![
|
static PREFIXES: [&str; 4] = [
|
||||||
"Clock strikes midnight, I'm still wide awake",
|
"Clock strikes midnight, I'm still wide awake",
|
||||||
"Got dreams bigger than these city lights",
|
"Got dreams bigger than these city lights",
|
||||||
"Time waits for no one, gotta make my move",
|
"Time waits for no one, gotta make my move",
|
||||||
"Started from the bottom, that's no metaphor",
|
"Started from the bottom, that's no metaphor",
|
||||||
];
|
];
|
||||||
let suffix = vec![
|
let suffixes = [
|
||||||
"Got too much to prove, ain't got time to lose",
|
"Got too much to prove, ain't got time to lose",
|
||||||
"History in the making, yeah, you can't erase this",
|
"History in the making, yeah, you can't erase this",
|
||||||
];
|
];
|
||||||
@@ -1041,10 +1038,10 @@ mod tests {
|
|||||||
|
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
for i in 0..prefix.len() {
|
for (i, prefix) in PREFIXES.iter().enumerate() {
|
||||||
for j in 0..suffix.len() {
|
for suffix in suffixes.iter() {
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
let text = format!("{} {}", prefix[i], suffix[j]);
|
let text = format!("{} {}", prefix, suffix);
|
||||||
let tenant = format!("tenant{}", i);
|
let tenant = format!("tenant{}", i);
|
||||||
|
|
||||||
let handle = thread::spawn(move || {
|
let handle = thread::spawn(move || {
|
||||||
@@ -1056,13 +1053,11 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check matching using multi threads
|
// check matching using multi threads
|
||||||
|
for prefix in PREFIXES.iter() {
|
||||||
for i in 0..prefix.len() {
|
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
let text = prefix[i];
|
|
||||||
|
|
||||||
let handle = thread::spawn(move || {
|
let handle = thread::spawn(move || {
|
||||||
let (_matched_text, _matched_tenant) = tree_clone.prefix_match(text);
|
let (_matched_text, _matched_tenant) = tree_clone.prefix_match(prefix);
|
||||||
});
|
});
|
||||||
|
|
||||||
handles.push(handle);
|
handles.push(handle);
|
||||||
@@ -1080,16 +1075,14 @@ mod tests {
|
|||||||
// use .chars() to get the iterator of the utf-8 value
|
// use .chars() to get the iterator of the utf-8 value
|
||||||
let tree = Arc::new(Tree::new());
|
let tree = Arc::new(Tree::new());
|
||||||
|
|
||||||
let test_pairs = vec![
|
static TEST_PAIRS: [(&str, &str); 3] = [
|
||||||
("你好嗎", "tenant1"),
|
("你好嗎", "tenant1"),
|
||||||
("你好喔", "tenant2"),
|
("你好喔", "tenant2"),
|
||||||
("你心情好嗎", "tenant3"),
|
("你心情好嗎", "tenant3"),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Insert sequentially
|
// Insert sequentially
|
||||||
for i in 0..test_pairs.len() {
|
for (text, tenant) in TEST_PAIRS.iter() {
|
||||||
let text = test_pairs[i].0;
|
|
||||||
let tenant = test_pairs[i].1;
|
|
||||||
tree.insert(text, tenant);
|
tree.insert(text, tenant);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1097,10 +1090,10 @@ mod tests {
|
|||||||
|
|
||||||
// Test sequentially
|
// Test sequentially
|
||||||
|
|
||||||
for i in 0..test_pairs.len() {
|
for (text, tenant) in TEST_PAIRS.iter() {
|
||||||
let (matched_text, matched_tenant) = tree.prefix_match(test_pairs[i].0);
|
let (matched_text, matched_tenant) = tree.prefix_match(text);
|
||||||
assert_eq!(matched_text, test_pairs[i].0);
|
assert_eq!(matched_text, *text);
|
||||||
assert_eq!(matched_tenant, test_pairs[i].1);
|
assert_eq!(matched_tenant, *tenant);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1108,7 +1101,7 @@ mod tests {
|
|||||||
fn test_utf8_split_concurrent() {
|
fn test_utf8_split_concurrent() {
|
||||||
let tree = Arc::new(Tree::new());
|
let tree = Arc::new(Tree::new());
|
||||||
|
|
||||||
let test_pairs = vec![
|
static TEST_PAIRS: [(&str, &str); 3] = [
|
||||||
("你好嗎", "tenant1"),
|
("你好嗎", "tenant1"),
|
||||||
("你好喔", "tenant2"),
|
("你好喔", "tenant2"),
|
||||||
("你心情好嗎", "tenant3"),
|
("你心情好嗎", "tenant3"),
|
||||||
@@ -1117,13 +1110,11 @@ mod tests {
|
|||||||
// Create multiple threads for insertion
|
// Create multiple threads for insertion
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
for i in 0..test_pairs.len() {
|
for (text, tenant) in TEST_PAIRS.iter() {
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
let text = test_pairs[i].0.to_string();
|
|
||||||
let tenant = test_pairs[i].1.to_string();
|
|
||||||
|
|
||||||
let handle = thread::spawn(move || {
|
let handle = thread::spawn(move || {
|
||||||
tree_clone.insert(&text, &tenant);
|
tree_clone.insert(text, tenant);
|
||||||
});
|
});
|
||||||
|
|
||||||
handles.push(handle);
|
handles.push(handle);
|
||||||
@@ -1139,15 +1130,13 @@ mod tests {
|
|||||||
// Create multiple threads for matching
|
// Create multiple threads for matching
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
for i in 0..test_pairs.len() {
|
for (text, tenant) in TEST_PAIRS.iter() {
|
||||||
let tree_clone = Arc::clone(&tree);
|
let tree_clone = Arc::clone(&tree);
|
||||||
let text = test_pairs[i].0.to_string();
|
|
||||||
let tenant = test_pairs[i].1.to_string();
|
|
||||||
|
|
||||||
let handle = thread::spawn(move || {
|
let handle = thread::spawn(move || {
|
||||||
let (matched_text, matched_tenant) = tree_clone.prefix_match(&text);
|
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
|
||||||
assert_eq!(matched_text, text);
|
assert_eq!(matched_text, *text);
|
||||||
assert_eq!(matched_tenant, tenant);
|
assert_eq!(matched_tenant, *tenant);
|
||||||
});
|
});
|
||||||
|
|
||||||
handles.push(handle);
|
handles.push(handle);
|
||||||
@@ -1202,7 +1191,7 @@ mod tests {
|
|||||||
let max_size: usize = 100;
|
let max_size: usize = 100;
|
||||||
|
|
||||||
// Define prefixes
|
// Define prefixes
|
||||||
let prefixes = vec!["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"];
|
let prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"];
|
||||||
|
|
||||||
// Insert strings with shared prefixes
|
// Insert strings with shared prefixes
|
||||||
for _i in 0..100 {
|
for _i in 0..100 {
|
||||||
|
|||||||
@@ -718,7 +718,7 @@ mod worker_management_tests {
|
|||||||
// Add the worker
|
// Add the worker
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.uri(&format!("/add_worker?url={}", url))
|
.uri(format!("/add_worker?url={}", url))
|
||||||
.body(Body::empty())
|
.body(Body::empty())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -776,7 +776,7 @@ mod worker_management_tests {
|
|||||||
// Remove the worker
|
// Remove the worker
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.uri(&format!("/remove_worker?url={}", worker_url))
|
.uri(format!("/remove_worker?url={}", worker_url))
|
||||||
.body(Body::empty())
|
.body(Body::empty())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -856,7 +856,7 @@ mod worker_management_tests {
|
|||||||
// Add worker first time
|
// Add worker first time
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.uri(&format!("/add_worker?url={}", url))
|
.uri(format!("/add_worker?url={}", url))
|
||||||
.body(Body::empty())
|
.body(Body::empty())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
@@ -867,7 +867,7 @@ mod worker_management_tests {
|
|||||||
// Try to add same worker again
|
// Try to add same worker again
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.uri(&format!("/add_worker?url={}", url))
|
.uri(format!("/add_worker?url={}", url))
|
||||||
.body(Body::empty())
|
.body(Body::empty())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let resp = app.oneshot(req).await.unwrap();
|
let resp = app.oneshot(req).await.unwrap();
|
||||||
@@ -896,7 +896,7 @@ mod worker_management_tests {
|
|||||||
// Try to add unhealthy worker
|
// Try to add unhealthy worker
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.uri(&format!("/add_worker?url={}", url))
|
.uri(format!("/add_worker?url={}", url))
|
||||||
.body(Body::empty())
|
.body(Body::empty())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let resp = app.oneshot(req).await.unwrap();
|
let resp = app.oneshot(req).await.unwrap();
|
||||||
@@ -1412,7 +1412,7 @@ mod pd_mode_tests {
|
|||||||
// Extract port from prefill URL
|
// Extract port from prefill URL
|
||||||
let prefill_port = prefill_url
|
let prefill_port = prefill_url
|
||||||
.split(':')
|
.split(':')
|
||||||
.last()
|
.next_back()
|
||||||
.and_then(|p| p.trim_end_matches('/').parse::<u16>().ok())
|
.and_then(|p| p.trim_end_matches('/').parse::<u16>().ok())
|
||||||
.unwrap_or(9000);
|
.unwrap_or(9000);
|
||||||
|
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ fn default_completion_request() -> CompletionRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
fn create_test_worker() -> BasicWorker {
|
fn create_test_worker() -> BasicWorker {
|
||||||
BasicWorker::new(
|
BasicWorker::new(
|
||||||
"http://test-server:8000".to_string(),
|
"http://test-server:8000".to_string(),
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ use sglang_router_rs::{
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Create a test Axum application using the actual server's build_app function
|
/// Create a test Axum application using the actual server's build_app function
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn create_test_app(
|
pub fn create_test_app(
|
||||||
router: Arc<dyn RouterTrait>,
|
router: Arc<dyn RouterTrait>,
|
||||||
client: Client,
|
client: Client,
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ impl TestContext {
|
|||||||
let worker_url = &worker_urls[0];
|
let worker_url = &worker_urls[0];
|
||||||
|
|
||||||
let response = client
|
let response = client
|
||||||
.post(&format!("{}{}", worker_url, endpoint))
|
.post(format!("{}{}", worker_url, endpoint))
|
||||||
.json(&body)
|
.json(&body)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ impl TestContext {
|
|||||||
let worker_url = &worker_urls[0];
|
let worker_url = &worker_urls[0];
|
||||||
|
|
||||||
let response = client
|
let response = client
|
||||||
.post(&format!("{}{}", worker_url, endpoint))
|
.post(format!("{}{}", worker_url, endpoint))
|
||||||
.json(&body)
|
.json(&body)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
@@ -128,8 +128,8 @@ impl TestContext {
|
|||||||
if let Ok(bytes) = chunk {
|
if let Ok(bytes) = chunk {
|
||||||
let text = String::from_utf8_lossy(&bytes);
|
let text = String::from_utf8_lossy(&bytes);
|
||||||
for line in text.lines() {
|
for line in text.lines() {
|
||||||
if line.starts_with("data: ") {
|
if let Some(stripped) = line.strip_prefix("data: ") {
|
||||||
events.push(line[6..].to_string());
|
events.push(stripped.to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test_pd_routing {
|
mod test_pd_routing {
|
||||||
use rand::Rng;
|
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{
|
use sglang_router_rs::config::{
|
||||||
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
@@ -421,41 +420,6 @@ mod test_pd_routing {
|
|||||||
assert_eq!(received_loads.get("http://decode2:8080"), Some(&15));
|
assert_eq!(received_loads.get("http://decode2:8080"), Some(&15));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_power_of_two_load_selection() {
|
|
||||||
// Test the power-of-two selection logic with different load scenarios
|
|
||||||
|
|
||||||
// Scenario 1: Clear winner for both prefill and decode
|
|
||||||
let _loads = vec![
|
|
||||||
("prefill1", 100),
|
|
||||||
("prefill2", 10), // Should be selected
|
|
||||||
("decode1", 50),
|
|
||||||
("decode2", 5), // Should be selected
|
|
||||||
];
|
|
||||||
|
|
||||||
// In actual implementation, the lower load should be selected
|
|
||||||
assert!(10 < 100);
|
|
||||||
assert!(5 < 50);
|
|
||||||
|
|
||||||
// Scenario 2: Equal loads (should select first)
|
|
||||||
let _equal_loads = vec![
|
|
||||||
("prefill1", 20),
|
|
||||||
("prefill2", 20), // Either could be selected
|
|
||||||
("decode1", 30),
|
|
||||||
("decode2", 30), // Either could be selected
|
|
||||||
];
|
|
||||||
|
|
||||||
// When loads are equal, <= comparison means first is selected
|
|
||||||
assert!(20 <= 20);
|
|
||||||
assert!(30 <= 30);
|
|
||||||
|
|
||||||
// Scenario 3: Missing load data (should default to usize::MAX)
|
|
||||||
// This tests the unwrap_or(usize::MAX) behavior
|
|
||||||
let missing_load = usize::MAX;
|
|
||||||
assert!(10 < missing_load);
|
|
||||||
assert!(missing_load > 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_load_monitoring_configuration() {
|
fn test_load_monitoring_configuration() {
|
||||||
// Test that load monitoring is only enabled for PowerOfTwo policy
|
// Test that load monitoring is only enabled for PowerOfTwo policy
|
||||||
@@ -605,12 +569,10 @@ mod test_pd_routing {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_streaming_response_parsing() {
|
fn test_streaming_response_parsing() {
|
||||||
// Test SSE format parsing from streaming responses
|
// Test SSE format parsing from streaming responses
|
||||||
let sse_chunks = vec![
|
let sse_chunks = ["data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
|
||||||
"data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
|
|
||||||
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
|
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
|
||||||
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
|
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
|
||||||
"data: [DONE]",
|
"data: [DONE]"];
|
||||||
];
|
|
||||||
|
|
||||||
for chunk in &sse_chunks[..3] {
|
for chunk in &sse_chunks[..3] {
|
||||||
assert!(chunk.starts_with("data: "));
|
assert!(chunk.starts_with("data: "));
|
||||||
@@ -848,7 +810,7 @@ mod test_pd_routing {
|
|||||||
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||||
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||||
large_batch_request["bootstrap_room"] = json!((0..batch_size)
|
large_batch_request["bootstrap_room"] = json!((0..batch_size)
|
||||||
.map(|_| rand::thread_rng().gen::<u64>())
|
.map(|_| rand::random::<u64>())
|
||||||
.collect::<Vec<_>>());
|
.collect::<Vec<_>>());
|
||||||
|
|
||||||
let elapsed = start.elapsed();
|
let elapsed = start.elapsed();
|
||||||
|
|||||||
Reference in New Issue
Block a user