diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index f95cea28e..77380fc3a 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -54,7 +54,9 @@ jobs: run: | source "$HOME/.cargo/env" cd sgl-router/ - cargo fmt -- --check + rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt + rustup toolchain install nightly --profile minimal + cargo +nightly fmt -- --check - name: Run Rust tests timeout-minutes: 20 diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index f9bce7942..03cf123f9 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -1,14 +1,18 @@ -use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; -use serde_json::{from_str, to_string, to_value, to_vec}; use std::time::Instant; -use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType}; -use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}; -use sglang_router_rs::protocols::common::StringOrArray; -use sglang_router_rs::protocols::completion::CompletionRequest; -use sglang_router_rs::protocols::generate::GenerateRequest; -use sglang_router_rs::protocols::sampling_params::SamplingParams; -use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use serde_json::{from_str, to_string, to_value, to_vec}; +use sglang_router_rs::{ + core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType}, + protocols::{ + chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, + common::StringOrArray, + completion::CompletionRequest, + generate::GenerateRequest, + sampling_params::SamplingParams, + }, + routers::http::pd_types::{generate_room_id, RequestWithBootstrap}, +}; fn create_test_worker() -> BasicWorker { BasicWorkerBuilder::new("http://test-server:8000") diff --git a/sgl-router/benches/tokenizer_benchmark.rs b/sgl-router/benches/tokenizer_benchmark.rs index a40abcc4e..6830d6797 100644 --- a/sgl-router/benches/tokenizer_benchmark.rs +++ b/sgl-router/benches/tokenizer_benchmark.rs @@ -1,16 +1,21 @@ //! Comprehensive tokenizer benchmark with clean summary output //! Each test adds a row to the final summary table +use std::{ + collections::BTreeMap, + path::PathBuf, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, Mutex, OnceLock, + }, + thread, + time::{Duration, Instant}, +}; + use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use sglang_router_rs::tokenizer::{ huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*, }; -use std::collections::BTreeMap; -use std::path::PathBuf; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; -use std::thread; -use std::time::{Duration, Instant}; // Include the common test utilities #[path = "../tests/common/mod.rs"] diff --git a/sgl-router/benches/tool_parser_benchmark.rs b/sgl-router/benches/tool_parser_benchmark.rs index 55c965fb7..0e09bbd4c 100644 --- a/sgl-router/benches/tool_parser_benchmark.rs +++ b/sgl-router/benches/tool_parser_benchmark.rs @@ -7,15 +7,22 @@ //! - Streaming vs complete parsing //! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.) +use std::{ + collections::BTreeMap, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, Mutex, + }, + thread, + time::{Duration, Instant}, +}; + use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use serde_json::json; -use sglang_router_rs::protocols::common::{Function, Tool}; -use sglang_router_rs::tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser}; -use std::collections::BTreeMap; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; -use std::thread; -use std::time::{Duration, Instant}; +use sglang_router_rs::{ + protocols::common::{Function, Tool}, + tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser}, +}; use tokio::runtime::Runtime; // Test data for different parser formats - realistic complex examples diff --git a/sgl-router/rustfmt.toml b/sgl-router/rustfmt.toml new file mode 100644 index 000000000..19e1ab31f --- /dev/null +++ b/sgl-router/rustfmt.toml @@ -0,0 +1,8 @@ +# Rust formatting configuration + +# Enforce grouped imports by crate +imports_granularity = "Crate" +# Group std, external crates, and local crate imports separately +group_imports = "StdExternalCrate" +reorder_imports = true +reorder_modules = true diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index cdb972092..dc704657b 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -1,7 +1,9 @@ -use super::ConfigResult; -use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use serde::{Deserialize, Serialize}; + +use super::ConfigResult; + /// Main router configuration #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RouterConfig { diff --git a/sgl-router/src/core/circuit_breaker.rs b/sgl-router/src/core/circuit_breaker.rs index 47542829f..3a91f32fc 100644 --- a/sgl-router/src/core/circuit_breaker.rs +++ b/sgl-router/src/core/circuit_breaker.rs @@ -1,6 +1,11 @@ -use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; -use std::sync::{Arc, RwLock}; -use std::time::{Duration, Instant}; +use std::{ + sync::{ + atomic::{AtomicU32, AtomicU64, Ordering}, + Arc, RwLock, + }, + time::{Duration, Instant}, +}; + use tracing::info; /// Circuit breaker configuration @@ -316,9 +321,10 @@ pub struct CircuitBreakerStats { #[cfg(test)] mod tests { - use super::*; use std::thread; + use super::*; + #[test] fn test_circuit_breaker_initial_state() { let cb = CircuitBreaker::new(); diff --git a/sgl-router/src/core/error.rs b/sgl-router/src/core/error.rs index fbe033590..740e9205f 100644 --- a/sgl-router/src/core/error.rs +++ b/sgl-router/src/core/error.rs @@ -68,9 +68,10 @@ impl From for WorkerError { #[cfg(test)] mod tests { - use super::*; use std::error::Error; + use super::*; + #[test] fn test_health_check_failed_display() { let error = WorkerError::HealthCheckFailed { diff --git a/sgl-router/src/core/job_queue.rs b/sgl-router/src/core/job_queue.rs index 3324ffafa..b6d3e86e1 100644 --- a/sgl-router/src/core/job_queue.rs +++ b/sgl-router/src/core/job_queue.rs @@ -3,16 +3,22 @@ //! Provides non-blocking worker management by queuing operations and processing //! them asynchronously in background worker tasks. -use crate::core::WorkerManager; -use crate::protocols::worker_spec::{JobStatus, WorkerConfigRequest}; -use crate::server::AppContext; +use std::{ + sync::{Arc, Weak}, + time::{Duration, SystemTime}, +}; + use dashmap::DashMap; use metrics::{counter, gauge, histogram}; -use std::sync::{Arc, Weak}; -use std::time::{Duration, SystemTime}; use tokio::sync::mpsc; use tracing::{debug, error, info, warn}; +use crate::{ + core::WorkerManager, + protocols::worker_spec::{JobStatus, WorkerConfigRequest}, + server::AppContext, +}; + /// Job types for control plane operations #[derive(Debug, Clone)] pub enum Job { diff --git a/sgl-router/src/core/retry.rs b/sgl-router/src/core/retry.rs index a6584375d..d2ced03a1 100644 --- a/sgl-router/src/core/retry.rs +++ b/sgl-router/src/core/retry.rs @@ -1,10 +1,11 @@ -use crate::config::types::RetryConfig; -use axum::http::StatusCode; -use axum::response::Response; -use rand::Rng; use std::time::Duration; + +use axum::{http::StatusCode, response::Response}; +use rand::Rng; use tracing::debug; +use crate::config::types::RetryConfig; + /// Check if an HTTP status code indicates a retryable error pub fn is_retryable_status(status: StatusCode) -> bool { matches!( @@ -162,11 +163,14 @@ impl RetryExecutor { #[cfg(test)] mod tests { + use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }; + + use axum::{http::StatusCode, response::IntoResponse}; + use super::*; - use axum::http::StatusCode; - use axum::response::IntoResponse; - use std::sync::atomic::{AtomicU32, Ordering}; - use std::sync::Arc; fn base_retry_config() -> RetryConfig { RetryConfig { diff --git a/sgl-router/src/core/token_bucket.rs b/sgl-router/src/core/token_bucket.rs index 781f27e7f..af862e60b 100644 --- a/sgl-router/src/core/token_bucket.rs +++ b/sgl-router/src/core/token_bucket.rs @@ -1,5 +1,8 @@ -use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; + use tokio::sync::{Mutex, Notify}; use tracing::{debug, trace}; diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 249d71592..2284b789d 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -1,19 +1,27 @@ -use super::{CircuitBreaker, WorkerError, WorkerResult}; -use crate::core::CircuitState; -use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder}; -use crate::grpc_client::SglangSchedulerClient; -use crate::metrics::RouterMetrics; -use crate::protocols::worker_spec::WorkerInfo; +use std::{ + fmt, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, LazyLock, + }, + time::{Duration, Instant}, +}; + use async_trait::async_trait; use futures; use serde_json; -use std::fmt; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::{Arc, LazyLock}; -use std::time::Duration; -use std::time::Instant; -use tokio::sync::{Mutex, RwLock}; -use tokio::time; +use tokio::{ + sync::{Mutex, RwLock}, + time, +}; + +use super::{CircuitBreaker, WorkerError, WorkerResult}; +use crate::{ + core::{BasicWorkerBuilder, CircuitState, DPAwareWorkerBuilder}, + grpc_client::SglangSchedulerClient, + metrics::RouterMetrics, + protocols::worker_spec::WorkerInfo, +}; static WORKER_CLIENT: LazyLock = LazyLock::new(|| { reqwest::Client::builder() @@ -1024,10 +1032,10 @@ pub fn worker_to_info(worker: &Arc) -> WorkerInfo { #[cfg(test)] mod tests { + use std::{thread, time::Duration}; + use super::*; use crate::core::CircuitBreakerConfig; - use std::thread; - use std::time::Duration; #[test] fn test_worker_type_display() { @@ -1502,9 +1510,10 @@ mod tests { #[test] fn test_load_counter_performance() { - use crate::core::BasicWorkerBuilder; use std::time::Instant; + use crate::core::BasicWorkerBuilder; + let worker = BasicWorkerBuilder::new("http://test:8080") .worker_type(WorkerType::Regular) .build(); diff --git a/sgl-router/src/core/worker_builder.rs b/sgl-router/src/core/worker_builder.rs index 77863263b..fd30c4bd8 100644 --- a/sgl-router/src/core/worker_builder.rs +++ b/sgl-router/src/core/worker_builder.rs @@ -1,9 +1,12 @@ -use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; -use super::worker::{ - BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType, +use std::collections::HashMap; + +use super::{ + circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}, + worker::{ + BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType, + }, }; use crate::grpc_client::SglangSchedulerClient; -use std::collections::HashMap; /// Builder for creating BasicWorker instances with fluent API pub struct BasicWorkerBuilder { @@ -100,6 +103,7 @@ impl BasicWorkerBuilder { atomic::{AtomicBool, AtomicUsize}, Arc, }; + use tokio::sync::{Mutex, RwLock}; let bootstrap_host = match url::Url::parse(&self.url) { @@ -282,9 +286,10 @@ impl DPAwareWorkerBuilder { #[cfg(test)] mod tests { + use std::time::Duration; + use super::*; use crate::core::worker::Worker; - use std::time::Duration; #[test] fn test_basic_worker_builder_minimal() { diff --git a/sgl-router/src/core/worker_manager.rs b/sgl-router/src/core/worker_manager.rs index c0c0c3c3e..58d397da6 100644 --- a/sgl-router/src/core/worker_manager.rs +++ b/sgl-router/src/core/worker_manager.rs @@ -3,31 +3,35 @@ //! Handles all aspects of worker lifecycle including discovery, initialization, //! runtime management, and health monitoring. -use crate::config::types::{ - CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode, - HealthCheckConfig, RouterConfig, RoutingMode, -}; -use crate::core::{ - BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig, - Worker, WorkerFactory, WorkerRegistry, WorkerType, -}; -use crate::grpc_client::SglangSchedulerClient; -use crate::policies::PolicyRegistry; -use crate::protocols::worker_spec::{ - FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult, -}; -use crate::server::AppContext; +use std::{collections::HashMap, sync::Arc, time::Duration}; + use futures::future; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::{watch, Mutex}; -use tokio::task::JoinHandle; +use tokio::{ + sync::{watch, Mutex}, + task::JoinHandle, +}; use tracing::{debug, error, info, warn}; +use crate::{ + config::types::{ + CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode, + HealthCheckConfig, RouterConfig, RoutingMode, + }, + core::{ + BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, + HealthConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType, + }, + grpc_client::SglangSchedulerClient, + policies::PolicyRegistry, + protocols::worker_spec::{ + FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult, + }, + server::AppContext, +}; + static HTTP_CLIENT: Lazy = Lazy::new(|| { reqwest::Client::builder() .timeout(Duration::from_secs(10)) @@ -1803,9 +1807,10 @@ impl Drop for LoadMonitor { #[cfg(test)] mod tests { - use super::*; use std::collections::HashMap; + use super::*; + #[test] fn test_parse_server_info() { let json = serde_json::json!({ diff --git a/sgl-router/src/core/worker_registry.rs b/sgl-router/src/core/worker_registry.rs index 0e58681e6..95e09f87c 100644 --- a/sgl-router/src/core/worker_registry.rs +++ b/sgl-router/src/core/worker_registry.rs @@ -2,11 +2,13 @@ //! //! Provides centralized registry for workers with model-based indexing -use crate::core::{ConnectionMode, Worker, WorkerType}; -use dashmap::DashMap; use std::sync::{Arc, RwLock}; + +use dashmap::DashMap; use uuid::Uuid; +use crate::core::{ConnectionMode, Worker, WorkerType}; + /// Unique identifier for a worker #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub struct WorkerId(String); @@ -363,8 +365,10 @@ impl WorkerRegistry { /// Start a health checker for all workers in the registry /// This should be called once after the registry is populated with workers pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker { - use std::sync::atomic::{AtomicBool, Ordering}; - use std::sync::Arc; + use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }; let shutdown = Arc::new(AtomicBool::new(false)); let shutdown_clone = shutdown.clone(); @@ -433,9 +437,10 @@ pub struct WorkerRegistryStats { #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig}; - use std::collections::HashMap; #[test] fn test_worker_registry() { diff --git a/sgl-router/src/data_connector/conversation_item_memory_store.rs b/sgl-router/src/data_connector/conversation_item_memory_store.rs index d15d17031..1d4ad1b36 100644 --- a/sgl-router/src/data_connector/conversation_item_memory_store.rs +++ b/sgl-router/src/data_connector/conversation_item_memory_store.rs @@ -1,14 +1,18 @@ -use std::collections::{BTreeMap, HashMap}; -use std::sync::RwLock; +use std::{ + collections::{BTreeMap, HashMap}, + sync::RwLock, +}; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use super::conversation_items::{ - make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams, - Result, SortOrder, +use super::{ + conversation_items::{ + make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams, + Result, SortOrder, + }, + conversations::ConversationId, }; -use super::conversations::ConversationId; #[derive(Default)] pub struct MemoryConversationItemStorage { @@ -190,9 +194,10 @@ impl ConversationItemStorage for MemoryConversationItemStorage { #[cfg(test)] mod tests { - use super::*; use chrono::{TimeZone, Utc}; + use super::*; + fn make_item( item_type: &str, role: Option<&str>, diff --git a/sgl-router/src/data_connector/conversation_item_oracle_store.rs b/sgl-router/src/data_connector/conversation_item_oracle_store.rs index 608c9376c..70d16f8bd 100644 --- a/sgl-router/src/data_connector/conversation_item_oracle_store.rs +++ b/sgl-router/src/data_connector/conversation_item_oracle_store.rs @@ -1,18 +1,21 @@ -use crate::config::OracleConfig; -use crate::data_connector::conversation_items::{ - make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, - ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder, -}; -use crate::data_connector::conversations::ConversationId; +use std::{path::Path, sync::Arc, time::Duration}; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; -use oracle::sql_type::ToSql; -use oracle::Connection; +use oracle::{sql_type::ToSql, Connection}; use serde_json::Value; -use std::path::Path; -use std::sync::Arc; -use std::time::Duration; + +use crate::{ + config::OracleConfig, + data_connector::{ + conversation_items::{ + make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, + ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder, + }, + conversations::ConversationId, + }, +}; #[derive(Clone)] pub struct OracleConversationItemStorage { diff --git a/sgl-router/src/data_connector/conversation_items.rs b/sgl-router/src/data_connector/conversation_items.rs index 6c99b2b8f..3fa2009b1 100644 --- a/sgl-router/src/data_connector/conversation_items.rs +++ b/sgl-router/src/data_connector/conversation_items.rs @@ -1,10 +1,13 @@ +use std::{ + fmt::{Display, Formatter}, + sync::Arc, +}; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use rand::RngCore; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::fmt::{Display, Formatter}; -use std::sync::Arc; use super::conversations::ConversationId; diff --git a/sgl-router/src/data_connector/conversation_memory_store.rs b/sgl-router/src/data_connector/conversation_memory_store.rs index c2091c6b2..5b7896c6c 100644 --- a/sgl-router/src/data_connector/conversation_memory_store.rs +++ b/sgl-router/src/data_connector/conversation_memory_store.rs @@ -1,7 +1,7 @@ +use std::{collections::HashMap, sync::Arc}; + use async_trait::async_trait; use parking_lot::RwLock; -use std::collections::HashMap; -use std::sync::Arc; use super::conversations::{ Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation, diff --git a/sgl-router/src/data_connector/conversation_oracle_store.rs b/sgl-router/src/data_connector/conversation_oracle_store.rs index 452b85c2c..f534a9c6e 100644 --- a/sgl-router/src/data_connector/conversation_oracle_store.rs +++ b/sgl-router/src/data_connector/conversation_oracle_store.rs @@ -1,16 +1,18 @@ -use crate::config::OracleConfig; -use crate::data_connector::conversations::{ - Conversation, ConversationId, ConversationMetadata, ConversationStorage, - ConversationStorageError, NewConversation, Result, -}; +use std::{path::Path, sync::Arc, time::Duration}; + use async_trait::async_trait; use chrono::Utc; use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; use oracle::{sql_type::OracleType, Connection}; use serde_json::Value; -use std::path::Path; -use std::sync::Arc; -use std::time::Duration; + +use crate::{ + config::OracleConfig, + data_connector::conversations::{ + Conversation, ConversationId, ConversationMetadata, ConversationStorage, + ConversationStorageError, NewConversation, Result, + }, +}; #[derive(Clone)] pub struct OracleConversationStorage { diff --git a/sgl-router/src/data_connector/conversations.rs b/sgl-router/src/data_connector/conversations.rs index 3c27555cb..96487c2f6 100644 --- a/sgl-router/src/data_connector/conversations.rs +++ b/sgl-router/src/data_connector/conversations.rs @@ -1,10 +1,13 @@ +use std::{ + fmt::{Display, Formatter}, + sync::Arc, +}; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use rand::RngCore; use serde::{Deserialize, Serialize}; use serde_json::{Map as JsonMap, Value}; -use std::fmt::{Display, Formatter}; -use std::sync::Arc; #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub struct ConversationId(pub String); diff --git a/sgl-router/src/data_connector/response_memory_store.rs b/sgl-router/src/data_connector/response_memory_store.rs index d9067ef54..767c2f5c0 100644 --- a/sgl-router/src/data_connector/response_memory_store.rs +++ b/sgl-router/src/data_connector/response_memory_store.rs @@ -1,7 +1,7 @@ +use std::{collections::HashMap, sync::Arc}; + use async_trait::async_trait; use parking_lot::RwLock; -use std::collections::HashMap; -use std::sync::Arc; use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse}; diff --git a/sgl-router/src/data_connector/response_oracle_store.rs b/sgl-router/src/data_connector/response_oracle_store.rs index 5ad3fab5f..cc4a8d5cc 100644 --- a/sgl-router/src/data_connector/response_oracle_store.rs +++ b/sgl-router/src/data_connector/response_oracle_store.rs @@ -1,16 +1,17 @@ -use crate::config::OracleConfig; -use crate::data_connector::responses::{ - ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult, - StoredResponse, -}; +use std::{collections::HashMap, path::Path, sync::Arc, time::Duration}; + use async_trait::async_trait; use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; use oracle::{Connection, Row}; use serde_json::Value; -use std::collections::HashMap; -use std::path::Path; -use std::sync::Arc; -use std::time::Duration; + +use crate::{ + config::OracleConfig, + data_connector::responses::{ + ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult, + StoredResponse, + }, +}; const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \ tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses"; @@ -510,9 +511,10 @@ impl OracleErrorExt for ResponseStorageError { #[cfg(test)] mod tests { - use super::*; use serde_json::json; + use super::*; + #[test] fn parse_tool_calls_handles_empty_input() { assert!(parse_tool_calls(None).unwrap().is_empty()); diff --git a/sgl-router/src/data_connector/responses.rs b/sgl-router/src/data_connector/responses.rs index bb203652b..a19bd7dfd 100644 --- a/sgl-router/src/data_connector/responses.rs +++ b/sgl-router/src/data_connector/responses.rs @@ -1,8 +1,8 @@ +use std::{collections::HashMap, sync::Arc}; + use async_trait::async_trait; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::HashMap; -use std::sync::Arc; /// Response identifier #[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)] diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 92413116b..89a169b07 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -1,16 +1,23 @@ -use std::convert::TryFrom; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::time::Duration; +use std::{ + convert::TryFrom, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Context, Poll}, + time::Duration, +}; + use tonic::{transport::Channel, Request, Streaming}; use tracing::{debug, warn}; -use crate::protocols::chat::ChatCompletionRequest; -use crate::protocols::common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}; -use crate::protocols::generate::GenerateRequest; -use crate::protocols::sampling_params::SamplingParams as GenerateSamplingParams; +use crate::protocols::{ + chat::ChatCompletionRequest, + common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}, + generate::GenerateRequest, + sampling_params::SamplingParams as GenerateSamplingParams, +}; // Include the generated protobuf code pub mod proto { diff --git a/sgl-router/src/logging.rs b/sgl-router/src/logging.rs index c92139ec0..7590e9e15 100644 --- a/sgl-router/src/logging.rs +++ b/sgl-router/src/logging.rs @@ -1,12 +1,14 @@ use std::path::PathBuf; + use tracing::Level; -use tracing_appender::non_blocking::WorkerGuard; -use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_appender::{ + non_blocking::WorkerGuard, + rolling::{RollingFileAppender, Rotation}, +}; use tracing_log::LogTracer; -use tracing_subscriber::fmt::time::ChronoUtc; -use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::{EnvFilter, Layer}; +use tracing_subscriber::{ + fmt::time::ChronoUtc, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer, +}; #[derive(Debug, Clone)] pub struct LoggingConfig { diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index ba1dc0a4b..f32c8ddf3 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -1,14 +1,17 @@ -use clap::{ArgAction, Parser, ValueEnum}; -use sglang_router_rs::config::{ - CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, - HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, - RouterConfig, RoutingMode, -}; -use sglang_router_rs::metrics::PrometheusConfig; -use sglang_router_rs::server::{self, ServerConfig}; -use sglang_router_rs::service_discovery::ServiceDiscoveryConfig; use std::collections::HashMap; +use clap::{ArgAction, Parser, ValueEnum}; +use sglang_router_rs::{ + config::{ + CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, + HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, + RouterConfig, RoutingMode, + }, + metrics::PrometheusConfig, + server::{self, ServerConfig}, + service_discovery::ServiceDiscoveryConfig, +}; + fn parse_prefill_args() -> Vec<(String, Option)> { let args: Vec = std::env::args().collect(); let mut prefill_entries = Vec::new(); diff --git a/sgl-router/src/mcp/client_manager.rs b/sgl-router/src/mcp/client_manager.rs index b9b618a42..e5c5f128a 100644 --- a/sgl-router/src/mcp/client_manager.rs +++ b/sgl-router/src/mcp/client_manager.rs @@ -1,3 +1,5 @@ +use std::{borrow::Cow, collections::HashMap, time::Duration}; + use backoff::ExponentialBackoffBuilder; use dashmap::DashMap; use rmcp::{ @@ -13,7 +15,6 @@ use rmcp::{ RoleClient, ServiceExt, }; use serde::{Deserialize, Serialize}; -use std::{borrow::Cow, collections::HashMap, time::Duration}; use crate::mcp::{ config::{McpConfig, McpServerConfig, McpTransport}, diff --git a/sgl-router/src/mcp/config.rs b/sgl-router/src/mcp/config.rs index 1adf6a7d7..e94208e5b 100644 --- a/sgl-router/src/mcp/config.rs +++ b/sgl-router/src/mcp/config.rs @@ -1,6 +1,7 @@ -use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use serde::{Deserialize, Serialize}; + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct McpConfig { pub servers: Vec, diff --git a/sgl-router/src/mcp/oauth.rs b/sgl-router/src/mcp/oauth.rs index 3d13ea2be..e0de50db8 100644 --- a/sgl-router/src/mcp/oauth.rs +++ b/sgl-router/src/mcp/oauth.rs @@ -1,5 +1,7 @@ // OAuth authentication support for MCP servers +use std::{net::SocketAddr, sync::Arc}; + use axum::{ extract::{Query, State}, response::Html, @@ -8,7 +10,6 @@ use axum::{ }; use rmcp::transport::auth::OAuthState; use serde::Deserialize; -use std::{net::SocketAddr, sync::Arc}; use tokio::sync::{oneshot, Mutex}; use crate::mcp::error::{McpError, McpResult}; diff --git a/sgl-router/src/metrics.rs b/sgl-router/src/metrics.rs index 6cd21b14b..14dfb4ca5 100644 --- a/sgl-router/src/metrics.rs +++ b/sgl-router/src/metrics.rs @@ -1,7 +1,10 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Duration, +}; + use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder}; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::time::Duration; #[derive(Debug, Clone)] pub struct PrometheusConfig { @@ -620,9 +623,10 @@ impl TokenizerMetrics { #[cfg(test)] mod tests { - use super::*; use std::net::TcpListener; + use super::*; + #[test] fn test_prometheus_config_default() { let config = PrometheusConfig::default(); @@ -912,9 +916,13 @@ mod tests { #[test] fn test_concurrent_metric_updates() { - use std::sync::atomic::{AtomicBool, Ordering}; - use std::sync::Arc; - use std::thread; + use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + thread, + }; let done = Arc::new(AtomicBool::new(false)); let mut handles = vec![]; diff --git a/sgl-router/src/middleware.rs b/sgl-router/src/middleware.rs index 924edcee6..e9fb86220 100644 --- a/sgl-router/src/middleware.rs +++ b/sgl-router/src/middleware.rs @@ -1,12 +1,19 @@ +use std::{ + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::{Duration, Instant}, +}; + use axum::{ - body::Body, extract::Request, extract::State, http::header, http::HeaderValue, - http::StatusCode, middleware::Next, response::IntoResponse, response::Response, + body::Body, + extract::{Request, State}, + http::{header, HeaderValue, StatusCode}, + middleware::Next, + response::{IntoResponse, Response}, }; use rand::Rng; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; -use std::time::Duration; -use std::time::Instant; use subtle::ConstantTimeEq; use tokio::sync::{mpsc, oneshot}; use tower::{Layer, Service}; @@ -14,9 +21,7 @@ use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer}; use tracing::{debug, error, field::Empty, info, info_span, warn, Span}; pub use crate::core::token_bucket::TokenBucket; - -use crate::metrics::RouterMetrics; -use crate::server::AppState; +use crate::{metrics::RouterMetrics, server::AppState}; #[derive(Clone)] pub struct AuthConfig { diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index 65244c779..99fbc1558 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -59,17 +59,15 @@ during the next eviction cycle. */ -use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy}; -use crate::core::Worker; -use crate::metrics::RouterMetrics; -use crate::tree::Tree; +use std::{sync::Arc, thread, time::Duration}; + use dashmap::DashMap; use rand::Rng; -use std::sync::Arc; -use std::thread; -use std::time::Duration; use tracing::debug; +use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy}; +use crate::{core::Worker, metrics::RouterMetrics, tree::Tree}; + /// Cache-aware routing policy /// /// Routes requests based on cache affinity when load is balanced, diff --git a/sgl-router/src/policies/factory.rs b/sgl-router/src/policies/factory.rs index 96164e4d3..f03e8f1a0 100644 --- a/sgl-router/src/policies/factory.rs +++ b/sgl-router/src/policies/factory.rs @@ -1,11 +1,12 @@ //! Factory for creating load balancing policies +use std::sync::Arc; + use super::{ CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy, RoundRobinPolicy, }; use crate::config::PolicyConfig; -use std::sync::Arc; /// Factory for creating policy instances pub struct PolicyFactory; diff --git a/sgl-router/src/policies/mod.rs b/sgl-router/src/policies/mod.rs index 82aea5c7d..564daa73e 100644 --- a/sgl-router/src/policies/mod.rs +++ b/sgl-router/src/policies/mod.rs @@ -3,9 +3,9 @@ //! This module provides a unified abstraction for routing policies that work //! across both regular and prefill-decode (PD) routing modes. +use std::{fmt::Debug, sync::Arc}; + use crate::core::Worker; -use std::fmt::Debug; -use std::sync::Arc; mod cache_aware; mod factory; diff --git a/sgl-router/src/policies/power_of_two.rs b/sgl-router/src/policies/power_of_two.rs index d21f42a46..b7edef822 100644 --- a/sgl-router/src/policies/power_of_two.rs +++ b/sgl-router/src/policies/power_of_two.rs @@ -1,13 +1,16 @@ //! Power-of-two choices load balancing policy -use super::{get_healthy_worker_indices, LoadBalancingPolicy}; -use crate::core::Worker; -use crate::metrics::RouterMetrics; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + use rand::Rng; -use std::collections::HashMap; -use std::sync::{Arc, RwLock}; use tracing::info; +use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use crate::{core::Worker, metrics::RouterMetrics}; + /// Power-of-two choices policy /// /// Randomly selects two workers and routes to the one with lower load. diff --git a/sgl-router/src/policies/random.rs b/sgl-router/src/policies/random.rs index 8569d441b..5b92b2d73 100644 --- a/sgl-router/src/policies/random.rs +++ b/sgl-router/src/policies/random.rs @@ -1,11 +1,12 @@ //! Random load balancing policy -use super::{get_healthy_worker_indices, LoadBalancingPolicy}; -use crate::core::Worker; -use crate::metrics::RouterMetrics; -use rand::Rng; use std::sync::Arc; +use rand::Rng; + +use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use crate::{core::Worker, metrics::RouterMetrics}; + /// Random selection policy /// /// Selects workers randomly with uniform distribution among healthy workers. @@ -50,9 +51,10 @@ impl LoadBalancingPolicy for RandomPolicy { #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::core::{BasicWorkerBuilder, WorkerType}; - use std::collections::HashMap; #[test] fn test_random_selection() { diff --git a/sgl-router/src/policies/registry.rs b/sgl-router/src/policies/registry.rs index 8d9de51d3..2904340ef 100644 --- a/sgl-router/src/policies/registry.rs +++ b/sgl-router/src/policies/registry.rs @@ -1,3 +1,10 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use tracing::{debug, info, warn}; + /// Policy Registry for managing model-to-policy mappings /// /// This registry manages the dynamic assignment of load balancing policies to models. @@ -8,11 +15,7 @@ use super::{ CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy, RoundRobinPolicy, }; -use crate::config::types::PolicyConfig; -use crate::core::Worker; -use std::collections::HashMap; -use std::sync::{Arc, RwLock}; -use tracing::{debug, info, warn}; +use crate::{config::types::PolicyConfig, core::Worker}; /// Registry for managing model-to-policy mappings #[derive(Clone)] diff --git a/sgl-router/src/policies/round_robin.rs b/sgl-router/src/policies/round_robin.rs index 47e3c6e92..5b0776253 100644 --- a/sgl-router/src/policies/round_robin.rs +++ b/sgl-router/src/policies/round_robin.rs @@ -1,10 +1,12 @@ //! Round-robin load balancing policy +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + use super::{get_healthy_worker_indices, LoadBalancingPolicy}; -use crate::core::Worker; -use crate::metrics::RouterMetrics; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; +use crate::{core::Worker, metrics::RouterMetrics}; /// Round-robin selection policy /// diff --git a/sgl-router/src/protocols/chat.rs b/sgl-router/src/protocols/chat.rs index 55854e6c7..f81105ec6 100644 --- a/sgl-router/src/protocols/chat.rs +++ b/sgl-router/src/protocols/chat.rs @@ -1,10 +1,13 @@ +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::HashMap; use validator::Validate; -use super::common::*; -use super::sampling_params::{validate_top_k_value, validate_top_p_value}; +use super::{ + common::*, + sampling_params::{validate_top_k_value, validate_top_p_value}, +}; use crate::protocols::validated::Normalizable; // ============================================================================ @@ -532,11 +535,12 @@ impl Normalizable for ChatCompletionRequest { // Apply tool_choice defaults if self.tool_choice.is_none() { if let Some(tools) = &self.tools { - self.tool_choice = if !tools.is_empty() { - Some(ToolChoice::Value(ToolChoiceValue::Auto)) + let choice_value = if !tools.is_empty() { + ToolChoiceValue::Auto } else { - Some(ToolChoice::Value(ToolChoiceValue::None)) + ToolChoiceValue::None }; + self.tool_choice = Some(ToolChoice::Value(choice_value)); } // If tools is None, leave tool_choice as None (don't set it) } diff --git a/sgl-router/src/protocols/common.rs b/sgl-router/src/protocols/common.rs index 339045fc1..e92138e56 100644 --- a/sgl-router/src/protocols/common.rs +++ b/sgl-router/src/protocols/common.rs @@ -1,6 +1,7 @@ +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::HashMap; // ============================================================================ // Default value helpers diff --git a/sgl-router/src/protocols/completion.rs b/sgl-router/src/protocols/completion.rs index a7bdfcfde..c6a4f638a 100644 --- a/sgl-router/src/protocols/completion.rs +++ b/sgl-router/src/protocols/completion.rs @@ -1,6 +1,7 @@ +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; -use std::collections::HashMap; use super::common::*; diff --git a/sgl-router/src/protocols/generate.rs b/sgl-router/src/protocols/generate.rs index 3aac25640..4f3c1301a 100644 --- a/sgl-router/src/protocols/generate.rs +++ b/sgl-router/src/protocols/generate.rs @@ -1,10 +1,13 @@ +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::HashMap; use validator::Validate; -use super::common::{default_true, GenerationRequest, InputIds}; -use super::sampling_params::SamplingParams; +use super::{ + common::{default_true, GenerationRequest, InputIds}, + sampling_params::SamplingParams, +}; use crate::protocols::validated::Normalizable; // ============================================================================ diff --git a/sgl-router/src/protocols/rerank.rs b/sgl-router/src/protocols/rerank.rs index 584a66b4b..6775f5d8a 100644 --- a/sgl-router/src/protocols/rerank.rs +++ b/sgl-router/src/protocols/rerank.rs @@ -1,6 +1,7 @@ +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::HashMap; use validator::Validate; use super::common::{default_model, default_true, GenerationRequest, StringOrArray, UsageInfo}; diff --git a/sgl-router/src/protocols/responses.rs b/sgl-router/src/protocols/responses.rs index 1fbda00ad..4fabf52e9 100644 --- a/sgl-router/src/protocols/responses.rs +++ b/sgl-router/src/protocols/responses.rs @@ -1,9 +1,10 @@ // OpenAI Responses API types // https://platform.openai.com/docs/api-reference/responses +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::HashMap; // Import shared types from common module use super::common::{ diff --git a/sgl-router/src/protocols/validated.rs b/sgl-router/src/protocols/validated.rs index 4e88def58..7eb5a812b 100644 --- a/sgl-router/src/protocols/validated.rs +++ b/sgl-router/src/protocols/validated.rs @@ -117,10 +117,11 @@ impl std::ops::DerefMut for ValidatedJson { #[cfg(test)] mod tests { - use super::*; use serde::{Deserialize, Serialize}; use validator::Validate; + use super::*; + #[derive(Debug, Deserialize, Serialize, Validate)] struct TestRequest { #[validate(range(min = 0.0, max = 1.0))] diff --git a/sgl-router/src/protocols/worker_spec.rs b/sgl-router/src/protocols/worker_spec.rs index 1ae3fb832..44d4297ee 100644 --- a/sgl-router/src/protocols/worker_spec.rs +++ b/sgl-router/src/protocols/worker_spec.rs @@ -2,9 +2,10 @@ //! //! Defines the request/response structures for worker management endpoints -use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use serde::{Deserialize, Serialize}; + /// Worker configuration for API requests #[derive(Debug, Clone, Deserialize, Serialize)] pub struct WorkerConfigRequest { diff --git a/sgl-router/src/reasoning_parser/factory.rs b/sgl-router/src/reasoning_parser/factory.rs index 28f3d0836..a0d57870d 100644 --- a/sgl-router/src/reasoning_parser/factory.rs +++ b/sgl-router/src/reasoning_parser/factory.rs @@ -1,16 +1,20 @@ // Factory and registry for creating model-specific reasoning parsers. // Now with parser pooling support for efficient reuse across requests. -use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; use tokio::sync::Mutex; -use crate::reasoning_parser::parsers::{ - BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, - QwenThinkingParser, Step3Parser, +use crate::reasoning_parser::{ + parsers::{ + BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, + QwenThinkingParser, Step3Parser, + }, + traits::{ParseError, ParserConfig, ReasoningParser}, }; -use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser}; /// Type alias for pooled parser instances. /// Uses tokio::Mutex to avoid blocking the async executor. @@ -402,8 +406,10 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 8)] async fn test_high_concurrency_parser_access() { - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::time::Instant; + use std::{ + sync::atomic::{AtomicUsize, Ordering}, + time::Instant, + }; let factory = ParserFactory::new(); let num_tasks = 100; diff --git a/sgl-router/src/reasoning_parser/parsers/deepseek_r1.rs b/sgl-router/src/reasoning_parser/parsers/deepseek_r1.rs index 1bb2f4c48..4982cecc5 100644 --- a/sgl-router/src/reasoning_parser/parsers/deepseek_r1.rs +++ b/sgl-router/src/reasoning_parser/parsers/deepseek_r1.rs @@ -2,8 +2,10 @@ // This parser starts with in_reasoning=true, assuming all text is reasoning // until an end token is encountered. -use crate::reasoning_parser::parsers::BaseReasoningParser; -use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; +use crate::reasoning_parser::{ + parsers::BaseReasoningParser, + traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}, +}; /// DeepSeek-R1 reasoning parser. /// diff --git a/sgl-router/src/reasoning_parser/parsers/glm45.rs b/sgl-router/src/reasoning_parser/parsers/glm45.rs index 5b2778993..f21124d1e 100644 --- a/sgl-router/src/reasoning_parser/parsers/glm45.rs +++ b/sgl-router/src/reasoning_parser/parsers/glm45.rs @@ -1,8 +1,10 @@ // GLM45 specific reasoning parser. // Uses the same format as Qwen3 but has its own implementation for debugging. -use crate::reasoning_parser::parsers::BaseReasoningParser; -use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; +use crate::reasoning_parser::{ + parsers::BaseReasoningParser, + traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}, +}; /// GLM45 reasoning parser. /// diff --git a/sgl-router/src/reasoning_parser/parsers/kimi.rs b/sgl-router/src/reasoning_parser/parsers/kimi.rs index 0095f94f0..998273172 100644 --- a/sgl-router/src/reasoning_parser/parsers/kimi.rs +++ b/sgl-router/src/reasoning_parser/parsers/kimi.rs @@ -1,8 +1,10 @@ // Kimi specific reasoning parser. // This parser uses Unicode tokens and starts with in_reasoning=false. -use crate::reasoning_parser::parsers::BaseReasoningParser; -use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; +use crate::reasoning_parser::{ + parsers::BaseReasoningParser, + traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}, +}; /// Kimi reasoning parser. /// diff --git a/sgl-router/src/reasoning_parser/parsers/qwen3.rs b/sgl-router/src/reasoning_parser/parsers/qwen3.rs index 038e7db8d..3233808f8 100644 --- a/sgl-router/src/reasoning_parser/parsers/qwen3.rs +++ b/sgl-router/src/reasoning_parser/parsers/qwen3.rs @@ -2,8 +2,10 @@ // This parser starts with in_reasoning=false, requiring an explicit // start token to enter reasoning mode. -use crate::reasoning_parser::parsers::BaseReasoningParser; -use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; +use crate::reasoning_parser::{ + parsers::BaseReasoningParser, + traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}, +}; /// Qwen3 reasoning parser. /// diff --git a/sgl-router/src/reasoning_parser/parsers/step3.rs b/sgl-router/src/reasoning_parser/parsers/step3.rs index 155e340cc..de30c438d 100644 --- a/sgl-router/src/reasoning_parser/parsers/step3.rs +++ b/sgl-router/src/reasoning_parser/parsers/step3.rs @@ -1,8 +1,10 @@ // Step3 specific reasoning parser. // Uses the same format as DeepSeek-R1 but has its own implementation for debugging. -use crate::reasoning_parser::parsers::BaseReasoningParser; -use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; +use crate::reasoning_parser::{ + parsers::BaseReasoningParser, + traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}, +}; /// Step3 reasoning parser. /// diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index 5a00fa7f5..669c123fd 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -1,16 +1,18 @@ //! Factory for creating router instances -use super::grpc::pd_router::GrpcPDRouter; -use super::grpc::router::GrpcRouter; +use std::sync::Arc; + use super::{ + grpc::{pd_router::GrpcPDRouter, router::GrpcRouter}, http::{pd_router::PDRouter, router::Router}, openai::OpenAIRouter, RouterTrait, }; -use crate::config::{ConnectionMode, PolicyConfig, RoutingMode}; -use crate::policies::PolicyFactory; -use crate::server::AppContext; -use std::sync::Arc; +use crate::{ + config::{ConnectionMode, PolicyConfig, RoutingMode}, + policies::PolicyFactory, + server::AppContext, +}; /// Factory for creating router instances based on configuration pub struct RouterFactory; diff --git a/sgl-router/src/routers/grpc/context.rs b/sgl-router/src/routers/grpc/context.rs index dc1f7a3c2..1d251dd64 100644 --- a/sgl-router/src/routers/grpc/context.rs +++ b/sgl-router/src/routers/grpc/context.rs @@ -4,20 +4,22 @@ //! eliminating deep parameter passing chains and providing a single source of truth //! for request state. -use std::collections::HashMap; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use axum::http::HeaderMap; use serde_json::Value; -use crate::core::Worker; -use crate::grpc_client::{proto, SglangSchedulerClient}; -use crate::protocols::chat::{ChatCompletionRequest, ChatCompletionResponse}; -use crate::protocols::generate::{GenerateRequest, GenerateResponse}; -use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; -use crate::tokenizer::stop::StopSequenceDecoder; -use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ParserFactory as ToolParserFactory; +use crate::{ + core::Worker, + grpc_client::{proto, SglangSchedulerClient}, + protocols::{ + chat::{ChatCompletionRequest, ChatCompletionResponse}, + generate::{GenerateRequest, GenerateResponse}, + }, + reasoning_parser::ParserFactory as ReasoningParserFactory, + tokenizer::{stop::StopSequenceDecoder, traits::Tokenizer}, + tool_parser::ParserFactory as ToolParserFactory, +}; // ============================================================================ // Core Context Types diff --git a/sgl-router/src/routers/grpc/mod.rs b/sgl-router/src/routers/grpc/mod.rs index 14ed36de4..920c49645 100644 --- a/sgl-router/src/routers/grpc/mod.rs +++ b/sgl-router/src/routers/grpc/mod.rs @@ -1,7 +1,6 @@ //! gRPC router implementations -use crate::grpc_client::proto; -use crate::protocols::common::StringOrArray; +use crate::{grpc_client::proto, protocols::common::StringOrArray}; pub mod context; pub mod pd_router; diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 1e524c27c..243df61b5 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -1,19 +1,7 @@ // PD (Prefill-Decode) gRPC Router Implementation -use crate::config::types::RetryConfig; -use crate::core::{ConnectionMode, WorkerRegistry, WorkerType}; -use crate::policies::PolicyRegistry; -use crate::protocols::chat::ChatCompletionRequest; -use crate::protocols::completion::CompletionRequest; -use crate::protocols::embedding::EmbeddingRequest; -use crate::protocols::generate::GenerateRequest; -use crate::protocols::rerank::RerankRequest; -use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; -use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; -use crate::routers::RouterTrait; -use crate::server::AppContext; -use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ParserFactory as ToolParserFactory; +use std::sync::Arc; + use async_trait::async_trait; use axum::{ body::Body, @@ -21,12 +9,27 @@ use axum::{ http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; -use std::sync::Arc; - use tracing::debug; -use super::context::SharedComponents; -use super::pipeline::RequestPipeline; +use super::{context::SharedComponents, pipeline::RequestPipeline}; +use crate::{ + config::types::RetryConfig, + core::{ConnectionMode, WorkerRegistry, WorkerType}, + policies::PolicyRegistry, + protocols::{ + chat::ChatCompletionRequest, + completion::CompletionRequest, + embedding::EmbeddingRequest, + generate::GenerateRequest, + rerank::RerankRequest, + responses::{ResponsesGetParams, ResponsesRequest}, + }, + reasoning_parser::ParserFactory as ReasoningParserFactory, + routers::RouterTrait, + server::AppContext, + tokenizer::traits::Tokenizer, + tool_parser::ParserFactory as ToolParserFactory, +}; /// gRPC PD (Prefill-Decode) router implementation for SGLang #[derive(Clone)] diff --git a/sgl-router/src/routers/grpc/pipeline.rs b/sgl-router/src/routers/grpc/pipeline.rs index 9b21a8281..8c07bc25a 100644 --- a/sgl-router/src/routers/grpc/pipeline.rs +++ b/sgl-router/src/routers/grpc/pipeline.rs @@ -3,29 +3,29 @@ //! This module defines the core pipeline abstraction and individual processing stages //! that transform a RequestContext through its lifecycle. +use std::{ + sync::Arc, + time::{Instant, SystemTime, UNIX_EPOCH}, +}; + use async_trait::async_trait; use axum::response::{IntoResponse, Response}; -use tracing::{debug, error, warn}; - -use super::context::*; -use super::processing; -use super::streaming; -use super::utils; -use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; -use crate::grpc_client::proto; -use crate::policies::PolicyRegistry; -use crate::protocols::chat::ChatCompletionRequest; -use crate::protocols::common::InputIds; -use crate::protocols::generate::GenerateRequest; -use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; -use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ParserFactory as ToolParserFactory; use proto::DisaggregatedParams; use rand::Rng; -use std::sync::Arc; -use std::time::{Instant, SystemTime, UNIX_EPOCH}; +use tracing::{debug, error, warn}; use uuid::Uuid; +use super::{context::*, processing, streaming, utils}; +use crate::{ + core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}, + grpc_client::proto, + policies::PolicyRegistry, + protocols::{chat::ChatCompletionRequest, common::InputIds, generate::GenerateRequest}, + reasoning_parser::ParserFactory as ReasoningParserFactory, + tokenizer::traits::Tokenizer, + tool_parser::ParserFactory as ToolParserFactory, +}; + // ============================================================================ // Pipeline Trait // ============================================================================ diff --git a/sgl-router/src/routers/grpc/processing.rs b/sgl-router/src/routers/grpc/processing.rs index 886ec0d95..22b7a3981 100644 --- a/sgl-router/src/routers/grpc/processing.rs +++ b/sgl-router/src/routers/grpc/processing.rs @@ -3,28 +3,30 @@ //! This module contains response processing functions that are shared between //! the regular router and PD router, eliminating ~1,200 lines of exact duplicates. -use std::sync::Arc; +use std::{sync::Arc, time::Instant}; +use proto::generate_complete::MatchedStop; use serde_json::Value; use tracing::error; -use crate::grpc_client::proto; -use crate::protocols::chat::{ - ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, +use super::{ + context::{DispatchMetadata, ExecutionResult}, + utils, }; -use crate::protocols::common::{ - FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage, +use crate::{ + grpc_client::proto, + protocols::{ + chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse}, + common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage}, + generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse}, + }, + reasoning_parser::ParserFactory as ReasoningParserFactory, + tokenizer::{ + stop::{SequenceDecoderOutput, StopSequenceDecoder}, + traits::Tokenizer, + }, + tool_parser::ParserFactory as ToolParserFactory, }; -use crate::protocols::generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse}; -use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; -use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; -use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ParserFactory as ToolParserFactory; -use proto::generate_complete::MatchedStop; -use std::time::Instant; - -use super::context::{DispatchMetadata, ExecutionResult}; -use super::utils; // ============================================================================ // Response Processor - Main Entry Point diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index d798c851b..ceed70e62 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -11,23 +11,25 @@ use axum::{ }; use tracing::debug; -use crate::config::types::RetryConfig; -use crate::core::WorkerRegistry; -use crate::policies::PolicyRegistry; -use crate::protocols::chat::ChatCompletionRequest; -use crate::protocols::completion::CompletionRequest; -use crate::protocols::embedding::EmbeddingRequest; -use crate::protocols::generate::GenerateRequest; -use crate::protocols::rerank::RerankRequest; -use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; -use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; -use crate::routers::RouterTrait; -use crate::server::AppContext; -use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ParserFactory as ToolParserFactory; - -use super::context::SharedComponents; -use super::pipeline::RequestPipeline; +use super::{context::SharedComponents, pipeline::RequestPipeline}; +use crate::{ + config::types::RetryConfig, + core::WorkerRegistry, + policies::PolicyRegistry, + protocols::{ + chat::ChatCompletionRequest, + completion::CompletionRequest, + embedding::EmbeddingRequest, + generate::GenerateRequest, + rerank::RerankRequest, + responses::{ResponsesGetParams, ResponsesRequest}, + }, + reasoning_parser::ParserFactory as ReasoningParserFactory, + routers::RouterTrait, + server::AppContext, + tokenizer::traits::Tokenizer, + tool_parser::ParserFactory as ToolParserFactory, +}; /// gRPC router implementation for SGLang #[derive(Clone)] diff --git a/sgl-router/src/routers/grpc/streaming.rs b/sgl-router/src/routers/grpc/streaming.rs index dcc127dab..1202ed69b 100644 --- a/sgl-router/src/routers/grpc/streaming.rs +++ b/sgl-router/src/routers/grpc/streaming.rs @@ -3,38 +3,40 @@ //! This module contains shared streaming logic for both Regular and PD routers, //! eliminating ~600 lines of duplication. -use axum::response::Response; -use axum::{body::Body, http::StatusCode}; +use std::{collections::HashMap, io, sync::Arc, time::Instant}; + +use axum::{body::Body, http::StatusCode, response::Response}; use bytes::Bytes; use http::header::{HeaderValue, CONTENT_TYPE}; +use proto::{ + generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId}, + generate_response::Response::{Chunk, Complete, Error}, +}; use serde_json::{json, Value}; -use std::collections::HashMap; -use std::io; -use std::sync::Arc; -use tokio::sync::mpsc::UnboundedSender; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_stream::StreamExt; +use tokio::sync::{mpsc, mpsc::UnboundedSender}; +use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt}; use tracing::{debug, error, warn}; -use super::context; -use super::utils; -use crate::grpc_client::proto; -use crate::protocols::chat::{ - ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, +use super::{context, utils}; +use crate::{ + grpc_client::proto, + protocols::{ + chat::{ + ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, + }, + common::{ + ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice, + ToolChoiceValue, Usage, + }, + generate::GenerateRequest, + }, + reasoning_parser::ReasoningParser, + tokenizer::{ + stop::{SequenceDecoderOutput, StopSequenceDecoder}, + traits::Tokenizer, + }, + tool_parser::ToolParser, }; -use crate::protocols::common::{ - ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice, - ToolChoiceValue, Usage, -}; -use crate::protocols::generate::GenerateRequest; -use crate::reasoning_parser::ReasoningParser; -use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; -use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ToolParser; -use proto::generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId}; -use proto::generate_response::Response::{Chunk, Complete, Error}; -use std::time::Instant; -use tokio::sync::mpsc; /// Shared streaming processor for both single and dual dispatch modes #[derive(Clone)] diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index 86e1532d5..e2c06b2a8 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -1,19 +1,7 @@ //! Shared utilities for gRPC routers -use super::ProcessedMessages; -use crate::core::Worker; -use crate::grpc_client::sglang_scheduler::AbortOnDropStream; -use crate::grpc_client::{proto, SglangSchedulerClient}; -use crate::protocols::chat::{ChatCompletionRequest, ChatMessage}; -use crate::protocols::common::{ - ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall, - ToolChoice, ToolChoiceValue, TopLogProb, -}; -use crate::protocols::generate::GenerateFinishReason; -use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; -use crate::tokenizer::traits::Tokenizer; -use crate::tokenizer::HuggingFaceTokenizer; -pub use crate::tokenizer::StopSequenceDecoder; +use std::{collections::HashMap, sync::Arc}; + use axum::{ http::StatusCode, response::{IntoResponse, Response}, @@ -21,11 +9,29 @@ use axum::{ }; use futures::StreamExt; use serde_json::{json, Map, Value}; -use std::collections::HashMap; -use std::sync::Arc; use tracing::{error, warn}; use uuid::Uuid; +use super::ProcessedMessages; +pub use crate::tokenizer::StopSequenceDecoder; +use crate::{ + core::Worker, + grpc_client::{proto, sglang_scheduler::AbortOnDropStream, SglangSchedulerClient}, + protocols::{ + chat::{ChatCompletionRequest, ChatMessage}, + common::{ + ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall, + ToolChoice, ToolChoiceValue, TopLogProb, + }, + generate::GenerateFinishReason, + }, + tokenizer::{ + chat_template::{ChatTemplateContentFormat, ChatTemplateParams}, + traits::Tokenizer, + HuggingFaceTokenizer, + }, +}; + /// Get gRPC client from worker, returning appropriate error response on failure pub async fn get_grpc_client_from_worker( worker: &Arc, @@ -953,12 +959,17 @@ pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> Generate #[cfg(test)] mod tests { - use super::*; - use crate::protocols::chat::{ChatMessage, UserMessageContent}; - use crate::protocols::common::{ContentPart, ImageUrl}; - use crate::tokenizer::chat_template::ChatTemplateContentFormat; use serde_json::json; + use super::*; + use crate::{ + protocols::{ + chat::{ChatMessage, UserMessageContent}, + common::{ContentPart, ImageUrl}, + }, + tokenizer::chat_template::ChatTemplateContentFormat, + }; + #[test] fn test_transform_messages_string_format() { let messages = vec![ChatMessage::User { diff --git a/sgl-router/src/routers/header_utils.rs b/sgl-router/src/routers/header_utils.rs index 13b6f04ef..9ddde8aed 100644 --- a/sgl-router/src/routers/header_utils.rs +++ b/sgl-router/src/routers/header_utils.rs @@ -1,6 +1,4 @@ -use axum::body::Body; -use axum::extract::Request; -use axum::http::HeaderMap; +use axum::{body::Body, extract::Request, http::HeaderMap}; /// Copy request headers to a Vec of name-value string pairs /// Used for forwarding headers to backend workers diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 8b6864a62..078bb5080 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -1,19 +1,5 @@ -use super::pd_types::api_path; -use crate::config::types::RetryConfig; -use crate::core::{ - is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType, -}; -use crate::metrics::RouterMetrics; -use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; -use crate::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}; -use crate::protocols::common::{InputIds, StringOrArray}; -use crate::protocols::completion::CompletionRequest; -use crate::protocols::embedding::EmbeddingRequest; -use crate::protocols::generate::GenerateRequest; -use crate::protocols::rerank::RerankRequest; -use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; -use crate::routers::header_utils; -use crate::routers::RouterTrait; +use std::{sync::Arc, time::Instant}; + use async_trait::async_trait; use axum::{ body::Body, @@ -25,11 +11,29 @@ use futures_util::StreamExt; use reqwest::Client; use serde::Serialize; use serde_json::{json, Value}; -use std::sync::Arc; -use std::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, warn}; +use super::pd_types::api_path; +use crate::{ + config::types::RetryConfig, + core::{ + is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType, + }, + metrics::RouterMetrics, + policies::{LoadBalancingPolicy, PolicyRegistry}, + protocols::{ + chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, + common::{InputIds, StringOrArray}, + completion::CompletionRequest, + embedding::EmbeddingRequest, + generate::GenerateRequest, + rerank::RerankRequest, + responses::{ResponsesGetParams, ResponsesRequest}, + }, + routers::{header_utils, RouterTrait}, +}; + #[derive(Debug)] pub struct PDRouter { pub worker_registry: Arc, diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index b20203166..911859a0d 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -1,35 +1,39 @@ -use crate::config::types::RetryConfig; -use crate::core::{ - is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType, -}; -use crate::metrics::RouterMetrics; -use crate::policies::PolicyRegistry; -use crate::protocols::chat::ChatCompletionRequest; -use crate::protocols::common::GenerationRequest; -use crate::protocols::completion::CompletionRequest; -use crate::protocols::embedding::EmbeddingRequest; -use crate::protocols::generate::GenerateRequest; -use crate::protocols::rerank::{RerankRequest, RerankResponse, RerankResult}; -use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; -use crate::routers::header_utils; -use crate::routers::RouterTrait; -use axum::body::to_bytes; +use std::{sync::Arc, time::Instant}; + use axum::{ - body::Body, + body::{to_bytes, Body}, extract::Request, http::{ - header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode, + header::{CONTENT_LENGTH, CONTENT_TYPE}, + HeaderMap, HeaderValue, Method, StatusCode, }, response::{IntoResponse, Response}, Json, }; use futures_util::StreamExt; use reqwest::Client; -use std::sync::Arc; -use std::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error}; +use crate::{ + config::types::RetryConfig, + core::{ + is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType, + }, + metrics::RouterMetrics, + policies::PolicyRegistry, + protocols::{ + chat::ChatCompletionRequest, + common::GenerationRequest, + completion::CompletionRequest, + embedding::EmbeddingRequest, + generate::GenerateRequest, + rerank::{RerankRequest, RerankResponse, RerankResult}, + responses::{ResponsesGetParams, ResponsesRequest}, + }, + routers::{header_utils, RouterTrait}, +}; + /// Regular router that uses injected load balancing policies #[derive(Debug)] pub struct Router { diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 410976089..b034605dd 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -1,5 +1,7 @@ //! Router implementations +use std::fmt::Debug; + use async_trait::async_trait; use axum::{ body::Body, @@ -7,16 +9,17 @@ use axum::{ http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; -use std::fmt::Debug; - -use crate::protocols::chat::ChatCompletionRequest; -use crate::protocols::completion::CompletionRequest; -use crate::protocols::embedding::EmbeddingRequest; -use crate::protocols::generate::GenerateRequest; -use crate::protocols::rerank::RerankRequest; -use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; use serde_json::Value; +use crate::protocols::{ + chat::ChatCompletionRequest, + completion::CompletionRequest, + embedding::EmbeddingRequest, + generate::GenerateRequest, + rerank::RerankRequest, + responses::{ResponsesGetParams, ResponsesRequest}, +}; + pub mod factory; pub mod grpc; pub mod header_utils; @@ -25,7 +28,6 @@ pub mod openai; // New refactored OpenAI router module pub mod router_manager; pub use factory::RouterFactory; - // Re-export HTTP routers for convenience pub use http::{pd_router, pd_types, router}; diff --git a/sgl-router/src/routers/openai/conversations.rs b/sgl-router/src/routers/openai/conversations.rs index 6fdadde53..8d9b8c0a1 100644 --- a/sgl-router/src/routers/openai/conversations.rs +++ b/sgl-router/src/routers/openai/conversations.rs @@ -1,22 +1,26 @@ //! Conversation CRUD operations and persistence -use crate::data_connector::{ - conversation_items::ListParams, conversation_items::SortOrder, Conversation, ConversationId, - ConversationItemId, ConversationItemStorage, ConversationStorage, NewConversation, - NewConversationItem, ResponseId, ResponseStorage, SharedConversationItemStorage, - SharedConversationStorage, +use std::{collections::HashMap, sync::Arc}; + +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, }; -use crate::protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest}; -use axum::http::StatusCode; -use axum::response::{IntoResponse, Response}; -use axum::Json; use chrono::Utc; use serde_json::{json, Value}; -use std::collections::HashMap; -use std::sync::Arc; use tracing::{debug, info, warn}; use super::responses::build_stored_response; +use crate::{ + data_connector::{ + conversation_items::{ListParams, SortOrder}, + Conversation, ConversationId, ConversationItemId, ConversationItemStorage, + ConversationStorage, NewConversation, NewConversationItem, ResponseId, ResponseStorage, + SharedConversationItemStorage, SharedConversationStorage, + }, + protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest}, +}; /// Maximum number of properties allowed in conversation metadata pub(crate) const MAX_METADATA_PROPERTIES: usize = 16; diff --git a/sgl-router/src/routers/openai/mcp.rs b/sgl-router/src/routers/openai/mcp.rs index de86690d7..c53d54d51 100644 --- a/sgl-router/src/routers/openai/mcp.rs +++ b/sgl-router/src/routers/openai/mcp.rs @@ -8,19 +8,20 @@ //! - Payload transformation for MCP tool interception //! - Metadata injection for MCP operations -use crate::mcp::McpClientManager; -use crate::protocols::responses::{ - ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest, -}; -use crate::routers::header_utils::apply_request_headers; +use std::{io, sync::Arc}; + use axum::http::HeaderMap; use bytes::Bytes; use serde_json::{json, to_value, Value}; -use std::{io, sync::Arc}; use tokio::sync::mpsc; use tracing::{info, warn}; use super::utils::event_types; +use crate::{ + mcp::McpClientManager, + protocols::responses::{ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest}, + routers::header_utils::apply_request_headers, +}; // ============================================================================ // Configuration and State Types diff --git a/sgl-router/src/routers/openai/responses.rs b/sgl-router/src/routers/openai/responses.rs index fbd3a1ee2..3931ac019 100644 --- a/sgl-router/src/routers/openai/responses.rs +++ b/sgl-router/src/routers/openai/responses.rs @@ -1,12 +1,15 @@ //! Response storage, patching, and extraction utilities -use crate::data_connector::{ResponseId, StoredResponse}; -use crate::protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest}; -use serde_json::{json, Value}; use std::collections::HashMap; + +use serde_json::{json, Value}; use tracing::warn; use super::utils::event_types; +use crate::{ + data_connector::{ResponseId, StoredResponse}, + protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest}, +}; // ============================================================================ // Response Storage Operations diff --git a/sgl-router/src/routers/openai/router.rs b/sgl-router/src/routers/openai/router.rs index dd6d2307e..3008929ea 100644 --- a/sgl-router/src/routers/openai/router.rs +++ b/sgl-router/src/routers/openai/router.rs @@ -1,21 +1,10 @@ //! OpenAI router - main coordinator that delegates to specialized modules -use crate::config::CircuitBreakerConfig; -use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; -use crate::data_connector::{ - conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId, - SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, +use std::{ + any::Any, + sync::{atomic::AtomicBool, Arc}, }; -use crate::protocols::chat::ChatCompletionRequest; -use crate::protocols::completion::CompletionRequest; -use crate::protocols::embedding::EmbeddingRequest; -use crate::protocols::generate::GenerateRequest; -use crate::protocols::rerank::RerankRequest; -use crate::protocols::responses::{ - ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams, - ResponsesRequest, -}; -use crate::routers::header_utils::apply_request_headers; + use axum::{ body::Body, extract::Request, @@ -25,10 +14,6 @@ use axum::{ }; use futures_util::StreamExt; use serde_json::{json, to_value, Value}; -use std::{ - any::Any, - sync::{atomic::AtomicBool, Arc}, -}; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::warn; @@ -39,12 +24,35 @@ use super::conversations::{ get_conversation, get_conversation_item, list_conversation_items, persist_conversation_items, update_conversation, }; -use super::mcp::{ - execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, - McpLoopConfig, +use super::{ + mcp::{ + execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, + McpLoopConfig, + }, + responses::{mask_tools_as_mcp, patch_streaming_response_json}, + streaming::handle_streaming_response, +}; +use crate::{ + config::CircuitBreakerConfig, + core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}, + data_connector::{ + conversation_items::{ListParams, SortOrder}, + ConversationId, ResponseId, SharedConversationItemStorage, SharedConversationStorage, + SharedResponseStorage, + }, + protocols::{ + chat::ChatCompletionRequest, + completion::CompletionRequest, + embedding::EmbeddingRequest, + generate::GenerateRequest, + rerank::RerankRequest, + responses::{ + ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams, + ResponsesRequest, + }, + }, + routers::header_utils::apply_request_headers, }; -use super::responses::{mask_tools_as_mcp, patch_streaming_response_json}; -use super::streaming::handle_streaming_response; // ============================================================================ // OpenAIRouter Struct diff --git a/sgl-router/src/routers/openai/streaming.rs b/sgl-router/src/routers/openai/streaming.rs index 349531d0e..804144446 100644 --- a/sgl-router/src/routers/openai/streaming.rs +++ b/sgl-router/src/routers/openai/streaming.rs @@ -7,11 +7,8 @@ //! - MCP tool execution loops within streaming responses //! - Event transformation and output index remapping -use crate::data_connector::{ - SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, -}; -use crate::protocols::responses::{ResponseToolType, ResponsesRequest}; -use crate::routers::header_utils::{apply_request_headers, preserve_response_headers}; +use std::{borrow::Cow, io, sync::Arc}; + use axum::{ body::Body, http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, @@ -20,20 +17,28 @@ use axum::{ use bytes::Bytes; use futures_util::StreamExt; use serde_json::{json, Value}; -use std::{borrow::Cow, io, sync::Arc}; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::warn; // Import from sibling modules use super::conversations::persist_conversation_items; -use super::mcp::{ - build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming, - mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, send_mcp_list_tools_events, - McpLoopConfig, ToolLoopState, +use super::{ + mcp::{ + build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming, + mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, + send_mcp_list_tools_events, McpLoopConfig, ToolLoopState, + }, + responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block}, + utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction}, +}; +use crate::{ + data_connector::{ + SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, + }, + protocols::responses::{ResponseToolType, ResponsesRequest}, + routers::header_utils::{apply_request_headers, preserve_response_headers}, }; -use super::responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block}; -use super::utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction}; // ============================================================================ // Streaming Response Accumulator diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 23f19f20f..3bf4be4fb 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -4,16 +4,8 @@ //! - Single Router Mode (enable_igw=false): Router owns workers directly //! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything -use crate::config::{ConnectionMode, RoutingMode}; -use crate::core::{WorkerRegistry, WorkerType}; -use crate::protocols::chat::ChatCompletionRequest; -use crate::protocols::completion::CompletionRequest; -use crate::protocols::embedding::EmbeddingRequest; -use crate::protocols::generate::GenerateRequest; -use crate::protocols::rerank::RerankRequest; -use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; -use crate::routers::RouterTrait; -use crate::server::{AppContext, ServerConfig}; +use std::sync::Arc; + use async_trait::async_trait; use axum::{ body::Body, @@ -23,9 +15,23 @@ use axum::{ }; use dashmap::DashMap; use serde_json::Value; -use std::sync::Arc; use tracing::{debug, info, warn}; +use crate::{ + config::{ConnectionMode, RoutingMode}, + core::{WorkerRegistry, WorkerType}, + protocols::{ + chat::ChatCompletionRequest, + completion::CompletionRequest, + embedding::EmbeddingRequest, + generate::GenerateRequest, + rerank::RerankRequest, + responses::{ResponsesGetParams, ResponsesRequest}, + }, + routers::RouterTrait, + server::{AppContext, ServerConfig}, +}; + #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub struct RouterId(String); diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 2eb7484b3..b699ff08d 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,3 +1,24 @@ +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, OnceLock, + }, + time::Duration, +}; + +use axum::{ + extract::{Path, Query, Request, State}, + http::StatusCode, + response::{IntoResponse, Response}, + routing::{delete, get, post}, + serve, Json, Router, +}; +use reqwest::Client; +use serde::Deserialize; +use serde_json::{json, Value}; +use tokio::{net::TcpListener, signal, spawn}; +use tracing::{error, info, warn, Level}; + use crate::{ config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, core::{ @@ -30,24 +51,6 @@ use crate::{ tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, tool_parser::ParserFactory as ToolParserFactory, }; -use axum::{ - extract::{Path, Query, Request, State}, - http::StatusCode, - response::{IntoResponse, Response}, - routing::{delete, get, post}, - serve, Json, Router, -}; -use reqwest::Client; -use serde::Deserialize; -use serde_json::{json, Value}; -use std::sync::OnceLock; -use std::{ - sync::atomic::{AtomicBool, Ordering}, - sync::Arc, - time::Duration, -}; -use tokio::{net::TcpListener, signal, spawn}; -use tracing::{error, info, warn, Level}; // diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 64666992e..8f2755762 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -1,24 +1,25 @@ -use crate::core::WorkerManager; -use crate::protocols::worker_spec::WorkerConfigRequest; -use crate::server::AppContext; +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, Mutex}, + time::Duration, +}; use futures::{StreamExt, TryStreamExt}; use k8s_openapi::api::core::v1::Pod; use kube::{ api::Api, - runtime::watcher::{watcher, Config}, - runtime::WatchStreamExt, + runtime::{ + watcher::{watcher, Config}, + WatchStreamExt, + }, Client, }; -use std::collections::{HashMap, HashSet}; - use rustls; -use std::sync::{Arc, Mutex}; -use std::time::Duration; -use tokio::task; -use tokio::time; +use tokio::{task, time}; use tracing::{debug, error, info, warn}; +use crate::{core::WorkerManager, protocols::worker_spec::WorkerConfigRequest, server::AppContext}; + #[derive(Debug, Clone)] pub struct ServiceDiscoveryConfig { pub enabled: bool, @@ -452,10 +453,12 @@ async fn handle_pod_deletion( #[cfg(test)] mod tests { + use k8s_openapi::{ + api::core::v1::{Pod, PodCondition, PodSpec, PodStatus}, + apimachinery::pkg::apis::meta::v1::{ObjectMeta, Time}, + }; + use super::*; - use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus}; - use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; - use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time; fn create_k8s_pod( name: Option<&str>, @@ -535,8 +538,7 @@ mod tests { } async fn create_test_app_context() -> Arc { - use crate::config::RouterConfig; - use crate::middleware::TokenBucket; + use crate::{config::RouterConfig, middleware::TokenBucket}; let router_config = RouterConfig { worker_startup_timeout_secs: 1, diff --git a/sgl-router/src/tokenizer/chat_template.rs b/sgl-router/src/tokenizer/chat_template.rs index e82544ca4..a575e8c44 100644 --- a/sgl-router/src/tokenizer/chat_template.rs +++ b/sgl-router/src/tokenizer/chat_template.rs @@ -3,12 +3,16 @@ //! This module provides functionality to apply chat templates to messages, //! similar to HuggingFace transformers' apply_chat_template method. -use anyhow::{anyhow, Result}; -use minijinja::machinery::ast::{Expr, Stmt}; -use minijinja::{context, Environment, Value}; -use serde_json; use std::collections::HashMap; +use anyhow::{anyhow, Result}; +use minijinja::{ + context, + machinery::ast::{Expr, Stmt}, + Environment, Value, +}; +use serde_json; + /// Chat template content format #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ChatTemplateContentFormat { @@ -319,8 +323,10 @@ impl<'a> Detector<'a> { /// AST-based detection using minijinja's unstable machinery /// Single-pass detector with scope tracking fn detect_format_with_ast(template: &str) -> Option { - use minijinja::machinery::{parse, WhitespaceConfig}; - use minijinja::syntax::SyntaxConfig; + use minijinja::{ + machinery::{parse, WhitespaceConfig}, + syntax::SyntaxConfig, + }; let ast = match parse( template, diff --git a/sgl-router/src/tokenizer/factory.rs b/sgl-router/src/tokenizer/factory.rs index 3f5d4f67c..8d8cde5f7 100644 --- a/sgl-router/src/tokenizer/factory.rs +++ b/sgl-router/src/tokenizer/factory.rs @@ -1,13 +1,9 @@ -use super::traits; +use std::{fs::File, io::Read, path::Path, sync::Arc}; + use anyhow::{Error, Result}; -use std::fs::File; -use std::io::Read; -use std::path::Path; -use std::sync::Arc; use tracing::{debug, info}; -use super::huggingface::HuggingFaceTokenizer; -use super::tiktoken::TiktokenTokenizer; +use super::{huggingface::HuggingFaceTokenizer, tiktoken::TiktokenTokenizer, traits}; use crate::tokenizer::hub::download_tokenizer_from_hf; /// Represents the type of tokenizer being used @@ -379,8 +375,7 @@ pub fn get_tokenizer_info(file_path: &str) -> Result { Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())), _ => { // Try auto-detection - use std::fs::File; - use std::io::Read; + use std::{fs::File, io::Read}; let mut file = File::open(file_path)?; let mut buffer = vec![0u8; 512]; diff --git a/sgl-router/src/tokenizer/hub.rs b/sgl-router/src/tokenizer/hub.rs index f9c344f57..c67616c16 100644 --- a/sgl-router/src/tokenizer/hub.rs +++ b/sgl-router/src/tokenizer/hub.rs @@ -1,6 +1,9 @@ +use std::{ + env, + path::{Path, PathBuf}, +}; + use hf_hub::api::tokio::ApiBuilder; -use std::env; -use std::path::{Path, PathBuf}; const IGNORED: [&str; 5] = [ ".gitattributes", diff --git a/sgl-router/src/tokenizer/huggingface.rs b/sgl-router/src/tokenizer/huggingface.rs index beaf98eb7..727d35715 100644 --- a/sgl-router/src/tokenizer/huggingface.rs +++ b/sgl-router/src/tokenizer/huggingface.rs @@ -3,12 +3,12 @@ use std::collections::HashMap; use anyhow::{Error, Result}; use tokenizers::tokenizer::Tokenizer as HfTokenizer; -use super::chat_template::{ - detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, - ChatTemplateProcessor, -}; -use super::traits::{ - Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, +use super::{ + chat_template::{ + detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, + ChatTemplateProcessor, + }, + traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait}, }; /// HuggingFace tokenizer wrapper diff --git a/sgl-router/src/tokenizer/mock.rs b/sgl-router/src/tokenizer/mock.rs index 9b0cd5cdf..ab918db37 100644 --- a/sgl-router/src/tokenizer/mock.rs +++ b/sgl-router/src/tokenizer/mock.rs @@ -1,9 +1,11 @@ //! Mock tokenizer implementation for testing -use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; -use anyhow::Result; use std::collections::HashMap; +use anyhow::Result; + +use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; + /// Mock tokenizer for testing purposes pub struct MockTokenizer { vocab: HashMap, diff --git a/sgl-router/src/tokenizer/mod.rs b/sgl-router/src/tokenizer/mod.rs index dbf3ba940..651b340de 100644 --- a/sgl-router/src/tokenizer/mod.rs +++ b/sgl-router/src/tokenizer/mod.rs @@ -1,6 +1,6 @@ +use std::{ops::Deref, sync::Arc}; + use anyhow::Result; -use std::ops::Deref; -use std::sync::Arc; pub mod factory; pub mod hub; @@ -27,14 +27,12 @@ pub use factory::{ create_tokenizer_from_file, create_tokenizer_with_chat_template, create_tokenizer_with_chat_template_blocking, TokenizerType, }; +pub use huggingface::HuggingFaceTokenizer; pub use sequence::Sequence; pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder}; pub use stream::DecodeStream; -pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; - -pub use huggingface::HuggingFaceTokenizer; - pub use tiktoken::{TiktokenModel, TiktokenTokenizer}; +pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; /// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations #[derive(Clone)] diff --git a/sgl-router/src/tokenizer/sequence.rs b/sgl-router/src/tokenizer/sequence.rs index f54f73437..a9b114021 100644 --- a/sgl-router/src/tokenizer/sequence.rs +++ b/sgl-router/src/tokenizer/sequence.rs @@ -1,7 +1,9 @@ -use super::traits::{TokenIdType, Tokenizer as TokenizerTrait}; -use anyhow::Result; use std::sync::Arc; +use anyhow::Result; + +use super::traits::{TokenIdType, Tokenizer as TokenizerTrait}; + /// Maintains state for an ongoing sequence of tokens and their decoded text /// This provides a cleaner abstraction for managing token sequences pub struct Sequence { diff --git a/sgl-router/src/tokenizer/stop.rs b/sgl-router/src/tokenizer/stop.rs index 3122a0e97..14f68793d 100644 --- a/sgl-router/src/tokenizer/stop.rs +++ b/sgl-router/src/tokenizer/stop.rs @@ -1,8 +1,11 @@ -use super::sequence::Sequence; -use super::traits::{self, TokenIdType}; +use std::{collections::HashSet, sync::Arc}; + use anyhow::Result; -use std::collections::HashSet; -use std::sync::Arc; + +use super::{ + sequence::Sequence, + traits::{self, TokenIdType}, +}; /// Output from the sequence decoder #[derive(Debug, Clone, PartialEq)] diff --git a/sgl-router/src/tokenizer/stream.rs b/sgl-router/src/tokenizer/stream.rs index 848be8a8c..978cdcae4 100644 --- a/sgl-router/src/tokenizer/stream.rs +++ b/sgl-router/src/tokenizer/stream.rs @@ -1,9 +1,11 @@ // src/tokenizer/stream.rs -use super::traits::{self, TokenIdType}; -use anyhow::Result; use std::sync::Arc; +use anyhow::Result; + +use super::traits::{self, TokenIdType}; + const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5; /// DecodeStream will keep the state necessary to produce individual chunks of diff --git a/sgl-router/src/tokenizer/tests.rs b/sgl-router/src/tokenizer/tests.rs index 7ad8399df..acd158766 100644 --- a/sgl-router/src/tokenizer/tests.rs +++ b/sgl-router/src/tokenizer/tests.rs @@ -1,8 +1,9 @@ #[cfg(test)] -use super::*; -#[cfg(test)] use std::sync::Arc; +#[cfg(test)] +use super::*; + #[test] fn test_mock_tokenizer_encode() { let tokenizer = mock::MockTokenizer::new(); diff --git a/sgl-router/src/tokenizer/tiktoken.rs b/sgl-router/src/tokenizer/tiktoken.rs index d75c10569..13df755f4 100644 --- a/sgl-router/src/tokenizer/tiktoken.rs +++ b/sgl-router/src/tokenizer/tiktoken.rs @@ -1,8 +1,9 @@ +use anyhow::{Error, Result}; +use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE}; + use super::traits::{ Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, }; -use anyhow::{Error, Result}; -use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE}; /// Tiktoken tokenizer wrapper for OpenAI GPT models pub struct TiktokenTokenizer { diff --git a/sgl-router/src/tokenizer/traits.rs b/sgl-router/src/tokenizer/traits.rs index 3ef2c4fe0..6e2fa7cb6 100644 --- a/sgl-router/src/tokenizer/traits.rs +++ b/sgl-router/src/tokenizer/traits.rs @@ -1,6 +1,9 @@ +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, +}; + use anyhow::Result; -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; /// Type alias for token IDs pub type TokenIdType = u32; diff --git a/sgl-router/src/tool_parser/factory.rs b/sgl-router/src/tool_parser/factory.rs index 43c40b6e8..a9fc9fd5f 100644 --- a/sgl-router/src/tool_parser/factory.rs +++ b/sgl-router/src/tool_parser/factory.rs @@ -1,14 +1,19 @@ // Factory and pool for creating model-specific tool parsers with pooling support. -use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + use tokio::sync::Mutex; -use crate::tool_parser::parsers::{ - DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser, - LlamaParser, MistralParser, PassthroughParser, PythonicParser, QwenParser, Step3Parser, +use crate::tool_parser::{ + parsers::{ + DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser, + LlamaParser, MistralParser, PassthroughParser, PythonicParser, QwenParser, Step3Parser, + }, + traits::ToolParser, }; -use crate::tool_parser::traits::ToolParser; /// Type alias for pooled parser instances. pub type PooledParser = Arc>>; diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index d4521b10c..11950e19d 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -18,11 +18,10 @@ mod tests; // Re-export commonly used types pub use errors::{ParserError, ParserResult}; pub use factory::{ParserFactory, ParserRegistry, PooledParser}; -pub use traits::{PartialJsonParser, ToolParser}; -pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall}; - // Re-export parsers for convenience pub use parsers::{ DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser, }; +pub use traits::{PartialJsonParser, ToolParser}; +pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall}; diff --git a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs index bb6306043..4ae789422 100644 --- a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs +++ b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs @@ -2,13 +2,14 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::{ParserError, ParserResult}, - parsers::helpers, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + parsers::helpers, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, + }, }; /// DeepSeek V3 format parser for tool calls diff --git a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs index 38cd86558..8b9dc5024 100644 --- a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs +++ b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs @@ -2,13 +2,14 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::{ParserError, ParserResult}, - parsers::helpers, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + parsers::helpers, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, + }, }; /// GLM-4 MoE format parser for tool calls diff --git a/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs b/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs index 5c66d08ae..79e998750 100644 --- a/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs +++ b/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs @@ -1,11 +1,12 @@ use async_trait::async_trait; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::ParserResult, - traits::{TokenToolParser, ToolParser}, - types::{StreamingParseResult, ToolCall}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::ParserResult, + traits::{TokenToolParser, ToolParser}, + types::{StreamingParseResult, ToolCall}, + }, }; /// Placeholder for the Harmony-backed GPT-OSS parser. diff --git a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs index dcd9afb26..b38fa52f2 100644 --- a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs +++ b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs @@ -2,14 +2,15 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::{ParserError, ParserResult}, - parsers::helpers, - partial_json::PartialJson, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + parsers::helpers, + partial_json::PartialJson, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, + }, }; /// GPT-OSS format parser for tool calls diff --git a/sgl-router/src/tool_parser/parsers/helpers.rs b/sgl-router/src/tool_parser/parsers/helpers.rs index 109cea53f..fa8799cf1 100644 --- a/sgl-router/src/tool_parser/parsers/helpers.rs +++ b/sgl-router/src/tool_parser/parsers/helpers.rs @@ -1,9 +1,14 @@ -use crate::protocols::common::Tool; -use serde_json::Value; use std::collections::HashMap; -use crate::tool_parser::errors::{ParserError, ParserResult}; -use crate::tool_parser::types::{StreamingParseResult, ToolCallItem}; +use serde_json::Value; + +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + types::{StreamingParseResult, ToolCallItem}, + }, +}; /// Get a mapping of tool names to their indices pub fn get_tool_indices(tools: &[Tool]) -> HashMap { diff --git a/sgl-router/src/tool_parser/parsers/json_parser.rs b/sgl-router/src/tool_parser/parsers/json_parser.rs index 1c7b481ab..3af6518c2 100644 --- a/sgl-router/src/tool_parser/parsers/json_parser.rs +++ b/sgl-router/src/tool_parser/parsers/json_parser.rs @@ -1,14 +1,15 @@ use async_trait::async_trait; use serde_json::Value; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::{ParserError, ParserResult}, - parsers::helpers, - partial_json::PartialJson, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + parsers::helpers, + partial_json::PartialJson, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, + }, }; /// JSON format parser for tool calls diff --git a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs index 9cc11437b..e68aa422e 100644 --- a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs +++ b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs @@ -2,13 +2,14 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::ParserResult, - parsers::helpers, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::ParserResult, + parsers::helpers, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, + }, }; /// Kimi K2 format parser for tool calls diff --git a/sgl-router/src/tool_parser/parsers/llama_parser.rs b/sgl-router/src/tool_parser/parsers/llama_parser.rs index e42c2d679..1bdb64315 100644 --- a/sgl-router/src/tool_parser/parsers/llama_parser.rs +++ b/sgl-router/src/tool_parser/parsers/llama_parser.rs @@ -1,14 +1,15 @@ use async_trait::async_trait; use serde_json::Value; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::{ParserError, ParserResult}, - parsers::helpers, - partial_json::PartialJson, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + parsers::helpers, + partial_json::PartialJson, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall}, + }, }; /// Llama 3.2 format parser for tool calls diff --git a/sgl-router/src/tool_parser/parsers/mistral_parser.rs b/sgl-router/src/tool_parser/parsers/mistral_parser.rs index 151e7fccf..54aa31e7f 100644 --- a/sgl-router/src/tool_parser/parsers/mistral_parser.rs +++ b/sgl-router/src/tool_parser/parsers/mistral_parser.rs @@ -1,14 +1,15 @@ use async_trait::async_trait; use serde_json::Value; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::{ParserError, ParserResult}, - parsers::helpers, - partial_json::PartialJson, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + parsers::helpers, + partial_json::PartialJson, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall}, + }, }; /// Mistral format parser for tool calls diff --git a/sgl-router/src/tool_parser/parsers/passthrough_parser.rs b/sgl-router/src/tool_parser/parsers/passthrough_parser.rs index b718bff58..11170f9d3 100644 --- a/sgl-router/src/tool_parser/parsers/passthrough_parser.rs +++ b/sgl-router/src/tool_parser/parsers/passthrough_parser.rs @@ -4,12 +4,17 @@ //! tool call parsing should be performed. It simply returns the input text //! with no tool calls detected. -use crate::protocols::common::Tool; -use crate::tool_parser::errors::ParserResult; -use crate::tool_parser::traits::ToolParser; -use crate::tool_parser::types::{StreamingParseResult, ToolCall, ToolCallItem}; use async_trait::async_trait; +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::ParserResult, + traits::ToolParser, + types::{StreamingParseResult, ToolCall, ToolCallItem}, + }, +}; + /// Passthrough parser that returns text unchanged with no tool calls #[derive(Default)] pub struct PassthroughParser; diff --git a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs index 317e5836d..175b4fdcf 100644 --- a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs +++ b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs @@ -1,3 +1,5 @@ +use std::sync::OnceLock; + /// Pythonic format parser for tool calls /// /// Handles Python function call syntax within square brackets: @@ -10,18 +12,20 @@ use async_trait::async_trait; use num_traits::ToPrimitive; use regex::Regex; -use rustpython_parser::ast::{Constant, Expr, Mod, UnaryOp}; -use rustpython_parser::{parse, Mode}; +use rustpython_parser::{ + ast::{Constant, Expr, Mod, UnaryOp}, + parse, Mode, +}; use serde_json::{Map, Number, Value}; -use std::sync::OnceLock; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::{ParserError, ParserResult}, - parsers::helpers, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + parsers::helpers, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, + }, }; static PYTHONIC_BLOCK_REGEX: OnceLock = OnceLock::new(); diff --git a/sgl-router/src/tool_parser/parsers/qwen_parser.rs b/sgl-router/src/tool_parser/parsers/qwen_parser.rs index a3f5d965a..9267fddda 100644 --- a/sgl-router/src/tool_parser/parsers/qwen_parser.rs +++ b/sgl-router/src/tool_parser/parsers/qwen_parser.rs @@ -2,14 +2,15 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::{ParserError, ParserResult}, - parsers::helpers, - partial_json::PartialJson, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + parsers::helpers, + partial_json::PartialJson, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall}, + }, }; /// Qwen format parser for tool calls diff --git a/sgl-router/src/tool_parser/parsers/step3_parser.rs b/sgl-router/src/tool_parser/parsers/step3_parser.rs index 1b311cc67..d53f81d56 100644 --- a/sgl-router/src/tool_parser/parsers/step3_parser.rs +++ b/sgl-router/src/tool_parser/parsers/step3_parser.rs @@ -1,15 +1,17 @@ +use std::collections::HashMap; + use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use std::collections::HashMap; -use crate::protocols::common::Tool; - -use crate::tool_parser::{ - errors::{ParserError, ParserResult}, - parsers::helpers, - traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::{ParserError, ParserResult}, + parsers::helpers, + traits::ToolParser, + types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, + }, }; /// Step3 format parser for tool calls diff --git a/sgl-router/src/tool_parser/partial_json.rs b/sgl-router/src/tool_parser/partial_json.rs index c6d474d6a..0764572e8 100644 --- a/sgl-router/src/tool_parser/partial_json.rs +++ b/sgl-router/src/tool_parser/partial_json.rs @@ -1,8 +1,9 @@ +use serde_json::{Map, Value}; + use crate::tool_parser::{ errors::{ParserError, ParserResult}, traits::PartialJsonParser, }; -use serde_json::{Map, Value}; /// Parser for incomplete JSON pub struct PartialJson { diff --git a/sgl-router/src/tool_parser/tests.rs b/sgl-router/src/tool_parser/tests.rs index b440382b6..bf43f9553 100644 --- a/sgl-router/src/tool_parser/tests.rs +++ b/sgl-router/src/tool_parser/tests.rs @@ -1,9 +1,9 @@ use super::*; -use crate::tool_parser::parsers::JsonParser; -use crate::tool_parser::partial_json::{ - compute_diff, find_common_prefix, is_complete_json, PartialJson, +use crate::tool_parser::{ + parsers::JsonParser, + partial_json::{compute_diff, find_common_prefix, is_complete_json, PartialJson}, + traits::ToolParser, }; -use crate::tool_parser::traits::ToolParser; #[tokio::test] async fn test_tool_parser_factory() { diff --git a/sgl-router/src/tool_parser/traits.rs b/sgl-router/src/tool_parser/traits.rs index 482f11dea..51421f20f 100644 --- a/sgl-router/src/tool_parser/traits.rs +++ b/sgl-router/src/tool_parser/traits.rs @@ -1,10 +1,13 @@ -use crate::protocols::common::Tool; -use crate::tool_parser::{ - errors::ParserResult, - types::{StreamingParseResult, ToolCall}, -}; use async_trait::async_trait; +use crate::{ + protocols::common::Tool, + tool_parser::{ + errors::ParserResult, + types::{StreamingParseResult, ToolCall}, + }, +}; + /// Core trait for all tool parsers #[async_trait] pub trait ToolParser: Send + Sync { diff --git a/sgl-router/src/tree.rs b/sgl-router/src/tree.rs index e511620a6..e4661620c 100644 --- a/sgl-router/src/tree.rs +++ b/sgl-router/src/tree.rs @@ -1,17 +1,13 @@ -use dashmap::mapref::entry::Entry; -use dashmap::DashMap; +use std::{ + cmp::Reverse, + collections::{BinaryHeap, HashMap, VecDeque}, + sync::{Arc, RwLock}, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use dashmap::{mapref::entry::Entry, DashMap}; use tracing::info; -use std::cmp::Reverse; -use std::collections::BinaryHeap; -use std::collections::HashMap; -use std::collections::VecDeque; -use std::sync::Arc; -use std::sync::RwLock; - -use std::time::Duration; -use std::time::{SystemTime, UNIX_EPOCH}; - type NodeRef = Arc; #[derive(Debug)] @@ -666,12 +662,12 @@ impl Tree { // Unit tests #[cfg(test)] mod tests { - use rand::distr::Alphanumeric; - use rand::distr::SampleString; - use rand::rng as thread_rng; - use rand::Rng; - use std::thread; - use std::time::Instant; + use std::{thread, time::Instant}; + + use rand::{ + distr::{Alphanumeric, SampleString}, + rng as thread_rng, Rng, + }; use super::*; diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index ffdb4997c..a94b416f0 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -1,5 +1,7 @@ mod common; +use std::sync::Arc; + use axum::{ body::Body, extract::Request, @@ -8,13 +10,14 @@ use axum::{ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{ - CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, +use sglang_router_rs::{ + config::{ + CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, + }, + core::WorkerManager, + routers::{RouterFactory, RouterTrait}, + server::AppContext, }; -use sglang_router_rs::core::WorkerManager; -use sglang_router_rs::routers::{RouterFactory, RouterTrait}; -use sglang_router_rs::server::AppContext; -use std::sync::Arc; use tower::ServiceExt; /// Test context that manages mock workers @@ -995,9 +998,10 @@ mod router_policy_tests { #[cfg(test)] mod responses_endpoint_tests { - use super::*; use reqwest::Client as HttpClient; + use super::*; + #[tokio::test] async fn test_v1_responses_non_streaming() { let ctx = TestContext::new(vec![MockWorkerConfig { diff --git a/sgl-router/tests/cache_aware_backward_compat_test.rs b/sgl-router/tests/cache_aware_backward_compat_test.rs index 6ff62b10b..9cafd6240 100644 --- a/sgl-router/tests/cache_aware_backward_compat_test.rs +++ b/sgl-router/tests/cache_aware_backward_compat_test.rs @@ -1,7 +1,9 @@ -use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType}; -use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy}; -use std::collections::HashMap; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; + +use sglang_router_rs::{ + core::{BasicWorkerBuilder, Worker, WorkerType}, + policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy}, +}; #[test] fn test_backward_compatibility_with_empty_model_id() { diff --git a/sgl-router/tests/chat_template_format_detection.rs b/sgl-router/tests/chat_template_format_detection.rs index 64ef20f02..b54785b4e 100644 --- a/sgl-router/tests/chat_template_format_detection.rs +++ b/sgl-router/tests/chat_template_format_detection.rs @@ -1,7 +1,9 @@ -use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent}; -use sglang_router_rs::tokenizer::chat_template::{ - detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, - ChatTemplateProcessor, +use sglang_router_rs::{ + protocols::chat::{ChatMessage, UserMessageContent}, + tokenizer::chat_template::{ + detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, + ChatTemplateProcessor, + }, }; #[test] diff --git a/sgl-router/tests/chat_template_integration.rs b/sgl-router/tests/chat_template_integration.rs index 30a0b146a..2f345a4f1 100644 --- a/sgl-router/tests/chat_template_integration.rs +++ b/sgl-router/tests/chat_template_integration.rs @@ -1,8 +1,12 @@ -use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent}; -use sglang_router_rs::protocols::common::{ContentPart, ImageUrl}; -use sglang_router_rs::tokenizer::chat_template::{ - detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, - ChatTemplateProcessor, +use sglang_router_rs::{ + protocols::{ + chat::{ChatMessage, UserMessageContent}, + common::{ContentPart, ImageUrl}, + }, + tokenizer::chat_template::{ + detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, + ChatTemplateProcessor, + }, }; #[test] diff --git a/sgl-router/tests/chat_template_loading.rs b/sgl-router/tests/chat_template_loading.rs index 5297198c7..4c537012b 100644 --- a/sgl-router/tests/chat_template_loading.rs +++ b/sgl-router/tests/chat_template_loading.rs @@ -1,9 +1,11 @@ #[cfg(test)] mod tests { - use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent}; - use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams; - use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; use std::fs; + + use sglang_router_rs::{ + protocols::chat::{ChatMessage, UserMessageContent}, + tokenizer::{chat_template::ChatTemplateParams, huggingface::HuggingFaceTokenizer}, + }; use tempfile::TempDir; #[test] diff --git a/sgl-router/tests/common/mock_mcp_server.rs b/sgl-router/tests/common/mock_mcp_server.rs index f5dfea738..02e80d5e1 100644 --- a/sgl-router/tests/common/mock_mcp_server.rs +++ b/sgl-router/tests/common/mock_mcp_server.rs @@ -148,8 +148,7 @@ mod tests { async fn test_mock_server_with_rmcp_client() { let mut server = MockMCPServer::start().await.unwrap(); - use rmcp::transport::StreamableHttpClientTransport; - use rmcp::ServiceExt; + use rmcp::{transport::StreamableHttpClientTransport, ServiceExt}; let transport = StreamableHttpClientTransport::from_uri(server.url().as_str()); let client = ().serve(transport).await; diff --git a/sgl-router/tests/common/mock_openai_server.rs b/sgl-router/tests/common/mock_openai_server.rs index 643fd5e98..36fac0543 100644 --- a/sgl-router/tests/common/mock_openai_server.rs +++ b/sgl-router/tests/common/mock_openai_server.rs @@ -2,19 +2,21 @@ #![allow(dead_code)] +use std::{net::SocketAddr, sync::Arc}; + use axum::{ body::Body, extract::{Request, State}, http::{HeaderValue, StatusCode}, - response::sse::{Event, KeepAlive}, - response::{IntoResponse, Response, Sse}, + response::{ + sse::{Event, KeepAlive}, + IntoResponse, Response, Sse, + }, routing::post, Json, Router, }; use futures_util::stream::{self, StreamExt}; use serde_json::json; -use std::net::SocketAddr; -use std::sync::Arc; use tokio::net::TcpListener; /// Mock OpenAI API server for testing diff --git a/sgl-router/tests/common/mock_worker.rs b/sgl-router/tests/common/mock_worker.rs index 384048f1a..c1e9142c5 100755 --- a/sgl-router/tests/common/mock_worker.rs +++ b/sgl-router/tests/common/mock_worker.rs @@ -1,20 +1,25 @@ // Mock worker for testing - these functions are used by integration tests #![allow(dead_code)] +use std::{ + collections::{HashMap, HashSet}, + convert::Infallible, + sync::{Arc, Mutex, OnceLock}, + time::{SystemTime, UNIX_EPOCH}, +}; + use axum::{ extract::{Json, Path, State}, http::StatusCode, - response::sse::{Event, KeepAlive}, - response::{IntoResponse, Response, Sse}, + response::{ + sse::{Event, KeepAlive}, + IntoResponse, Response, Sse, + }, routing::{get, post}, Router, }; use futures_util::stream::{self, StreamExt}; use serde_json::json; -use std::collections::{HashMap, HashSet}; -use std::convert::Infallible; -use std::sync::{Arc, Mutex, OnceLock}; -use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; use uuid::Uuid; diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index 8b9a9bd75..ab92d4ed8 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -7,19 +7,24 @@ pub mod mock_worker; pub mod streaming_helpers; pub mod test_app; -use serde_json::json; -use sglang_router_rs::config::RouterConfig; -use sglang_router_rs::core::{LoadMonitor, WorkerRegistry}; -use sglang_router_rs::data_connector::{ - MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, +use std::{ + fs, + path::PathBuf, + sync::{Arc, Mutex, OnceLock}, +}; + +use serde_json::json; +use sglang_router_rs::{ + config::RouterConfig, + core::{LoadMonitor, WorkerRegistry}, + data_connector::{ + MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, + }, + middleware::TokenBucket, + policies::PolicyRegistry, + protocols::common::{Function, Tool}, + server::AppContext, }; -use sglang_router_rs::middleware::TokenBucket; -use sglang_router_rs::policies::PolicyRegistry; -use sglang_router_rs::protocols::common::{Function, Tool}; -use sglang_router_rs::server::AppContext; -use std::fs; -use std::path::PathBuf; -use std::sync::{Arc, Mutex, OnceLock}; /// Helper function to create AppContext for tests pub fn create_test_context(config: RouterConfig) -> Arc { diff --git a/sgl-router/tests/common/test_app.rs b/sgl-router/tests/common/test_app.rs index c293b7074..76a04c086 100644 --- a/sgl-router/tests/common/test_app.rs +++ b/sgl-router/tests/common/test_app.rs @@ -1,3 +1,5 @@ +use std::sync::{Arc, OnceLock}; + use axum::Router; use reqwest::Client; use sglang_router_rs::{ @@ -11,7 +13,6 @@ use sglang_router_rs::{ routers::RouterTrait, server::{build_app, AppContext, AppState}, }; -use std::sync::{Arc, OnceLock}; /// Create a test Axum application using the actual server's build_app function #[allow(dead_code)] diff --git a/sgl-router/tests/mcp_test.rs b/sgl-router/tests/mcp_test.rs index 720164340..fb1c4404c 100644 --- a/sgl-router/tests/mcp_test.rs +++ b/sgl-router/tests/mcp_test.rs @@ -9,10 +9,11 @@ mod common; +use std::collections::HashMap; + use common::mock_mcp_server::MockMCPServer; use serde_json::json; use sglang_router_rs::mcp::{McpClientManager, McpConfig, McpError, McpServerConfig, McpTransport}; -use std::collections::HashMap; /// Create a new mock server for testing (each test gets its own) async fn create_mock_server() -> MockMCPServer { diff --git a/sgl-router/tests/policy_registry_integration.rs b/sgl-router/tests/policy_registry_integration.rs index 48d79bf42..2f38ade3e 100644 --- a/sgl-router/tests/policy_registry_integration.rs +++ b/sgl-router/tests/policy_registry_integration.rs @@ -1,12 +1,11 @@ //! Integration tests for PolicyRegistry with RouterManager -use sglang_router_rs::config::PolicyConfig; -use sglang_router_rs::core::WorkerRegistry; -use sglang_router_rs::policies::PolicyRegistry; -use sglang_router_rs::protocols::worker_spec::WorkerConfigRequest; -use sglang_router_rs::routers::router_manager::RouterManager; -use std::collections::HashMap; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; + +use sglang_router_rs::{ + config::PolicyConfig, core::WorkerRegistry, policies::PolicyRegistry, + protocols::worker_spec::WorkerConfigRequest, routers::router_manager::RouterManager, +}; #[tokio::test] async fn test_policy_registry_with_router_manager() { @@ -95,8 +94,7 @@ async fn test_policy_registry_with_router_manager() { #[test] fn test_policy_registry_cleanup() { - use sglang_router_rs::config::PolicyConfig; - use sglang_router_rs::policies::PolicyRegistry; + use sglang_router_rs::{config::PolicyConfig, policies::PolicyRegistry}; let registry = PolicyRegistry::new(PolicyConfig::RoundRobin); @@ -123,8 +121,7 @@ fn test_policy_registry_cleanup() { #[test] fn test_policy_registry_multiple_models() { - use sglang_router_rs::config::PolicyConfig; - use sglang_router_rs::policies::PolicyRegistry; + use sglang_router_rs::{config::PolicyConfig, policies::PolicyRegistry}; let registry = PolicyRegistry::new(PolicyConfig::RoundRobin); diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index 589be6171..823ca5fbe 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -1,12 +1,15 @@ mod common; +use std::sync::Arc; + use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{RouterConfig, RoutingMode}; -use sglang_router_rs::core::WorkerManager; -use sglang_router_rs::routers::{RouterFactory, RouterTrait}; -use std::sync::Arc; +use sglang_router_rs::{ + config::{RouterConfig, RoutingMode}, + core::WorkerManager, + routers::{RouterFactory, RouterTrait}, +}; /// Test context that manages mock workers struct TestContext { diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index 896d4e484..86cd154d0 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -1,22 +1,26 @@ // Integration test for Responses API use axum::http::StatusCode; -use sglang_router_rs::protocols::common::{ - GenerationRequest, ToolChoice, ToolChoiceValue, UsageInfo, -}; -use sglang_router_rs::protocols::responses::{ - ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseTool, ResponseToolType, - ResponsesRequest, ServiceTier, Truncation, +use sglang_router_rs::protocols::{ + common::{GenerationRequest, ToolChoice, ToolChoiceValue, UsageInfo}, + responses::{ + ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseTool, ResponseToolType, + ResponsesRequest, ServiceTier, Truncation, + }, }; mod common; -use common::mock_mcp_server::MockMCPServer; -use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; -use sglang_router_rs::config::{ - CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig, - RouterConfig, RoutingMode, +use common::{ + mock_mcp_server::MockMCPServer, + mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}, +}; +use sglang_router_rs::{ + config::{ + CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig, + RouterConfig, RoutingMode, + }, + routers::RouterFactory, }; -use sglang_router_rs::routers::RouterFactory; #[tokio::test] async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { diff --git a/sgl-router/tests/spec/chat_completion.rs b/sgl-router/tests/spec/chat_completion.rs index 66e417ea0..8994af114 100644 --- a/sgl-router/tests/spec/chat_completion.rs +++ b/sgl-router/tests/spec/chat_completion.rs @@ -1,10 +1,12 @@ use serde_json::json; -use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}; -use sglang_router_rs::protocols::common::{ - Function, FunctionCall, FunctionChoice, StreamOptions, Tool, ToolChoice, ToolChoiceValue, - ToolReference, +use sglang_router_rs::protocols::{ + chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, + common::{ + Function, FunctionCall, FunctionChoice, StreamOptions, Tool, ToolChoice, ToolChoiceValue, + ToolReference, + }, + validated::Normalizable, }; -use sglang_router_rs::protocols::validated::Normalizable; use validator::Validate; // Deprecated fields normalization tests diff --git a/sgl-router/tests/spec/embedding.rs b/sgl-router/tests/spec/embedding.rs index e7c832884..2a55d88af 100644 --- a/sgl-router/tests/spec/embedding.rs +++ b/sgl-router/tests/spec/embedding.rs @@ -1,6 +1,5 @@ use serde_json::{from_str, json, to_string}; -use sglang_router_rs::protocols::common::GenerationRequest; -use sglang_router_rs::protocols::embedding::EmbeddingRequest; +use sglang_router_rs::protocols::{common::GenerationRequest, embedding::EmbeddingRequest}; #[test] fn test_embedding_request_serialization_string_input() { diff --git a/sgl-router/tests/spec/rerank.rs b/sgl-router/tests/spec/rerank.rs index 790ab49df..3f6524a0c 100644 --- a/sgl-router/tests/spec/rerank.rs +++ b/sgl-router/tests/spec/rerank.rs @@ -1,9 +1,10 @@ -use serde_json::{from_str, to_string, Number, Value}; -use sglang_router_rs::protocols::common::{GenerationRequest, StringOrArray, UsageInfo}; -use sglang_router_rs::protocols::rerank::{ - RerankRequest, RerankResponse, RerankResult, V1RerankReqInput, -}; use std::collections::HashMap; + +use serde_json::{from_str, to_string, Number, Value}; +use sglang_router_rs::protocols::{ + common::{GenerationRequest, StringOrArray, UsageInfo}, + rerank::{RerankRequest, RerankResponse, RerankResult, V1RerankReqInput}, +}; use validator::Validate; #[test] diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index 5e1bcf876..81b7443e5 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -1,13 +1,16 @@ mod common; +use std::sync::Arc; + use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use futures_util::StreamExt; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{RouterConfig, RoutingMode}; -use sglang_router_rs::core::WorkerManager; -use sglang_router_rs::routers::{RouterFactory, RouterTrait}; -use std::sync::Arc; +use sglang_router_rs::{ + config::{RouterConfig, RoutingMode}, + core::WorkerManager, + routers::{RouterFactory, RouterTrait}, +}; /// Test context that manages mock workers struct TestContext { diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index fa249cebb..9e3e995c5 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -1,5 +1,13 @@ //! Comprehensive integration tests for OpenAI backend functionality +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + use axum::{ body::Body, extract::Request, @@ -26,13 +34,10 @@ use sglang_router_rs::{ }, routers::{openai::OpenAIRouter, RouterTrait}, }; -use std::collections::HashMap; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, +use tokio::{ + net::TcpListener, + time::{sleep, Duration}, }; -use tokio::net::TcpListener; -use tokio::time::{sleep, Duration}; use tower::ServiceExt; mod common; @@ -962,8 +967,7 @@ fn oracle_config_validation_accepts_wallet_alias() { /// Test that RouterManager delegates /v1/models to OpenAI router in single-router mode #[tokio::test] async fn test_router_manager_delegates_models_to_openai_router() { - use sglang_router_rs::routers::router_manager::RouterManager; - use sglang_router_rs::server::ServerConfig; + use sglang_router_rs::{routers::router_manager::RouterManager, server::ServerConfig}; // Start a mock OpenAI server let mock_server = MockOpenAIServer::new().await; diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index ae50693b9..dfe9cdcdc 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -1,12 +1,14 @@ #[cfg(test)] mod test_pd_routing { use serde_json::json; - use sglang_router_rs::config::{ - CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, + use sglang_router_rs::{ + config::{ + CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, + RoutingMode, + }, + core::{BasicWorkerBuilder, Worker, WorkerType}, + routers::{http::pd_types::PDSelectionPolicy, RouterFactory}, }; - use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType}; - use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy; - use sglang_router_rs::routers::RouterFactory; #[derive(Debug)] struct PDRequest { @@ -201,14 +203,18 @@ mod test_pd_routing { }; let app_context = { - use sglang_router_rs::core::{LoadMonitor, WorkerRegistry}; - use sglang_router_rs::data_connector::{ - MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, - }; - use sglang_router_rs::middleware::TokenBucket; - use sglang_router_rs::policies::PolicyRegistry; use std::sync::{Arc, OnceLock}; + use sglang_router_rs::{ + core::{LoadMonitor, WorkerRegistry}, + data_connector::{ + MemoryConversationItemStorage, MemoryConversationStorage, + MemoryResponseStorage, + }, + middleware::TokenBucket, + policies::PolicyRegistry, + }; + let client = reqwest::Client::new(); // Initialize rate limiter @@ -421,6 +427,7 @@ mod test_pd_routing { #[tokio::test] async fn test_background_load_monitoring() { use std::collections::HashMap; + use tokio::sync::watch; let (tx, rx) = watch::channel(HashMap::new()); @@ -466,6 +473,7 @@ mod test_pd_routing { #[tokio::test] async fn test_watch_channel_behavior() { use std::collections::HashMap; + use tokio::sync::watch; let (tx, rx1) = watch::channel(HashMap::new()); diff --git a/sgl-router/tests/tokenizer_integration.rs b/sgl-router/tests/tokenizer_integration.rs index 6e4a87ea9..b6fb68232 100644 --- a/sgl-router/tests/tokenizer_integration.rs +++ b/sgl-router/tests/tokenizer_integration.rs @@ -4,13 +4,13 @@ //! implementation works correctly with real-world tokenizer files. mod common; -use common::{ensure_tokenizer_cached, EXPECTED_HASHES, TEST_PROMPTS}; +use std::sync::Arc; +use common::{ensure_tokenizer_cached, EXPECTED_HASHES, TEST_PROMPTS}; use sglang_router_rs::tokenizer::{ factory, huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*, }; -use std::sync::Arc; const LONG_TEST_PROMPTS: [(&str, &str); 6] = [ ("Tell me about the following text.", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."), @@ -318,6 +318,7 @@ fn test_thread_safety() { #[test] fn test_chat_template_discovery() { use std::fs; + use tempfile::TempDir; // Create a temporary directory with test files @@ -366,6 +367,7 @@ fn test_chat_template_discovery() { #[test] fn test_load_chat_template_from_local_file() { use std::fs; + use tempfile::TempDir; // Test 1: Load tokenizer with explicit chat template path