[router] Add rustfmt and set group imports by default (#11732)

This commit is contained in:
Chang Su
2025-10-16 17:33:29 -07:00
committed by GitHub
parent 7a7f99beb7
commit dc01313da1
126 changed files with 1127 additions and 813 deletions

View File

@@ -54,7 +54,9 @@ jobs:
run: | run: |
source "$HOME/.cargo/env" source "$HOME/.cargo/env"
cd sgl-router/ cd sgl-router/
cargo fmt -- --check rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt
rustup toolchain install nightly --profile minimal
cargo +nightly fmt -- --check
- name: Run Rust tests - name: Run Rust tests
timeout-minutes: 20 timeout-minutes: 20

View File

@@ -1,14 +1,18 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use serde_json::{from_str, to_string, to_value, to_vec};
use std::time::Instant; use std::time::Instant;
use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}; use serde_json::{from_str, to_string, to_value, to_vec};
use sglang_router_rs::protocols::common::StringOrArray; use sglang_router_rs::{
use sglang_router_rs::protocols::completion::CompletionRequest; core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType},
use sglang_router_rs::protocols::generate::GenerateRequest; protocols::{
use sglang_router_rs::protocols::sampling_params::SamplingParams; chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap}; common::StringOrArray,
completion::CompletionRequest,
generate::GenerateRequest,
sampling_params::SamplingParams,
},
routers::http::pd_types::{generate_room_id, RequestWithBootstrap},
};
fn create_test_worker() -> BasicWorker { fn create_test_worker() -> BasicWorker {
BasicWorkerBuilder::new("http://test-server:8000") BasicWorkerBuilder::new("http://test-server:8000")

View File

@@ -1,16 +1,21 @@
//! Comprehensive tokenizer benchmark with clean summary output //! Comprehensive tokenizer benchmark with clean summary output
//! Each test adds a row to the final summary table //! Each test adds a row to the final summary table
use std::{
collections::BTreeMap,
path::PathBuf,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex, OnceLock,
},
thread,
time::{Duration, Instant},
};
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use sglang_router_rs::tokenizer::{ use sglang_router_rs::tokenizer::{
huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*, huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*,
}; };
use std::collections::BTreeMap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::thread;
use std::time::{Duration, Instant};
// Include the common test utilities // Include the common test utilities
#[path = "../tests/common/mod.rs"] #[path = "../tests/common/mod.rs"]

View File

@@ -7,15 +7,22 @@
//! - Streaming vs complete parsing //! - Streaming vs complete parsing
//! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.) //! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.)
use std::{
collections::BTreeMap,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
},
thread,
time::{Duration, Instant},
};
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use serde_json::json; use serde_json::json;
use sglang_router_rs::protocols::common::{Function, Tool}; use sglang_router_rs::{
use sglang_router_rs::tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser}; protocols::common::{Function, Tool},
use std::collections::BTreeMap; tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser},
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; };
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
// Test data for different parser formats - realistic complex examples // Test data for different parser formats - realistic complex examples

8
sgl-router/rustfmt.toml Normal file
View File

@@ -0,0 +1,8 @@
# Rust formatting configuration
# Enforce grouped imports by crate
imports_granularity = "Crate"
# Group std, external crates, and local crate imports separately
group_imports = "StdExternalCrate"
reorder_imports = true
reorder_modules = true

View File

@@ -1,7 +1,9 @@
use super::ConfigResult;
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::ConfigResult;
/// Main router configuration /// Main router configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig { pub struct RouterConfig {

View File

@@ -1,6 +1,11 @@
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::{
use std::sync::{Arc, RwLock}; sync::{
use std::time::{Duration, Instant}; atomic::{AtomicU32, AtomicU64, Ordering},
Arc, RwLock,
},
time::{Duration, Instant},
};
use tracing::info; use tracing::info;
/// Circuit breaker configuration /// Circuit breaker configuration
@@ -316,9 +321,10 @@ pub struct CircuitBreakerStats {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use std::thread; use std::thread;
use super::*;
#[test] #[test]
fn test_circuit_breaker_initial_state() { fn test_circuit_breaker_initial_state() {
let cb = CircuitBreaker::new(); let cb = CircuitBreaker::new();

View File

@@ -68,9 +68,10 @@ impl From<reqwest::Error> for WorkerError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use std::error::Error; use std::error::Error;
use super::*;
#[test] #[test]
fn test_health_check_failed_display() { fn test_health_check_failed_display() {
let error = WorkerError::HealthCheckFailed { let error = WorkerError::HealthCheckFailed {

View File

@@ -3,16 +3,22 @@
//! Provides non-blocking worker management by queuing operations and processing //! Provides non-blocking worker management by queuing operations and processing
//! them asynchronously in background worker tasks. //! them asynchronously in background worker tasks.
use crate::core::WorkerManager; use std::{
use crate::protocols::worker_spec::{JobStatus, WorkerConfigRequest}; sync::{Arc, Weak},
use crate::server::AppContext; time::{Duration, SystemTime},
};
use dashmap::DashMap; use dashmap::DashMap;
use metrics::{counter, gauge, histogram}; use metrics::{counter, gauge, histogram};
use std::sync::{Arc, Weak};
use std::time::{Duration, SystemTime};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::{
core::WorkerManager,
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
server::AppContext,
};
/// Job types for control plane operations /// Job types for control plane operations
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Job { pub enum Job {

View File

@@ -1,10 +1,11 @@
use crate::config::types::RetryConfig;
use axum::http::StatusCode;
use axum::response::Response;
use rand::Rng;
use std::time::Duration; use std::time::Duration;
use axum::{http::StatusCode, response::Response};
use rand::Rng;
use tracing::debug; use tracing::debug;
use crate::config::types::RetryConfig;
/// Check if an HTTP status code indicates a retryable error /// Check if an HTTP status code indicates a retryable error
pub fn is_retryable_status(status: StatusCode) -> bool { pub fn is_retryable_status(status: StatusCode) -> bool {
matches!( matches!(
@@ -162,11 +163,14 @@ impl RetryExecutor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
use axum::{http::StatusCode, response::IntoResponse};
use super::*; use super::*;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
fn base_retry_config() -> RetryConfig { fn base_retry_config() -> RetryConfig {
RetryConfig { RetryConfig {

View File

@@ -1,5 +1,8 @@
use std::sync::Arc; use std::{
use std::time::{Duration, Instant}; sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::{Mutex, Notify}; use tokio::sync::{Mutex, Notify};
use tracing::{debug, trace}; use tracing::{debug, trace};

View File

@@ -1,19 +1,27 @@
use super::{CircuitBreaker, WorkerError, WorkerResult}; use std::{
use crate::core::CircuitState; fmt,
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder}; sync::{
use crate::grpc_client::SglangSchedulerClient; atomic::{AtomicBool, AtomicUsize, Ordering},
use crate::metrics::RouterMetrics; Arc, LazyLock,
use crate::protocols::worker_spec::WorkerInfo; },
time::{Duration, Instant},
};
use async_trait::async_trait; use async_trait::async_trait;
use futures; use futures;
use serde_json; use serde_json;
use std::fmt; use tokio::{
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; sync::{Mutex, RwLock},
use std::sync::{Arc, LazyLock}; time,
use std::time::Duration; };
use std::time::Instant;
use tokio::sync::{Mutex, RwLock}; use super::{CircuitBreaker, WorkerError, WorkerResult};
use tokio::time; use crate::{
core::{BasicWorkerBuilder, CircuitState, DPAwareWorkerBuilder},
grpc_client::SglangSchedulerClient,
metrics::RouterMetrics,
protocols::worker_spec::WorkerInfo,
};
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| { static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder() reqwest::Client::builder()
@@ -1024,10 +1032,10 @@ pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{thread, time::Duration};
use super::*; use super::*;
use crate::core::CircuitBreakerConfig; use crate::core::CircuitBreakerConfig;
use std::thread;
use std::time::Duration;
#[test] #[test]
fn test_worker_type_display() { fn test_worker_type_display() {
@@ -1502,9 +1510,10 @@ mod tests {
#[test] #[test]
fn test_load_counter_performance() { fn test_load_counter_performance() {
use crate::core::BasicWorkerBuilder;
use std::time::Instant; use std::time::Instant;
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080") let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.build(); .build();

View File

@@ -1,9 +1,12 @@
use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; use std::collections::HashMap;
use super::worker::{
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType, use super::{
circuit_breaker::{CircuitBreaker, CircuitBreakerConfig},
worker::{
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType,
},
}; };
use crate::grpc_client::SglangSchedulerClient; use crate::grpc_client::SglangSchedulerClient;
use std::collections::HashMap;
/// Builder for creating BasicWorker instances with fluent API /// Builder for creating BasicWorker instances with fluent API
pub struct BasicWorkerBuilder { pub struct BasicWorkerBuilder {
@@ -100,6 +103,7 @@ impl BasicWorkerBuilder {
atomic::{AtomicBool, AtomicUsize}, atomic::{AtomicBool, AtomicUsize},
Arc, Arc,
}; };
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
let bootstrap_host = match url::Url::parse(&self.url) { let bootstrap_host = match url::Url::parse(&self.url) {
@@ -282,9 +286,10 @@ impl DPAwareWorkerBuilder {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::time::Duration;
use super::*; use super::*;
use crate::core::worker::Worker; use crate::core::worker::Worker;
use std::time::Duration;
#[test] #[test]
fn test_basic_worker_builder_minimal() { fn test_basic_worker_builder_minimal() {

View File

@@ -3,31 +3,35 @@
//! Handles all aspects of worker lifecycle including discovery, initialization, //! Handles all aspects of worker lifecycle including discovery, initialization,
//! runtime management, and health monitoring. //! runtime management, and health monitoring.
use crate::config::types::{ use std::{collections::HashMap, sync::Arc, time::Duration};
CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode,
HealthCheckConfig, RouterConfig, RoutingMode,
};
use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig,
Worker, WorkerFactory, WorkerRegistry, WorkerType,
};
use crate::grpc_client::SglangSchedulerClient;
use crate::policies::PolicyRegistry;
use crate::protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
};
use crate::server::AppContext;
use futures::future; use futures::future;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use tokio::{
use std::sync::Arc; sync::{watch, Mutex},
use std::time::Duration; task::JoinHandle,
use tokio::sync::{watch, Mutex}; };
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::{
config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode,
HealthCheckConfig, RouterConfig, RoutingMode,
},
core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder,
HealthConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType,
},
grpc_client::SglangSchedulerClient,
policies::PolicyRegistry,
protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
},
server::AppContext,
};
static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| { static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder() reqwest::Client::builder()
.timeout(Duration::from_secs(10)) .timeout(Duration::from_secs(10))
@@ -1803,9 +1807,10 @@ impl Drop for LoadMonitor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use std::collections::HashMap; use std::collections::HashMap;
use super::*;
#[test] #[test]
fn test_parse_server_info() { fn test_parse_server_info() {
let json = serde_json::json!({ let json = serde_json::json!({

View File

@@ -2,11 +2,13 @@
//! //!
//! Provides centralized registry for workers with model-based indexing //! Provides centralized registry for workers with model-based indexing
use crate::core::{ConnectionMode, Worker, WorkerType};
use dashmap::DashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use dashmap::DashMap;
use uuid::Uuid; use uuid::Uuid;
use crate::core::{ConnectionMode, Worker, WorkerType};
/// Unique identifier for a worker /// Unique identifier for a worker
#[derive(Debug, Clone, Hash, Eq, PartialEq)] #[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct WorkerId(String); pub struct WorkerId(String);
@@ -363,8 +365,10 @@ impl WorkerRegistry {
/// Start a health checker for all workers in the registry /// Start a health checker for all workers in the registry
/// This should be called once after the registry is populated with workers /// This should be called once after the registry is populated with workers
pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker { pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker {
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{
use std::sync::Arc; atomic::{AtomicBool, Ordering},
Arc,
};
let shutdown = Arc::new(AtomicBool::new(false)); let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone(); let shutdown_clone = shutdown.clone();
@@ -433,9 +437,10 @@ pub struct WorkerRegistryStats {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashMap;
use super::*; use super::*;
use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig}; use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig};
use std::collections::HashMap;
#[test] #[test]
fn test_worker_registry() { fn test_worker_registry() {

View File

@@ -1,14 +1,18 @@
use std::collections::{BTreeMap, HashMap}; use std::{
use std::sync::RwLock; collections::{BTreeMap, HashMap},
sync::RwLock,
};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use super::conversation_items::{ use super::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams, conversation_items::{
Result, SortOrder, make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams,
Result, SortOrder,
},
conversations::ConversationId,
}; };
use super::conversations::ConversationId;
#[derive(Default)] #[derive(Default)]
pub struct MemoryConversationItemStorage { pub struct MemoryConversationItemStorage {
@@ -190,9 +194,10 @@ impl ConversationItemStorage for MemoryConversationItemStorage {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use chrono::{TimeZone, Utc}; use chrono::{TimeZone, Utc};
use super::*;
fn make_item( fn make_item(
item_type: &str, item_type: &str,
role: Option<&str>, role: Option<&str>,

View File

@@ -1,18 +1,21 @@
use crate::config::OracleConfig; use std::{path::Path, sync::Arc, time::Duration};
use crate::data_connector::conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
};
use crate::data_connector::conversations::ConversationId;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::sql_type::ToSql; use oracle::{sql_type::ToSql, Connection};
use oracle::Connection;
use serde_json::Value; use serde_json::Value;
use std::path::Path;
use std::sync::Arc; use crate::{
use std::time::Duration; config::OracleConfig,
data_connector::{
conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
},
conversations::ConversationId,
},
};
#[derive(Clone)] #[derive(Clone)]
pub struct OracleConversationItemStorage { pub struct OracleConversationItemStorage {

View File

@@ -1,10 +1,13 @@
use std::{
fmt::{Display, Formatter},
sync::Arc,
};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use rand::RngCore; use rand::RngCore;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use super::conversations::ConversationId; use super::conversations::ConversationId;

View File

@@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::conversations::{ use super::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation, Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation,

View File

@@ -1,16 +1,18 @@
use crate::config::OracleConfig; use std::{path::Path, sync::Arc, time::Duration};
use crate::data_connector::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result,
};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::Utc; use chrono::Utc;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{sql_type::OracleType, Connection}; use oracle::{sql_type::OracleType, Connection};
use serde_json::Value; use serde_json::Value;
use std::path::Path;
use std::sync::Arc; use crate::{
use std::time::Duration; config::OracleConfig,
data_connector::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result,
},
};
#[derive(Clone)] #[derive(Clone)]
pub struct OracleConversationStorage { pub struct OracleConversationStorage {

View File

@@ -1,10 +1,13 @@
use std::{
fmt::{Display, Formatter},
sync::Arc,
};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use rand::RngCore; use rand::RngCore;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map as JsonMap, Value}; use serde_json::{Map as JsonMap, Value};
use std::fmt::{Display, Formatter};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct ConversationId(pub String); pub struct ConversationId(pub String);

View File

@@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse}; use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse};

View File

@@ -1,16 +1,17 @@
use crate::config::OracleConfig; use std::{collections::HashMap, path::Path, sync::Arc, time::Duration};
use crate::data_connector::responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
StoredResponse,
};
use async_trait::async_trait; use async_trait::async_trait;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{Connection, Row}; use oracle::{Connection, Row};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use std::path::Path; use crate::{
use std::sync::Arc; config::OracleConfig,
use std::time::Duration; data_connector::responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
StoredResponse,
},
};
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \ const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses"; tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses";
@@ -510,9 +511,10 @@ impl OracleErrorExt for ResponseStorageError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use serde_json::json; use serde_json::json;
use super::*;
#[test] #[test]
fn parse_tool_calls_handles_empty_input() { fn parse_tool_calls_handles_empty_input() {
assert!(parse_tool_calls(None).unwrap().is_empty()); assert!(parse_tool_calls(None).unwrap().is_empty());

View File

@@ -1,8 +1,8 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
/// Response identifier /// Response identifier
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]

View File

@@ -1,16 +1,23 @@
use std::convert::TryFrom; use std::{
use std::pin::Pin; convert::TryFrom,
use std::sync::atomic::{AtomicBool, Ordering}; pin::Pin,
use std::sync::Arc; sync::{
use std::task::{Context, Poll}; atomic::{AtomicBool, Ordering},
use std::time::Duration; Arc,
},
task::{Context, Poll},
time::Duration,
};
use tonic::{transport::Channel, Request, Streaming}; use tonic::{transport::Channel, Request, Streaming};
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::protocols::chat::ChatCompletionRequest; use crate::protocols::{
use crate::protocols::common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}; chat::ChatCompletionRequest,
use crate::protocols::generate::GenerateRequest; common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
use crate::protocols::sampling_params::SamplingParams as GenerateSamplingParams; generate::GenerateRequest,
sampling_params::SamplingParams as GenerateSamplingParams,
};
// Include the generated protobuf code // Include the generated protobuf code
pub mod proto { pub mod proto {

View File

@@ -1,12 +1,14 @@
use std::path::PathBuf; use std::path::PathBuf;
use tracing::Level; use tracing::Level;
use tracing_appender::non_blocking::WorkerGuard; use tracing_appender::{
use tracing_appender::rolling::{RollingFileAppender, Rotation}; non_blocking::WorkerGuard,
rolling::{RollingFileAppender, Rotation},
};
use tracing_log::LogTracer; use tracing_log::LogTracer;
use tracing_subscriber::fmt::time::ChronoUtc; use tracing_subscriber::{
use tracing_subscriber::layer::SubscriberExt; fmt::time::ChronoUtc, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer,
use tracing_subscriber::util::SubscriberInitExt; };
use tracing_subscriber::{EnvFilter, Layer};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LoggingConfig { pub struct LoggingConfig {

View File

@@ -1,14 +1,17 @@
use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
};
use sglang_router_rs::metrics::PrometheusConfig;
use sglang_router_rs::server::{self, ServerConfig};
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
use std::collections::HashMap; use std::collections::HashMap;
use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
},
metrics::PrometheusConfig,
server::{self, ServerConfig},
service_discovery::ServiceDiscoveryConfig,
};
fn parse_prefill_args() -> Vec<(String, Option<u16>)> { fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
let mut prefill_entries = Vec::new(); let mut prefill_entries = Vec::new();

View File

@@ -1,3 +1,5 @@
use std::{borrow::Cow, collections::HashMap, time::Duration};
use backoff::ExponentialBackoffBuilder; use backoff::ExponentialBackoffBuilder;
use dashmap::DashMap; use dashmap::DashMap;
use rmcp::{ use rmcp::{
@@ -13,7 +15,6 @@ use rmcp::{
RoleClient, ServiceExt, RoleClient, ServiceExt,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, time::Duration};
use crate::mcp::{ use crate::mcp::{
config::{McpConfig, McpServerConfig, McpTransport}, config::{McpConfig, McpServerConfig, McpTransport},

View File

@@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpConfig { pub struct McpConfig {
pub servers: Vec<McpServerConfig>, pub servers: Vec<McpServerConfig>,

View File

@@ -1,5 +1,7 @@
// OAuth authentication support for MCP servers // OAuth authentication support for MCP servers
use std::{net::SocketAddr, sync::Arc};
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
response::Html, response::Html,
@@ -8,7 +10,6 @@ use axum::{
}; };
use rmcp::transport::auth::OAuthState; use rmcp::transport::auth::OAuthState;
use serde::Deserialize; use serde::Deserialize;
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::{oneshot, Mutex}; use tokio::sync::{oneshot, Mutex};
use crate::mcp::error::{McpError, McpResult}; use crate::mcp::error::{McpError, McpResult};

View File

@@ -1,7 +1,10 @@
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
time::Duration,
};
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PrometheusConfig { pub struct PrometheusConfig {
@@ -620,9 +623,10 @@ impl TokenizerMetrics {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use std::net::TcpListener; use std::net::TcpListener;
use super::*;
#[test] #[test]
fn test_prometheus_config_default() { fn test_prometheus_config_default() {
let config = PrometheusConfig::default(); let config = PrometheusConfig::default();
@@ -912,9 +916,13 @@ mod tests {
#[test] #[test]
fn test_concurrent_metric_updates() { fn test_concurrent_metric_updates() {
use std::sync::atomic::{AtomicBool, Ordering}; use std::{
use std::sync::Arc; sync::{
use std::thread; atomic::{AtomicBool, Ordering},
Arc,
},
thread,
};
let done = Arc::new(AtomicBool::new(false)); let done = Arc::new(AtomicBool::new(false));
let mut handles = vec![]; let mut handles = vec![];

View File

@@ -1,12 +1,19 @@
use std::{
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use axum::{ use axum::{
body::Body, extract::Request, extract::State, http::header, http::HeaderValue, body::Body,
http::StatusCode, middleware::Next, response::IntoResponse, response::Response, extract::{Request, State},
http::{header, HeaderValue, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
}; };
use rand::Rng; use rand::Rng;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tower::{Layer, Service}; use tower::{Layer, Service};
@@ -14,9 +21,7 @@ use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use tracing::{debug, error, field::Empty, info, info_span, warn, Span}; use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
pub use crate::core::token_bucket::TokenBucket; pub use crate::core::token_bucket::TokenBucket;
use crate::{metrics::RouterMetrics, server::AppState};
use crate::metrics::RouterMetrics;
use crate::server::AppState;
#[derive(Clone)] #[derive(Clone)]
pub struct AuthConfig { pub struct AuthConfig {

View File

@@ -59,17 +59,15 @@
during the next eviction cycle. during the next eviction cycle.
*/ */
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy}; use std::{sync::Arc, thread, time::Duration};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use crate::tree::Tree;
use dashmap::DashMap; use dashmap::DashMap;
use rand::Rng; use rand::Rng;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use tracing::debug; use tracing::debug;
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics, tree::Tree};
/// Cache-aware routing policy /// Cache-aware routing policy
/// ///
/// Routes requests based on cache affinity when load is balanced, /// Routes requests based on cache affinity when load is balanced,

View File

@@ -1,11 +1,12 @@
//! Factory for creating load balancing policies //! Factory for creating load balancing policies
use std::sync::Arc;
use super::{ use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy, CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy, RoundRobinPolicy,
}; };
use crate::config::PolicyConfig; use crate::config::PolicyConfig;
use std::sync::Arc;
/// Factory for creating policy instances /// Factory for creating policy instances
pub struct PolicyFactory; pub struct PolicyFactory;

View File

@@ -3,9 +3,9 @@
//! This module provides a unified abstraction for routing policies that work //! This module provides a unified abstraction for routing policies that work
//! across both regular and prefill-decode (PD) routing modes. //! across both regular and prefill-decode (PD) routing modes.
use std::{fmt::Debug, sync::Arc};
use crate::core::Worker; use crate::core::Worker;
use std::fmt::Debug;
use std::sync::Arc;
mod cache_aware; mod cache_aware;
mod factory; mod factory;

View File

@@ -1,13 +1,16 @@
//! Power-of-two choices load balancing policy //! Power-of-two choices load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy}; use std::{
use crate::core::Worker; collections::HashMap,
use crate::metrics::RouterMetrics; sync::{Arc, RwLock},
};
use rand::Rng; use rand::Rng;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::info; use tracing::info;
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics};
/// Power-of-two choices policy /// Power-of-two choices policy
/// ///
/// Randomly selects two workers and routes to the one with lower load. /// Randomly selects two workers and routes to the one with lower load.

View File

@@ -1,11 +1,12 @@
//! Random load balancing policy //! Random load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
use std::sync::Arc; use std::sync::Arc;
use rand::Rng;
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics};
/// Random selection policy /// Random selection policy
/// ///
/// Selects workers randomly with uniform distribution among healthy workers. /// Selects workers randomly with uniform distribution among healthy workers.
@@ -50,9 +51,10 @@ impl LoadBalancingPolicy for RandomPolicy {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashMap;
use super::*; use super::*;
use crate::core::{BasicWorkerBuilder, WorkerType}; use crate::core::{BasicWorkerBuilder, WorkerType};
use std::collections::HashMap;
#[test] #[test]
fn test_random_selection() { fn test_random_selection() {

View File

@@ -1,3 +1,10 @@
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tracing::{debug, info, warn};
/// Policy Registry for managing model-to-policy mappings /// Policy Registry for managing model-to-policy mappings
/// ///
/// This registry manages the dynamic assignment of load balancing policies to models. /// This registry manages the dynamic assignment of load balancing policies to models.
@@ -8,11 +15,7 @@ use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy, CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy, RoundRobinPolicy,
}; };
use crate::config::types::PolicyConfig; use crate::{config::types::PolicyConfig, core::Worker};
use crate::core::Worker;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn};
/// Registry for managing model-to-policy mappings /// Registry for managing model-to-policy mappings
#[derive(Clone)] #[derive(Clone)]

View File

@@ -1,10 +1,12 @@
//! Round-robin load balancing policy //! Round-robin load balancing policy
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use super::{get_healthy_worker_indices, LoadBalancingPolicy}; use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker; use crate::{core::Worker, metrics::RouterMetrics};
use crate::metrics::RouterMetrics;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
/// Round-robin selection policy /// Round-robin selection policy
/// ///

View File

@@ -1,10 +1,13 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use validator::Validate; use validator::Validate;
use super::common::*; use super::{
use super::sampling_params::{validate_top_k_value, validate_top_p_value}; common::*,
sampling_params::{validate_top_k_value, validate_top_p_value},
};
use crate::protocols::validated::Normalizable; use crate::protocols::validated::Normalizable;
// ============================================================================ // ============================================================================
@@ -532,11 +535,12 @@ impl Normalizable for ChatCompletionRequest {
// Apply tool_choice defaults // Apply tool_choice defaults
if self.tool_choice.is_none() { if self.tool_choice.is_none() {
if let Some(tools) = &self.tools { if let Some(tools) = &self.tools {
self.tool_choice = if !tools.is_empty() { let choice_value = if !tools.is_empty() {
Some(ToolChoice::Value(ToolChoiceValue::Auto)) ToolChoiceValue::Auto
} else { } else {
Some(ToolChoice::Value(ToolChoiceValue::None)) ToolChoiceValue::None
}; };
self.tool_choice = Some(ToolChoice::Value(choice_value));
} }
// If tools is None, leave tool_choice as None (don't set it) // If tools is None, leave tool_choice as None (don't set it)
} }

View File

@@ -1,6 +1,7 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
// ============================================================================ // ============================================================================
// Default value helpers // Default value helpers

View File

@@ -1,6 +1,7 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use std::collections::HashMap;
use super::common::*; use super::common::*;

View File

@@ -1,10 +1,13 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use validator::Validate; use validator::Validate;
use super::common::{default_true, GenerationRequest, InputIds}; use super::{
use super::sampling_params::SamplingParams; common::{default_true, GenerationRequest, InputIds},
sampling_params::SamplingParams,
};
use crate::protocols::validated::Normalizable; use crate::protocols::validated::Normalizable;
// ============================================================================ // ============================================================================

View File

@@ -1,6 +1,7 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use validator::Validate; use validator::Validate;
use super::common::{default_model, default_true, GenerationRequest, StringOrArray, UsageInfo}; use super::common::{default_model, default_true, GenerationRequest, StringOrArray, UsageInfo};

View File

@@ -1,9 +1,10 @@
// OpenAI Responses API types // OpenAI Responses API types
// https://platform.openai.com/docs/api-reference/responses // https://platform.openai.com/docs/api-reference/responses
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
// Import shared types from common module // Import shared types from common module
use super::common::{ use super::common::{

View File

@@ -117,10 +117,11 @@ impl<T> std::ops::DerefMut for ValidatedJson<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::Validate; use validator::Validate;
use super::*;
#[derive(Debug, Deserialize, Serialize, Validate)] #[derive(Debug, Deserialize, Serialize, Validate)]
struct TestRequest { struct TestRequest {
#[validate(range(min = 0.0, max = 1.0))] #[validate(range(min = 0.0, max = 1.0))]

View File

@@ -2,9 +2,10 @@
//! //!
//! Defines the request/response structures for worker management endpoints //! Defines the request/response structures for worker management endpoints
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use serde::{Deserialize, Serialize};
/// Worker configuration for API requests /// Worker configuration for API requests
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerConfigRequest { pub struct WorkerConfigRequest {

View File

@@ -1,16 +1,20 @@
// Factory and registry for creating model-specific reasoning parsers. // Factory and registry for creating model-specific reasoning parsers.
// Now with parser pooling support for efficient reuse across requests. // Now with parser pooling support for efficient reuse across requests.
use std::collections::HashMap; use std::{
use std::sync::{Arc, RwLock}; collections::HashMap,
sync::{Arc, RwLock},
};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::reasoning_parser::parsers::{ use crate::reasoning_parser::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, parsers::{
QwenThinkingParser, Step3Parser, BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser,
},
traits::{ParseError, ParserConfig, ReasoningParser},
}; };
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
/// Type alias for pooled parser instances. /// Type alias for pooled parser instances.
/// Uses tokio::Mutex to avoid blocking the async executor. /// Uses tokio::Mutex to avoid blocking the async executor.
@@ -402,8 +406,10 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 8)] #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn test_high_concurrency_parser_access() { async fn test_high_concurrency_parser_access() {
use std::sync::atomic::{AtomicUsize, Ordering}; use std::{
use std::time::Instant; sync::atomic::{AtomicUsize, Ordering},
time::Instant,
};
let factory = ParserFactory::new(); let factory = ParserFactory::new();
let num_tasks = 100; let num_tasks = 100;

View File

@@ -2,8 +2,10 @@
// This parser starts with in_reasoning=true, assuming all text is reasoning // This parser starts with in_reasoning=true, assuming all text is reasoning
// until an end token is encountered. // until an end token is encountered.
use crate::reasoning_parser::parsers::BaseReasoningParser; use crate::reasoning_parser::{
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// DeepSeek-R1 reasoning parser. /// DeepSeek-R1 reasoning parser.
/// ///

View File

@@ -1,8 +1,10 @@
// GLM45 specific reasoning parser. // GLM45 specific reasoning parser.
// Uses the same format as Qwen3 but has its own implementation for debugging. // Uses the same format as Qwen3 but has its own implementation for debugging.
use crate::reasoning_parser::parsers::BaseReasoningParser; use crate::reasoning_parser::{
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// GLM45 reasoning parser. /// GLM45 reasoning parser.
/// ///

View File

@@ -1,8 +1,10 @@
// Kimi specific reasoning parser. // Kimi specific reasoning parser.
// This parser uses Unicode tokens and starts with in_reasoning=false. // This parser uses Unicode tokens and starts with in_reasoning=false.
use crate::reasoning_parser::parsers::BaseReasoningParser; use crate::reasoning_parser::{
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// Kimi reasoning parser. /// Kimi reasoning parser.
/// ///

View File

@@ -2,8 +2,10 @@
// This parser starts with in_reasoning=false, requiring an explicit // This parser starts with in_reasoning=false, requiring an explicit
// start token to enter reasoning mode. // start token to enter reasoning mode.
use crate::reasoning_parser::parsers::BaseReasoningParser; use crate::reasoning_parser::{
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// Qwen3 reasoning parser. /// Qwen3 reasoning parser.
/// ///

View File

@@ -1,8 +1,10 @@
// Step3 specific reasoning parser. // Step3 specific reasoning parser.
// Uses the same format as DeepSeek-R1 but has its own implementation for debugging. // Uses the same format as DeepSeek-R1 but has its own implementation for debugging.
use crate::reasoning_parser::parsers::BaseReasoningParser; use crate::reasoning_parser::{
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// Step3 reasoning parser. /// Step3 reasoning parser.
/// ///

View File

@@ -1,16 +1,18 @@
//! Factory for creating router instances //! Factory for creating router instances
use super::grpc::pd_router::GrpcPDRouter; use std::sync::Arc;
use super::grpc::router::GrpcRouter;
use super::{ use super::{
grpc::{pd_router::GrpcPDRouter, router::GrpcRouter},
http::{pd_router::PDRouter, router::Router}, http::{pd_router::PDRouter, router::Router},
openai::OpenAIRouter, openai::OpenAIRouter,
RouterTrait, RouterTrait,
}; };
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode}; use crate::{
use crate::policies::PolicyFactory; config::{ConnectionMode, PolicyConfig, RoutingMode},
use crate::server::AppContext; policies::PolicyFactory,
use std::sync::Arc; server::AppContext,
};
/// Factory for creating router instances based on configuration /// Factory for creating router instances based on configuration
pub struct RouterFactory; pub struct RouterFactory;

View File

@@ -4,20 +4,22 @@
//! eliminating deep parameter passing chains and providing a single source of truth //! eliminating deep parameter passing chains and providing a single source of truth
//! for request state. //! for request state.
use std::collections::HashMap; use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use axum::http::HeaderMap; use axum::http::HeaderMap;
use serde_json::Value; use serde_json::Value;
use crate::core::Worker; use crate::{
use crate::grpc_client::{proto, SglangSchedulerClient}; core::Worker,
use crate::protocols::chat::{ChatCompletionRequest, ChatCompletionResponse}; grpc_client::{proto, SglangSchedulerClient},
use crate::protocols::generate::{GenerateRequest, GenerateResponse}; protocols::{
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; chat::{ChatCompletionRequest, ChatCompletionResponse},
use crate::tokenizer::stop::StopSequenceDecoder; generate::{GenerateRequest, GenerateResponse},
use crate::tokenizer::traits::Tokenizer; },
use crate::tool_parser::ParserFactory as ToolParserFactory; reasoning_parser::ParserFactory as ReasoningParserFactory,
tokenizer::{stop::StopSequenceDecoder, traits::Tokenizer},
tool_parser::ParserFactory as ToolParserFactory,
};
// ============================================================================ // ============================================================================
// Core Context Types // Core Context Types

View File

@@ -1,7 +1,6 @@
//! gRPC router implementations //! gRPC router implementations
use crate::grpc_client::proto; use crate::{grpc_client::proto, protocols::common::StringOrArray};
use crate::protocols::common::StringOrArray;
pub mod context; pub mod context;
pub mod pd_router; pub mod pd_router;

View File

@@ -1,19 +1,7 @@
// PD (Prefill-Decode) gRPC Router Implementation // PD (Prefill-Decode) gRPC Router Implementation
use crate::config::types::RetryConfig; use std::sync::Arc;
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
use crate::policies::PolicyRegistry;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::routers::RouterTrait;
use crate::server::AppContext;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
@@ -21,12 +9,27 @@ use axum::{
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use std::sync::Arc;
use tracing::debug; use tracing::debug;
use super::context::SharedComponents; use super::{context::SharedComponents, pipeline::RequestPipeline};
use super::pipeline::RequestPipeline; use crate::{
config::types::RetryConfig,
core::{ConnectionMode, WorkerRegistry, WorkerType},
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::RouterTrait,
server::AppContext,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
/// gRPC PD (Prefill-Decode) router implementation for SGLang /// gRPC PD (Prefill-Decode) router implementation for SGLang
#[derive(Clone)] #[derive(Clone)]

View File

@@ -3,29 +3,29 @@
//! This module defines the core pipeline abstraction and individual processing stages //! This module defines the core pipeline abstraction and individual processing stages
//! that transform a RequestContext through its lifecycle. //! that transform a RequestContext through its lifecycle.
use std::{
sync::Arc,
time::{Instant, SystemTime, UNIX_EPOCH},
};
use async_trait::async_trait; use async_trait::async_trait;
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use tracing::{debug, error, warn};
use super::context::*;
use super::processing;
use super::streaming;
use super::utils;
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
use crate::grpc_client::proto;
use crate::policies::PolicyRegistry;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::common::InputIds;
use crate::protocols::generate::GenerateRequest;
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory;
use proto::DisaggregatedParams; use proto::DisaggregatedParams;
use rand::Rng; use rand::Rng;
use std::sync::Arc; use tracing::{debug, error, warn};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use uuid::Uuid; use uuid::Uuid;
use super::{context::*, processing, streaming, utils};
use crate::{
core::{ConnectionMode, Worker, WorkerRegistry, WorkerType},
grpc_client::proto,
policies::PolicyRegistry,
protocols::{chat::ChatCompletionRequest, common::InputIds, generate::GenerateRequest},
reasoning_parser::ParserFactory as ReasoningParserFactory,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
// ============================================================================ // ============================================================================
// Pipeline Trait // Pipeline Trait
// ============================================================================ // ============================================================================

View File

@@ -3,28 +3,30 @@
//! This module contains response processing functions that are shared between //! This module contains response processing functions that are shared between
//! the regular router and PD router, eliminating ~1,200 lines of exact duplicates. //! the regular router and PD router, eliminating ~1,200 lines of exact duplicates.
use std::sync::Arc; use std::{sync::Arc, time::Instant};
use proto::generate_complete::MatchedStop;
use serde_json::Value; use serde_json::Value;
use tracing::error; use tracing::error;
use crate::grpc_client::proto; use super::{
use crate::protocols::chat::{ context::{DispatchMetadata, ExecutionResult},
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, utils,
}; };
use crate::protocols::common::{ use crate::{
FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage, grpc_client::proto,
protocols::{
chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse},
common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage},
generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
tokenizer::{
stop::{SequenceDecoderOutput, StopSequenceDecoder},
traits::Tokenizer,
},
tool_parser::ParserFactory as ToolParserFactory,
}; };
use crate::protocols::generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse};
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory;
use proto::generate_complete::MatchedStop;
use std::time::Instant;
use super::context::{DispatchMetadata, ExecutionResult};
use super::utils;
// ============================================================================ // ============================================================================
// Response Processor - Main Entry Point // Response Processor - Main Entry Point

View File

@@ -11,23 +11,25 @@ use axum::{
}; };
use tracing::debug; use tracing::debug;
use crate::config::types::RetryConfig; use super::{context::SharedComponents, pipeline::RequestPipeline};
use crate::core::WorkerRegistry; use crate::{
use crate::policies::PolicyRegistry; config::types::RetryConfig,
use crate::protocols::chat::ChatCompletionRequest; core::WorkerRegistry,
use crate::protocols::completion::CompletionRequest; policies::PolicyRegistry,
use crate::protocols::embedding::EmbeddingRequest; protocols::{
use crate::protocols::generate::GenerateRequest; chat::ChatCompletionRequest,
use crate::protocols::rerank::RerankRequest; completion::CompletionRequest,
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; embedding::EmbeddingRequest,
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; generate::GenerateRequest,
use crate::routers::RouterTrait; rerank::RerankRequest,
use crate::server::AppContext; responses::{ResponsesGetParams, ResponsesRequest},
use crate::tokenizer::traits::Tokenizer; },
use crate::tool_parser::ParserFactory as ToolParserFactory; reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::RouterTrait,
use super::context::SharedComponents; server::AppContext,
use super::pipeline::RequestPipeline; tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
#[derive(Clone)] #[derive(Clone)]

View File

@@ -3,38 +3,40 @@
//! This module contains shared streaming logic for both Regular and PD routers, //! This module contains shared streaming logic for both Regular and PD routers,
//! eliminating ~600 lines of duplication. //! eliminating ~600 lines of duplication.
use axum::response::Response; use std::{collections::HashMap, io, sync::Arc, time::Instant};
use axum::{body::Body, http::StatusCode};
use axum::{body::Body, http::StatusCode, response::Response};
use bytes::Bytes; use bytes::Bytes;
use http::header::{HeaderValue, CONTENT_TYPE}; use http::header::{HeaderValue, CONTENT_TYPE};
use proto::{
generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId},
generate_response::Response::{Chunk, Complete, Error},
};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap; use tokio::sync::{mpsc, mpsc::UnboundedSender};
use std::io; use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
use super::context; use super::{context, utils};
use super::utils; use crate::{
use crate::grpc_client::proto; grpc_client::proto,
use crate::protocols::chat::{ protocols::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, chat::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
},
common::{
ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice,
ToolChoiceValue, Usage,
},
generate::GenerateRequest,
},
reasoning_parser::ReasoningParser,
tokenizer::{
stop::{SequenceDecoderOutput, StopSequenceDecoder},
traits::Tokenizer,
},
tool_parser::ToolParser,
}; };
use crate::protocols::common::{
ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice,
ToolChoiceValue, Usage,
};
use crate::protocols::generate::GenerateRequest;
use crate::reasoning_parser::ReasoningParser;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParser;
use proto::generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId};
use proto::generate_response::Response::{Chunk, Complete, Error};
use std::time::Instant;
use tokio::sync::mpsc;
/// Shared streaming processor for both single and dual dispatch modes /// Shared streaming processor for both single and dual dispatch modes
#[derive(Clone)] #[derive(Clone)]

View File

@@ -1,19 +1,7 @@
//! Shared utilities for gRPC routers //! Shared utilities for gRPC routers
use super::ProcessedMessages; use std::{collections::HashMap, sync::Arc};
use crate::core::Worker;
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::chat::{ChatCompletionRequest, ChatMessage};
use crate::protocols::common::{
ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall,
ToolChoice, ToolChoiceValue, TopLogProb,
};
use crate::protocols::generate::GenerateFinishReason;
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer;
pub use crate::tokenizer::StopSequenceDecoder;
use axum::{ use axum::{
http::StatusCode, http::StatusCode,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
@@ -21,11 +9,29 @@ use axum::{
}; };
use futures::StreamExt; use futures::StreamExt;
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{error, warn}; use tracing::{error, warn};
use uuid::Uuid; use uuid::Uuid;
use super::ProcessedMessages;
pub use crate::tokenizer::StopSequenceDecoder;
use crate::{
core::Worker,
grpc_client::{proto, sglang_scheduler::AbortOnDropStream, SglangSchedulerClient},
protocols::{
chat::{ChatCompletionRequest, ChatMessage},
common::{
ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall,
ToolChoice, ToolChoiceValue, TopLogProb,
},
generate::GenerateFinishReason,
},
tokenizer::{
chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
traits::Tokenizer,
HuggingFaceTokenizer,
},
};
/// Get gRPC client from worker, returning appropriate error response on failure /// Get gRPC client from worker, returning appropriate error response on failure
pub async fn get_grpc_client_from_worker( pub async fn get_grpc_client_from_worker(
worker: &Arc<dyn Worker>, worker: &Arc<dyn Worker>,
@@ -953,12 +959,17 @@ pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> Generate
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use crate::protocols::chat::{ChatMessage, UserMessageContent};
use crate::protocols::common::{ContentPart, ImageUrl};
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
use serde_json::json; use serde_json::json;
use super::*;
use crate::{
protocols::{
chat::{ChatMessage, UserMessageContent},
common::{ContentPart, ImageUrl},
},
tokenizer::chat_template::ChatTemplateContentFormat,
};
#[test] #[test]
fn test_transform_messages_string_format() { fn test_transform_messages_string_format() {
let messages = vec![ChatMessage::User { let messages = vec![ChatMessage::User {

View File

@@ -1,6 +1,4 @@
use axum::body::Body; use axum::{body::Body, extract::Request, http::HeaderMap};
use axum::extract::Request;
use axum::http::HeaderMap;
/// Copy request headers to a Vec of name-value string pairs /// Copy request headers to a Vec of name-value string pairs
/// Used for forwarding headers to backend workers /// Used for forwarding headers to backend workers

View File

@@ -1,19 +1,5 @@
use super::pd_types::api_path; use std::{sync::Arc, time::Instant};
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent};
use crate::protocols::common::{InputIds, StringOrArray};
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::header_utils;
use crate::routers::RouterTrait;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
@@ -25,11 +11,29 @@ use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde::Serialize; use serde::Serialize;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::sync::Arc;
use std::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
use super::pd_types::api_path;
use crate::{
config::types::RetryConfig,
core::{
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
},
metrics::RouterMetrics,
policies::{LoadBalancingPolicy, PolicyRegistry},
protocols::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
common::{InputIds, StringOrArray},
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::{header_utils, RouterTrait},
};
#[derive(Debug)] #[derive(Debug)]
pub struct PDRouter { pub struct PDRouter {
pub worker_registry: Arc<WorkerRegistry>, pub worker_registry: Arc<WorkerRegistry>,

View File

@@ -1,35 +1,39 @@
use crate::config::types::RetryConfig; use std::{sync::Arc, time::Instant};
use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::common::GenerationRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::{RerankRequest, RerankResponse, RerankResult};
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::header_utils;
use crate::routers::RouterTrait;
use axum::body::to_bytes;
use axum::{ use axum::{
body::Body, body::{to_bytes, Body},
extract::Request, extract::Request,
http::{ http::{
header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode, header::{CONTENT_LENGTH, CONTENT_TYPE},
HeaderMap, HeaderValue, Method, StatusCode,
}, },
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json, Json,
}; };
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use std::sync::Arc;
use std::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error}; use tracing::{debug, error};
use crate::{
config::types::RetryConfig,
core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
},
metrics::RouterMetrics,
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
common::GenerationRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::{RerankRequest, RerankResponse, RerankResult},
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::{header_utils, RouterTrait},
};
/// Regular router that uses injected load balancing policies /// Regular router that uses injected load balancing policies
#[derive(Debug)] #[derive(Debug)]
pub struct Router { pub struct Router {

View File

@@ -1,5 +1,7 @@
//! Router implementations //! Router implementations
use std::fmt::Debug;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
@@ -7,16 +9,17 @@ use axum::{
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use std::fmt::Debug;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use serde_json::Value; use serde_json::Value;
use crate::protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
};
pub mod factory; pub mod factory;
pub mod grpc; pub mod grpc;
pub mod header_utils; pub mod header_utils;
@@ -25,7 +28,6 @@ pub mod openai; // New refactored OpenAI router module
pub mod router_manager; pub mod router_manager;
pub use factory::RouterFactory; pub use factory::RouterFactory;
// Re-export HTTP routers for convenience // Re-export HTTP routers for convenience
pub use http::{pd_router, pd_types, router}; pub use http::{pd_router, pd_types, router};

View File

@@ -1,22 +1,26 @@
//! Conversation CRUD operations and persistence //! Conversation CRUD operations and persistence
use crate::data_connector::{ use std::{collections::HashMap, sync::Arc};
conversation_items::ListParams, conversation_items::SortOrder, Conversation, ConversationId,
ConversationItemId, ConversationItemStorage, ConversationStorage, NewConversation, use axum::{
NewConversationItem, ResponseId, ResponseStorage, SharedConversationItemStorage, http::StatusCode,
SharedConversationStorage, response::{IntoResponse, Response},
Json,
}; };
use crate::protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use chrono::Utc; use chrono::Utc;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use super::responses::build_stored_response; use super::responses::build_stored_response;
use crate::{
data_connector::{
conversation_items::{ListParams, SortOrder},
Conversation, ConversationId, ConversationItemId, ConversationItemStorage,
ConversationStorage, NewConversation, NewConversationItem, ResponseId, ResponseStorage,
SharedConversationItemStorage, SharedConversationStorage,
},
protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest},
};
/// Maximum number of properties allowed in conversation metadata /// Maximum number of properties allowed in conversation metadata
pub(crate) const MAX_METADATA_PROPERTIES: usize = 16; pub(crate) const MAX_METADATA_PROPERTIES: usize = 16;

View File

@@ -8,19 +8,20 @@
//! - Payload transformation for MCP tool interception //! - Payload transformation for MCP tool interception
//! - Metadata injection for MCP operations //! - Metadata injection for MCP operations
use crate::mcp::McpClientManager; use std::{io, sync::Arc};
use crate::protocols::responses::{
ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest,
};
use crate::routers::header_utils::apply_request_headers;
use axum::http::HeaderMap; use axum::http::HeaderMap;
use bytes::Bytes; use bytes::Bytes;
use serde_json::{json, to_value, Value}; use serde_json::{json, to_value, Value};
use std::{io, sync::Arc};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::{info, warn}; use tracing::{info, warn};
use super::utils::event_types; use super::utils::event_types;
use crate::{
mcp::McpClientManager,
protocols::responses::{ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest},
routers::header_utils::apply_request_headers,
};
// ============================================================================ // ============================================================================
// Configuration and State Types // Configuration and State Types

View File

@@ -1,12 +1,15 @@
//! Response storage, patching, and extraction utilities //! Response storage, patching, and extraction utilities
use crate::data_connector::{ResponseId, StoredResponse};
use crate::protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest};
use serde_json::{json, Value};
use std::collections::HashMap; use std::collections::HashMap;
use serde_json::{json, Value};
use tracing::warn; use tracing::warn;
use super::utils::event_types; use super::utils::event_types;
use crate::{
data_connector::{ResponseId, StoredResponse},
protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest},
};
// ============================================================================ // ============================================================================
// Response Storage Operations // Response Storage Operations

View File

@@ -1,21 +1,10 @@
//! OpenAI router - main coordinator that delegates to specialized modules //! OpenAI router - main coordinator that delegates to specialized modules
use crate::config::CircuitBreakerConfig; use std::{
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; any::Any,
use crate::data_connector::{ sync::{atomic::AtomicBool, Arc},
conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId,
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
}; };
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
ResponsesRequest,
};
use crate::routers::header_utils::apply_request_headers;
use axum::{ use axum::{
body::Body, body::Body,
extract::Request, extract::Request,
@@ -25,10 +14,6 @@ use axum::{
}; };
use futures_util::StreamExt; use futures_util::StreamExt;
use serde_json::{json, to_value, Value}; use serde_json::{json, to_value, Value};
use std::{
any::Any,
sync::{atomic::AtomicBool, Arc},
};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn; use tracing::warn;
@@ -39,12 +24,35 @@ use super::conversations::{
get_conversation, get_conversation_item, list_conversation_items, persist_conversation_items, get_conversation, get_conversation_item, list_conversation_items, persist_conversation_items,
update_conversation, update_conversation,
}; };
use super::mcp::{ use super::{
execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, mcp::{
McpLoopConfig, execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
McpLoopConfig,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json},
streaming::handle_streaming_response,
};
use crate::{
config::CircuitBreakerConfig,
core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig},
data_connector::{
conversation_items::{ListParams, SortOrder},
ConversationId, ResponseId, SharedConversationItemStorage, SharedConversationStorage,
SharedResponseStorage,
},
protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
ResponsesRequest,
},
},
routers::header_utils::apply_request_headers,
}; };
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json};
use super::streaming::handle_streaming_response;
// ============================================================================ // ============================================================================
// OpenAIRouter Struct // OpenAIRouter Struct

View File

@@ -7,11 +7,8 @@
//! - MCP tool execution loops within streaming responses //! - MCP tool execution loops within streaming responses
//! - Event transformation and output index remapping //! - Event transformation and output index remapping
use crate::data_connector::{ use std::{borrow::Cow, io, sync::Arc};
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
};
use crate::protocols::responses::{ResponseToolType, ResponsesRequest};
use crate::routers::header_utils::{apply_request_headers, preserve_response_headers};
use axum::{ use axum::{
body::Body, body::Body,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
@@ -20,20 +17,28 @@ use axum::{
use bytes::Bytes; use bytes::Bytes;
use futures_util::StreamExt; use futures_util::StreamExt;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::{borrow::Cow, io, sync::Arc};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn; use tracing::warn;
// Import from sibling modules // Import from sibling modules
use super::conversations::persist_conversation_items; use super::conversations::persist_conversation_items;
use super::mcp::{ use super::{
build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming, mcp::{
mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, send_mcp_list_tools_events, build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming,
McpLoopConfig, ToolLoopState, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
send_mcp_list_tools_events, McpLoopConfig, ToolLoopState,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block},
utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction},
};
use crate::{
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
protocols::responses::{ResponseToolType, ResponsesRequest},
routers::header_utils::{apply_request_headers, preserve_response_headers},
}; };
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block};
use super::utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction};
// ============================================================================ // ============================================================================
// Streaming Response Accumulator // Streaming Response Accumulator

View File

@@ -4,16 +4,8 @@
//! - Single Router Mode (enable_igw=false): Router owns workers directly //! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything //! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use crate::config::{ConnectionMode, RoutingMode}; use std::sync::Arc;
use crate::core::{WorkerRegistry, WorkerType};
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::RouterTrait;
use crate::server::{AppContext, ServerConfig};
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
@@ -23,9 +15,23 @@ use axum::{
}; };
use dashmap::DashMap; use dashmap::DashMap;
use serde_json::Value; use serde_json::Value;
use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::{
config::{ConnectionMode, RoutingMode},
core::{WorkerRegistry, WorkerType},
protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::RouterTrait,
server::{AppContext, ServerConfig},
};
#[derive(Debug, Clone, Hash, Eq, PartialEq)] #[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct RouterId(String); pub struct RouterId(String);

View File

@@ -1,3 +1,24 @@
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
},
time::Duration,
};
use axum::{
extract::{Path, Query, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
serve, Json, Router,
};
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};
use crate::{ use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
core::{ core::{
@@ -30,24 +51,6 @@ use crate::{
tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tool_parser::ParserFactory as ToolParserFactory, tool_parser::ParserFactory as ToolParserFactory,
}; };
use axum::{
extract::{Path, Query, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
serve, Json, Router,
};
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use std::sync::OnceLock;
use std::{
sync::atomic::{AtomicBool, Ordering},
sync::Arc,
time::Duration,
};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};
// //

View File

@@ -1,24 +1,25 @@
use crate::core::WorkerManager; use std::{
use crate::protocols::worker_spec::WorkerConfigRequest; collections::{HashMap, HashSet},
use crate::server::AppContext; sync::{Arc, Mutex},
time::Duration,
};
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
use k8s_openapi::api::core::v1::Pod; use k8s_openapi::api::core::v1::Pod;
use kube::{ use kube::{
api::Api, api::Api,
runtime::watcher::{watcher, Config}, runtime::{
runtime::WatchStreamExt, watcher::{watcher, Config},
WatchStreamExt,
},
Client, Client,
}; };
use std::collections::{HashMap, HashSet};
use rustls; use rustls;
use std::sync::{Arc, Mutex}; use tokio::{task, time};
use std::time::Duration;
use tokio::task;
use tokio::time;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::{core::WorkerManager, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ServiceDiscoveryConfig { pub struct ServiceDiscoveryConfig {
pub enabled: bool, pub enabled: bool,
@@ -452,10 +453,12 @@ async fn handle_pod_deletion(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use k8s_openapi::{
api::core::v1::{Pod, PodCondition, PodSpec, PodStatus},
apimachinery::pkg::apis::meta::v1::{ObjectMeta, Time},
};
use super::*; use super::*;
use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus};
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
fn create_k8s_pod( fn create_k8s_pod(
name: Option<&str>, name: Option<&str>,
@@ -535,8 +538,7 @@ mod tests {
} }
async fn create_test_app_context() -> Arc<AppContext> { async fn create_test_app_context() -> Arc<AppContext> {
use crate::config::RouterConfig; use crate::{config::RouterConfig, middleware::TokenBucket};
use crate::middleware::TokenBucket;
let router_config = RouterConfig { let router_config = RouterConfig {
worker_startup_timeout_secs: 1, worker_startup_timeout_secs: 1,

View File

@@ -3,12 +3,16 @@
//! This module provides functionality to apply chat templates to messages, //! This module provides functionality to apply chat templates to messages,
//! similar to HuggingFace transformers' apply_chat_template method. //! similar to HuggingFace transformers' apply_chat_template method.
use anyhow::{anyhow, Result};
use minijinja::machinery::ast::{Expr, Stmt};
use minijinja::{context, Environment, Value};
use serde_json;
use std::collections::HashMap; use std::collections::HashMap;
use anyhow::{anyhow, Result};
use minijinja::{
context,
machinery::ast::{Expr, Stmt},
Environment, Value,
};
use serde_json;
/// Chat template content format /// Chat template content format
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChatTemplateContentFormat { pub enum ChatTemplateContentFormat {
@@ -319,8 +323,10 @@ impl<'a> Detector<'a> {
/// AST-based detection using minijinja's unstable machinery /// AST-based detection using minijinja's unstable machinery
/// Single-pass detector with scope tracking /// Single-pass detector with scope tracking
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> { fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
use minijinja::machinery::{parse, WhitespaceConfig}; use minijinja::{
use minijinja::syntax::SyntaxConfig; machinery::{parse, WhitespaceConfig},
syntax::SyntaxConfig,
};
let ast = match parse( let ast = match parse(
template, template,

View File

@@ -1,13 +1,9 @@
use super::traits; use std::{fs::File, io::Read, path::Path, sync::Arc};
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::sync::Arc;
use tracing::{debug, info}; use tracing::{debug, info};
use super::huggingface::HuggingFaceTokenizer; use super::{huggingface::HuggingFaceTokenizer, tiktoken::TiktokenTokenizer, traits};
use super::tiktoken::TiktokenTokenizer;
use crate::tokenizer::hub::download_tokenizer_from_hf; use crate::tokenizer::hub::download_tokenizer_from_hf;
/// Represents the type of tokenizer being used /// Represents the type of tokenizer being used
@@ -379,8 +375,7 @@ pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())), Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
_ => { _ => {
// Try auto-detection // Try auto-detection
use std::fs::File; use std::{fs::File, io::Read};
use std::io::Read;
let mut file = File::open(file_path)?; let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; 512]; let mut buffer = vec![0u8; 512];

View File

@@ -1,6 +1,9 @@
use std::{
env,
path::{Path, PathBuf},
};
use hf_hub::api::tokio::ApiBuilder; use hf_hub::api::tokio::ApiBuilder;
use std::env;
use std::path::{Path, PathBuf};
const IGNORED: [&str; 5] = [ const IGNORED: [&str; 5] = [
".gitattributes", ".gitattributes",

View File

@@ -3,12 +3,12 @@ use std::collections::HashMap;
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use tokenizers::tokenizer::Tokenizer as HfTokenizer; use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::chat_template::{ use super::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, chat_template::{
ChatTemplateProcessor, detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
}; ChatTemplateProcessor,
use super::traits::{ },
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
}; };
/// HuggingFace tokenizer wrapper /// HuggingFace tokenizer wrapper

View File

@@ -1,9 +1,11 @@
//! Mock tokenizer implementation for testing //! Mock tokenizer implementation for testing
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
use anyhow::Result;
use std::collections::HashMap; use std::collections::HashMap;
use anyhow::Result;
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
/// Mock tokenizer for testing purposes /// Mock tokenizer for testing purposes
pub struct MockTokenizer { pub struct MockTokenizer {
vocab: HashMap<String, u32>, vocab: HashMap<String, u32>,

View File

@@ -1,6 +1,6 @@
use std::{ops::Deref, sync::Arc};
use anyhow::Result; use anyhow::Result;
use std::ops::Deref;
use std::sync::Arc;
pub mod factory; pub mod factory;
pub mod hub; pub mod hub;
@@ -27,14 +27,12 @@ pub use factory::{
create_tokenizer_from_file, create_tokenizer_with_chat_template, create_tokenizer_from_file, create_tokenizer_with_chat_template,
create_tokenizer_with_chat_template_blocking, TokenizerType, create_tokenizer_with_chat_template_blocking, TokenizerType,
}; };
pub use huggingface::HuggingFaceTokenizer;
pub use sequence::Sequence; pub use sequence::Sequence;
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder}; pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
pub use stream::DecodeStream; pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
pub use huggingface::HuggingFaceTokenizer;
pub use tiktoken::{TiktokenModel, TiktokenTokenizer}; pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations /// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)] #[derive(Clone)]

View File

@@ -1,7 +1,9 @@
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
use anyhow::Result;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result;
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
/// Maintains state for an ongoing sequence of tokens and their decoded text /// Maintains state for an ongoing sequence of tokens and their decoded text
/// This provides a cleaner abstraction for managing token sequences /// This provides a cleaner abstraction for managing token sequences
pub struct Sequence { pub struct Sequence {

View File

@@ -1,8 +1,11 @@
use super::sequence::Sequence; use std::{collections::HashSet, sync::Arc};
use super::traits::{self, TokenIdType};
use anyhow::Result; use anyhow::Result;
use std::collections::HashSet;
use std::sync::Arc; use super::{
sequence::Sequence,
traits::{self, TokenIdType},
};
/// Output from the sequence decoder /// Output from the sequence decoder
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]

View File

@@ -1,9 +1,11 @@
// src/tokenizer/stream.rs // src/tokenizer/stream.rs
use super::traits::{self, TokenIdType};
use anyhow::Result;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result;
use super::traits::{self, TokenIdType};
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5; const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
/// DecodeStream will keep the state necessary to produce individual chunks of /// DecodeStream will keep the state necessary to produce individual chunks of

View File

@@ -1,8 +1,9 @@
#[cfg(test)] #[cfg(test)]
use super::*;
#[cfg(test)]
use std::sync::Arc; use std::sync::Arc;
#[cfg(test)]
use super::*;
#[test] #[test]
fn test_mock_tokenizer_encode() { fn test_mock_tokenizer_encode() {
let tokenizer = mock::MockTokenizer::new(); let tokenizer = mock::MockTokenizer::new();

View File

@@ -1,8 +1,9 @@
use anyhow::{Error, Result};
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
use super::traits::{ use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
}; };
use anyhow::{Error, Result};
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
/// Tiktoken tokenizer wrapper for OpenAI GPT models /// Tiktoken tokenizer wrapper for OpenAI GPT models
pub struct TiktokenTokenizer { pub struct TiktokenTokenizer {

View File

@@ -1,6 +1,9 @@
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
use anyhow::Result; use anyhow::Result;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
/// Type alias for token IDs /// Type alias for token IDs
pub type TokenIdType = u32; pub type TokenIdType = u32;

View File

@@ -1,14 +1,19 @@
// Factory and pool for creating model-specific tool parsers with pooling support. // Factory and pool for creating model-specific tool parsers with pooling support.
use std::collections::HashMap; use std::{
use std::sync::{Arc, RwLock}; collections::HashMap,
sync::{Arc, RwLock},
};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::tool_parser::parsers::{ use crate::tool_parser::{
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser, parsers::{
LlamaParser, MistralParser, PassthroughParser, PythonicParser, QwenParser, Step3Parser, DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
LlamaParser, MistralParser, PassthroughParser, PythonicParser, QwenParser, Step3Parser,
},
traits::ToolParser,
}; };
use crate::tool_parser::traits::ToolParser;
/// Type alias for pooled parser instances. /// Type alias for pooled parser instances.
pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>; pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>;

View File

@@ -18,11 +18,10 @@ mod tests;
// Re-export commonly used types // Re-export commonly used types
pub use errors::{ParserError, ParserResult}; pub use errors::{ParserError, ParserResult};
pub use factory::{ParserFactory, ParserRegistry, PooledParser}; pub use factory::{ParserFactory, ParserRegistry, PooledParser};
pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};
// Re-export parsers for convenience // Re-export parsers for convenience
pub use parsers::{ pub use parsers::{
DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser, DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser,
MistralParser, PythonicParser, QwenParser, Step3Parser, MistralParser, PythonicParser, QwenParser, Step3Parser,
}; };
pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};

View File

@@ -2,13 +2,14 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
}; };
/// DeepSeek V3 format parser for tool calls /// DeepSeek V3 format parser for tool calls

View File

@@ -2,13 +2,14 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
}; };
/// GLM-4 MoE format parser for tool calls /// GLM-4 MoE format parser for tool calls

View File

@@ -1,11 +1,12 @@
use async_trait::async_trait; use async_trait::async_trait;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::ParserResult, errors::ParserResult,
traits::{TokenToolParser, ToolParser}, traits::{TokenToolParser, ToolParser},
types::{StreamingParseResult, ToolCall}, types::{StreamingParseResult, ToolCall},
},
}; };
/// Placeholder for the Harmony-backed GPT-OSS parser. /// Placeholder for the Harmony-backed GPT-OSS parser.

View File

@@ -2,14 +2,15 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
}; };
/// GPT-OSS format parser for tool calls /// GPT-OSS format parser for tool calls

View File

@@ -1,9 +1,14 @@
use crate::protocols::common::Tool;
use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use crate::tool_parser::errors::{ParserError, ParserResult}; use serde_json::Value;
use crate::tool_parser::types::{StreamingParseResult, ToolCallItem};
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
types::{StreamingParseResult, ToolCallItem},
},
};
/// Get a mapping of tool names to their indices /// Get a mapping of tool names to their indices
pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> { pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {

View File

@@ -1,14 +1,15 @@
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
}; };
/// JSON format parser for tool calls /// JSON format parser for tool calls

View File

@@ -2,13 +2,14 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::ParserResult, errors::ParserResult,
parsers::helpers, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
}; };
/// Kimi K2 format parser for tool calls /// Kimi K2 format parser for tool calls

View File

@@ -1,14 +1,15 @@
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall},
},
}; };
/// Llama 3.2 format parser for tool calls /// Llama 3.2 format parser for tool calls

View File

@@ -1,14 +1,15 @@
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall},
},
}; };
/// Mistral format parser for tool calls /// Mistral format parser for tool calls

View File

@@ -4,12 +4,17 @@
//! tool call parsing should be performed. It simply returns the input text //! tool call parsing should be performed. It simply returns the input text
//! with no tool calls detected. //! with no tool calls detected.
use crate::protocols::common::Tool;
use crate::tool_parser::errors::ParserResult;
use crate::tool_parser::traits::ToolParser;
use crate::tool_parser::types::{StreamingParseResult, ToolCall, ToolCallItem};
use async_trait::async_trait; use async_trait::async_trait;
use crate::{
protocols::common::Tool,
tool_parser::{
errors::ParserResult,
traits::ToolParser,
types::{StreamingParseResult, ToolCall, ToolCallItem},
},
};
/// Passthrough parser that returns text unchanged with no tool calls /// Passthrough parser that returns text unchanged with no tool calls
#[derive(Default)] #[derive(Default)]
pub struct PassthroughParser; pub struct PassthroughParser;

View File

@@ -1,3 +1,5 @@
use std::sync::OnceLock;
/// Pythonic format parser for tool calls /// Pythonic format parser for tool calls
/// ///
/// Handles Python function call syntax within square brackets: /// Handles Python function call syntax within square brackets:
@@ -10,18 +12,20 @@
use async_trait::async_trait; use async_trait::async_trait;
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use regex::Regex; use regex::Regex;
use rustpython_parser::ast::{Constant, Expr, Mod, UnaryOp}; use rustpython_parser::{
use rustpython_parser::{parse, Mode}; ast::{Constant, Expr, Mod, UnaryOp},
parse, Mode,
};
use serde_json::{Map, Number, Value}; use serde_json::{Map, Number, Value};
use std::sync::OnceLock;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
}; };
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new(); static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();

View File

@@ -2,14 +2,15 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::common::Tool; use crate::{
protocols::common::Tool,
use crate::tool_parser::{ tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall},
},
}; };
/// Qwen format parser for tool calls /// Qwen format parser for tool calls

Some files were not shown because too many files have changed in this diff Show More