[router] Add rustfmt and set group imports by default (#11732)

This commit is contained in:
Chang Su
2025-10-16 17:33:29 -07:00
committed by GitHub
parent 7a7f99beb7
commit dc01313da1
126 changed files with 1127 additions and 813 deletions

View File

@@ -54,7 +54,9 @@ jobs:
run: |
source "$HOME/.cargo/env"
cd sgl-router/
cargo fmt -- --check
rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt
rustup toolchain install nightly --profile minimal
cargo +nightly fmt -- --check
- name: Run Rust tests
timeout-minutes: 20

View File

@@ -1,14 +1,18 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use serde_json::{from_str, to_string, to_value, to_vec};
use std::time::Instant;
use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType};
use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent};
use sglang_router_rs::protocols::common::StringOrArray;
use sglang_router_rs::protocols::completion::CompletionRequest;
use sglang_router_rs::protocols::generate::GenerateRequest;
use sglang_router_rs::protocols::sampling_params::SamplingParams;
use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use serde_json::{from_str, to_string, to_value, to_vec};
use sglang_router_rs::{
core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType},
protocols::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
common::StringOrArray,
completion::CompletionRequest,
generate::GenerateRequest,
sampling_params::SamplingParams,
},
routers::http::pd_types::{generate_room_id, RequestWithBootstrap},
};
fn create_test_worker() -> BasicWorker {
BasicWorkerBuilder::new("http://test-server:8000")

View File

@@ -1,16 +1,21 @@
//! Comprehensive tokenizer benchmark with clean summary output
//! Each test adds a row to the final summary table
use std::{
collections::BTreeMap,
path::PathBuf,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex, OnceLock,
},
thread,
time::{Duration, Instant},
};
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use sglang_router_rs::tokenizer::{
huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*,
};
use std::collections::BTreeMap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::thread;
use std::time::{Duration, Instant};
// Include the common test utilities
#[path = "../tests/common/mod.rs"]

View File

@@ -7,15 +7,22 @@
//! - Streaming vs complete parsing
//! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.)
use std::{
collections::BTreeMap,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
},
thread,
time::{Duration, Instant},
};
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use serde_json::json;
use sglang_router_rs::protocols::common::{Function, Tool};
use sglang_router_rs::tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser};
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
use sglang_router_rs::{
protocols::common::{Function, Tool},
tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser},
};
use tokio::runtime::Runtime;
// Test data for different parser formats - realistic complex examples

8
sgl-router/rustfmt.toml Normal file
View File

@@ -0,0 +1,8 @@
# Rust formatting configuration
# Enforce grouped imports by crate
imports_granularity = "Crate"
# Group std, external crates, and local crate imports separately
group_imports = "StdExternalCrate"
reorder_imports = true
reorder_modules = true

View File

@@ -1,7 +1,9 @@
use super::ConfigResult;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::ConfigResult;
/// Main router configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig {

View File

@@ -1,6 +1,11 @@
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use std::{
sync::{
atomic::{AtomicU32, AtomicU64, Ordering},
Arc, RwLock,
},
time::{Duration, Instant},
};
use tracing::info;
/// Circuit breaker configuration
@@ -316,9 +321,10 @@ pub struct CircuitBreakerStats {
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use super::*;
#[test]
fn test_circuit_breaker_initial_state() {
let cb = CircuitBreaker::new();

View File

@@ -68,9 +68,10 @@ impl From<reqwest::Error> for WorkerError {
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error;
use super::*;
#[test]
fn test_health_check_failed_display() {
let error = WorkerError::HealthCheckFailed {

View File

@@ -3,16 +3,22 @@
//! Provides non-blocking worker management by queuing operations and processing
//! them asynchronously in background worker tasks.
use crate::core::WorkerManager;
use crate::protocols::worker_spec::{JobStatus, WorkerConfigRequest};
use crate::server::AppContext;
use std::{
sync::{Arc, Weak},
time::{Duration, SystemTime},
};
use dashmap::DashMap;
use metrics::{counter, gauge, histogram};
use std::sync::{Arc, Weak};
use std::time::{Duration, SystemTime};
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::{
core::WorkerManager,
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
server::AppContext,
};
/// Job types for control plane operations
#[derive(Debug, Clone)]
pub enum Job {

View File

@@ -1,10 +1,11 @@
use crate::config::types::RetryConfig;
use axum::http::StatusCode;
use axum::response::Response;
use rand::Rng;
use std::time::Duration;
use axum::{http::StatusCode, response::Response};
use rand::Rng;
use tracing::debug;
use crate::config::types::RetryConfig;
/// Check if an HTTP status code indicates a retryable error
pub fn is_retryable_status(status: StatusCode) -> bool {
matches!(
@@ -162,11 +163,14 @@ impl RetryExecutor {
#[cfg(test)]
mod tests {
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
use axum::{http::StatusCode, response::IntoResponse};
use super::*;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
fn base_retry_config() -> RetryConfig {
RetryConfig {

View File

@@ -1,5 +1,8 @@
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::{
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::{Mutex, Notify};
use tracing::{debug, trace};

View File

@@ -1,19 +1,27 @@
use super::{CircuitBreaker, WorkerError, WorkerResult};
use crate::core::CircuitState;
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder};
use crate::grpc_client::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::protocols::worker_spec::WorkerInfo;
use std::{
fmt,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, LazyLock,
},
time::{Duration, Instant},
};
use async_trait::async_trait;
use futures;
use serde_json;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use std::time::Instant;
use tokio::sync::{Mutex, RwLock};
use tokio::time;
use tokio::{
sync::{Mutex, RwLock},
time,
};
use super::{CircuitBreaker, WorkerError, WorkerResult};
use crate::{
core::{BasicWorkerBuilder, CircuitState, DPAwareWorkerBuilder},
grpc_client::SglangSchedulerClient,
metrics::RouterMetrics,
protocols::worker_spec::WorkerInfo,
};
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
@@ -1024,10 +1032,10 @@ pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo {
#[cfg(test)]
mod tests {
use std::{thread, time::Duration};
use super::*;
use crate::core::CircuitBreakerConfig;
use std::thread;
use std::time::Duration;
#[test]
fn test_worker_type_display() {
@@ -1502,9 +1510,10 @@ mod tests {
#[test]
fn test_load_counter_performance() {
use crate::core::BasicWorkerBuilder;
use std::time::Instant;
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();

View File

@@ -1,9 +1,12 @@
use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
use super::worker::{
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType,
use std::collections::HashMap;
use super::{
circuit_breaker::{CircuitBreaker, CircuitBreakerConfig},
worker::{
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType,
},
};
use crate::grpc_client::SglangSchedulerClient;
use std::collections::HashMap;
/// Builder for creating BasicWorker instances with fluent API
pub struct BasicWorkerBuilder {
@@ -100,6 +103,7 @@ impl BasicWorkerBuilder {
atomic::{AtomicBool, AtomicUsize},
Arc,
};
use tokio::sync::{Mutex, RwLock};
let bootstrap_host = match url::Url::parse(&self.url) {
@@ -282,9 +286,10 @@ impl DPAwareWorkerBuilder {
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::core::worker::Worker;
use std::time::Duration;
#[test]
fn test_basic_worker_builder_minimal() {

View File

@@ -3,31 +3,35 @@
//! Handles all aspects of worker lifecycle including discovery, initialization,
//! runtime management, and health monitoring.
use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode,
HealthCheckConfig, RouterConfig, RoutingMode,
};
use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig,
Worker, WorkerFactory, WorkerRegistry, WorkerType,
};
use crate::grpc_client::SglangSchedulerClient;
use crate::policies::PolicyRegistry;
use crate::protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
};
use crate::server::AppContext;
use std::{collections::HashMap, sync::Arc, time::Duration};
use futures::future;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{watch, Mutex};
use tokio::task::JoinHandle;
use tokio::{
sync::{watch, Mutex},
task::JoinHandle,
};
use tracing::{debug, error, info, warn};
use crate::{
config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode,
HealthCheckConfig, RouterConfig, RoutingMode,
},
core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder,
HealthConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType,
},
grpc_client::SglangSchedulerClient,
policies::PolicyRegistry,
protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
},
server::AppContext,
};
static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder()
.timeout(Duration::from_secs(10))
@@ -1803,9 +1807,10 @@ impl Drop for LoadMonitor {
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use super::*;
#[test]
fn test_parse_server_info() {
let json = serde_json::json!({

View File

@@ -2,11 +2,13 @@
//!
//! Provides centralized registry for workers with model-based indexing
use crate::core::{ConnectionMode, Worker, WorkerType};
use dashmap::DashMap;
use std::sync::{Arc, RwLock};
use dashmap::DashMap;
use uuid::Uuid;
use crate::core::{ConnectionMode, Worker, WorkerType};
/// Unique identifier for a worker
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct WorkerId(String);
@@ -363,8 +365,10 @@ impl WorkerRegistry {
/// Start a health checker for all workers in the registry
/// This should be called once after the registry is populated with workers
pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
@@ -433,9 +437,10 @@ pub struct WorkerRegistryStats {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig};
use std::collections::HashMap;
#[test]
fn test_worker_registry() {

View File

@@ -1,14 +1,18 @@
use std::collections::{BTreeMap, HashMap};
use std::sync::RwLock;
use std::{
collections::{BTreeMap, HashMap},
sync::RwLock,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use super::conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams,
Result, SortOrder,
use super::{
conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams,
Result, SortOrder,
},
conversations::ConversationId,
};
use super::conversations::ConversationId;
#[derive(Default)]
pub struct MemoryConversationItemStorage {
@@ -190,9 +194,10 @@ impl ConversationItemStorage for MemoryConversationItemStorage {
#[cfg(test)]
mod tests {
use super::*;
use chrono::{TimeZone, Utc};
use super::*;
fn make_item(
item_type: &str,
role: Option<&str>,

View File

@@ -1,18 +1,21 @@
use crate::config::OracleConfig;
use crate::data_connector::conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
};
use crate::data_connector::conversations::ConversationId;
use std::{path::Path, sync::Arc, time::Duration};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::sql_type::ToSql;
use oracle::Connection;
use oracle::{sql_type::ToSql, Connection};
use serde_json::Value;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use crate::{
config::OracleConfig,
data_connector::{
conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
},
conversations::ConversationId,
},
};
#[derive(Clone)]
pub struct OracleConversationItemStorage {

View File

@@ -1,10 +1,13 @@
use std::{
fmt::{Display, Formatter},
sync::Arc,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use super::conversations::ConversationId;

View File

@@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation,

View File

@@ -1,16 +1,18 @@
use crate::config::OracleConfig;
use crate::data_connector::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result,
};
use std::{path::Path, sync::Arc, time::Duration};
use async_trait::async_trait;
use chrono::Utc;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{sql_type::OracleType, Connection};
use serde_json::Value;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use crate::{
config::OracleConfig,
data_connector::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result,
},
};
#[derive(Clone)]
pub struct OracleConversationStorage {

View File

@@ -1,10 +1,13 @@
use std::{
fmt::{Display, Formatter},
sync::Arc,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use serde_json::{Map as JsonMap, Value};
use std::fmt::{Display, Formatter};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct ConversationId(pub String);

View File

@@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse};

View File

@@ -1,16 +1,17 @@
use crate::config::OracleConfig;
use crate::data_connector::responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
StoredResponse,
};
use std::{collections::HashMap, path::Path, sync::Arc, time::Duration};
use async_trait::async_trait;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{Connection, Row};
use serde_json::Value;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use crate::{
config::OracleConfig,
data_connector::responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
StoredResponse,
},
};
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses";
@@ -510,9 +511,10 @@ impl OracleErrorExt for ResponseStorageError {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use super::*;
#[test]
fn parse_tool_calls_handles_empty_input() {
assert!(parse_tool_calls(None).unwrap().is_empty());

View File

@@ -1,8 +1,8 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
/// Response identifier
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]

View File

@@ -1,16 +1,23 @@
use std::convert::TryFrom;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{
convert::TryFrom,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll},
time::Duration,
};
use tonic::{transport::Channel, Request, Streaming};
use tracing::{debug, warn};
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue};
use crate::protocols::generate::GenerateRequest;
use crate::protocols::sampling_params::SamplingParams as GenerateSamplingParams;
use crate::protocols::{
chat::ChatCompletionRequest,
common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
generate::GenerateRequest,
sampling_params::SamplingParams as GenerateSamplingParams,
};
// Include the generated protobuf code
pub mod proto {

View File

@@ -1,12 +1,14 @@
use std::path::PathBuf;
use tracing::Level;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_appender::{
non_blocking::WorkerGuard,
rolling::{RollingFileAppender, Rotation},
};
use tracing_log::LogTracer;
use tracing_subscriber::fmt::time::ChronoUtc;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
use tracing_subscriber::{
fmt::time::ChronoUtc, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer,
};
#[derive(Debug, Clone)]
pub struct LoggingConfig {

View File

@@ -1,14 +1,17 @@
use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
};
use sglang_router_rs::metrics::PrometheusConfig;
use sglang_router_rs::server::{self, ServerConfig};
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
use std::collections::HashMap;
use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
},
metrics::PrometheusConfig,
server::{self, ServerConfig},
service_discovery::ServiceDiscoveryConfig,
};
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
let args: Vec<String> = std::env::args().collect();
let mut prefill_entries = Vec::new();

View File

@@ -1,3 +1,5 @@
use std::{borrow::Cow, collections::HashMap, time::Duration};
use backoff::ExponentialBackoffBuilder;
use dashmap::DashMap;
use rmcp::{
@@ -13,7 +15,6 @@ use rmcp::{
RoleClient, ServiceExt,
};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, time::Duration};
use crate::mcp::{
config::{McpConfig, McpServerConfig, McpTransport},

View File

@@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpConfig {
pub servers: Vec<McpServerConfig>,

View File

@@ -1,5 +1,7 @@
// OAuth authentication support for MCP servers
use std::{net::SocketAddr, sync::Arc};
use axum::{
extract::{Query, State},
response::Html,
@@ -8,7 +10,6 @@ use axum::{
};
use rmcp::transport::auth::OAuthState;
use serde::Deserialize;
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::{oneshot, Mutex};
use crate::mcp::error::{McpError, McpResult};

View File

@@ -1,7 +1,10 @@
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
time::Duration,
};
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct PrometheusConfig {
@@ -620,9 +623,10 @@ impl TokenizerMetrics {
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpListener;
use super::*;
#[test]
fn test_prometheus_config_default() {
let config = PrometheusConfig::default();
@@ -912,9 +916,13 @@ mod tests {
#[test]
fn test_concurrent_metric_updates() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
thread,
};
let done = Arc::new(AtomicBool::new(false));
let mut handles = vec![];

View File

@@ -1,12 +1,19 @@
use std::{
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use axum::{
body::Body, extract::Request, extract::State, http::header, http::HeaderValue,
http::StatusCode, middleware::Next, response::IntoResponse, response::Response,
body::Body,
extract::{Request, State},
http::{header, HeaderValue, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use rand::Rng;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use subtle::ConstantTimeEq;
use tokio::sync::{mpsc, oneshot};
use tower::{Layer, Service};
@@ -14,9 +21,7 @@ use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
pub use crate::core::token_bucket::TokenBucket;
use crate::metrics::RouterMetrics;
use crate::server::AppState;
use crate::{metrics::RouterMetrics, server::AppState};
#[derive(Clone)]
pub struct AuthConfig {

View File

@@ -59,17 +59,15 @@
during the next eviction cycle.
*/
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use crate::tree::Tree;
use std::{sync::Arc, thread, time::Duration};
use dashmap::DashMap;
use rand::Rng;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use tracing::debug;
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics, tree::Tree};
/// Cache-aware routing policy
///
/// Routes requests based on cache affinity when load is balanced,

View File

@@ -1,11 +1,12 @@
//! Factory for creating load balancing policies
use std::sync::Arc;
use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy,
};
use crate::config::PolicyConfig;
use std::sync::Arc;
/// Factory for creating policy instances
pub struct PolicyFactory;

View File

@@ -3,9 +3,9 @@
//! This module provides a unified abstraction for routing policies that work
//! across both regular and prefill-decode (PD) routing modes.
use std::{fmt::Debug, sync::Arc};
use crate::core::Worker;
use std::fmt::Debug;
use std::sync::Arc;
mod cache_aware;
mod factory;

View File

@@ -1,13 +1,16 @@
//! Power-of-two choices load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use rand::Rng;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::info;
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics};
/// Power-of-two choices policy
///
/// Randomly selects two workers and routes to the one with lower load.

View File

@@ -1,11 +1,12 @@
//! Random load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
use std::sync::Arc;
use rand::Rng;
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics};
/// Random selection policy
///
/// Selects workers randomly with uniform distribution among healthy workers.
@@ -50,9 +51,10 @@ impl LoadBalancingPolicy for RandomPolicy {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::core::{BasicWorkerBuilder, WorkerType};
use std::collections::HashMap;
#[test]
fn test_random_selection() {

View File

@@ -1,3 +1,10 @@
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tracing::{debug, info, warn};
/// Policy Registry for managing model-to-policy mappings
///
/// This registry manages the dynamic assignment of load balancing policies to models.
@@ -8,11 +15,7 @@ use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy,
};
use crate::config::types::PolicyConfig;
use crate::core::Worker;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn};
use crate::{config::types::PolicyConfig, core::Worker};
/// Registry for managing model-to-policy mappings
#[derive(Clone)]

View File

@@ -1,10 +1,12 @@
//! Round-robin load balancing policy
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::{core::Worker, metrics::RouterMetrics};
/// Round-robin selection policy
///

View File

@@ -1,10 +1,13 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use validator::Validate;
use super::common::*;
use super::sampling_params::{validate_top_k_value, validate_top_p_value};
use super::{
common::*,
sampling_params::{validate_top_k_value, validate_top_p_value},
};
use crate::protocols::validated::Normalizable;
// ============================================================================
@@ -532,11 +535,12 @@ impl Normalizable for ChatCompletionRequest {
// Apply tool_choice defaults
if self.tool_choice.is_none() {
if let Some(tools) = &self.tools {
self.tool_choice = if !tools.is_empty() {
Some(ToolChoice::Value(ToolChoiceValue::Auto))
let choice_value = if !tools.is_empty() {
ToolChoiceValue::Auto
} else {
Some(ToolChoice::Value(ToolChoiceValue::None))
ToolChoiceValue::None
};
self.tool_choice = Some(ToolChoice::Value(choice_value));
}
// If tools is None, leave tool_choice as None (don't set it)
}

View File

@@ -1,6 +1,7 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
// ============================================================================
// Default value helpers

View File

@@ -1,6 +1,7 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::collections::HashMap;
use super::common::*;

View File

@@ -1,10 +1,13 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use validator::Validate;
use super::common::{default_true, GenerationRequest, InputIds};
use super::sampling_params::SamplingParams;
use super::{
common::{default_true, GenerationRequest, InputIds},
sampling_params::SamplingParams,
};
use crate::protocols::validated::Normalizable;
// ============================================================================

View File

@@ -1,6 +1,7 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use validator::Validate;
use super::common::{default_model, default_true, GenerationRequest, StringOrArray, UsageInfo};

View File

@@ -1,9 +1,10 @@
// OpenAI Responses API types
// https://platform.openai.com/docs/api-reference/responses
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
// Import shared types from common module
use super::common::{

View File

@@ -117,10 +117,11 @@ impl<T> std::ops::DerefMut for ValidatedJson<T> {
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use validator::Validate;
use super::*;
#[derive(Debug, Deserialize, Serialize, Validate)]
struct TestRequest {
#[validate(range(min = 0.0, max = 1.0))]

View File

@@ -2,9 +2,10 @@
//!
//! Defines the request/response structures for worker management endpoints
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
/// Worker configuration for API requests
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerConfigRequest {

View File

@@ -1,16 +1,20 @@
// Factory and registry for creating model-specific reasoning parsers.
// Now with parser pooling support for efficient reuse across requests.
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tokio::sync::Mutex;
use crate::reasoning_parser::parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser,
use crate::reasoning_parser::{
parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser,
},
traits::{ParseError, ParserConfig, ReasoningParser},
};
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
/// Type alias for pooled parser instances.
/// Uses tokio::Mutex to avoid blocking the async executor.
@@ -402,8 +406,10 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn test_high_concurrency_parser_access() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Instant;
use std::{
sync::atomic::{AtomicUsize, Ordering},
time::Instant,
};
let factory = ParserFactory::new();
let num_tasks = 100;

View File

@@ -2,8 +2,10 @@
// This parser starts with in_reasoning=true, assuming all text is reasoning
// until an end token is encountered.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
use crate::reasoning_parser::{
parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// DeepSeek-R1 reasoning parser.
///

View File

@@ -1,8 +1,10 @@
// GLM45 specific reasoning parser.
// Uses the same format as Qwen3 but has its own implementation for debugging.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
use crate::reasoning_parser::{
parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// GLM45 reasoning parser.
///

View File

@@ -1,8 +1,10 @@
// Kimi specific reasoning parser.
// This parser uses Unicode tokens and starts with in_reasoning=false.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
use crate::reasoning_parser::{
parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// Kimi reasoning parser.
///

View File

@@ -2,8 +2,10 @@
// This parser starts with in_reasoning=false, requiring an explicit
// start token to enter reasoning mode.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
use crate::reasoning_parser::{
parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// Qwen3 reasoning parser.
///

View File

@@ -1,8 +1,10 @@
// Step3 specific reasoning parser.
// Uses the same format as DeepSeek-R1 but has its own implementation for debugging.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
use crate::reasoning_parser::{
parsers::BaseReasoningParser,
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
};
/// Step3 reasoning parser.
///

View File

@@ -1,16 +1,18 @@
//! Factory for creating router instances
use super::grpc::pd_router::GrpcPDRouter;
use super::grpc::router::GrpcRouter;
use std::sync::Arc;
use super::{
grpc::{pd_router::GrpcPDRouter, router::GrpcRouter},
http::{pd_router::PDRouter, router::Router},
openai::OpenAIRouter,
RouterTrait,
};
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
use crate::policies::PolicyFactory;
use crate::server::AppContext;
use std::sync::Arc;
use crate::{
config::{ConnectionMode, PolicyConfig, RoutingMode},
policies::PolicyFactory,
server::AppContext,
};
/// Factory for creating router instances based on configuration
pub struct RouterFactory;

View File

@@ -4,20 +4,22 @@
//! eliminating deep parameter passing chains and providing a single source of truth
//! for request state.
use std::collections::HashMap;
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use axum::http::HeaderMap;
use serde_json::Value;
use crate::core::Worker;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::chat::{ChatCompletionRequest, ChatCompletionResponse};
use crate::protocols::generate::{GenerateRequest, GenerateResponse};
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::tokenizer::stop::StopSequenceDecoder;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory;
use crate::{
core::Worker,
grpc_client::{proto, SglangSchedulerClient},
protocols::{
chat::{ChatCompletionRequest, ChatCompletionResponse},
generate::{GenerateRequest, GenerateResponse},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
tokenizer::{stop::StopSequenceDecoder, traits::Tokenizer},
tool_parser::ParserFactory as ToolParserFactory,
};
// ============================================================================
// Core Context Types

View File

@@ -1,7 +1,6 @@
//! gRPC router implementations
use crate::grpc_client::proto;
use crate::protocols::common::StringOrArray;
use crate::{grpc_client::proto, protocols::common::StringOrArray};
pub mod context;
pub mod pd_router;

View File

@@ -1,19 +1,7 @@
// PD (Prefill-Decode) gRPC Router Implementation
use crate::config::types::RetryConfig;
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
use crate::policies::PolicyRegistry;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::routers::RouterTrait;
use crate::server::AppContext;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory;
use std::sync::Arc;
use async_trait::async_trait;
use axum::{
body::Body,
@@ -21,12 +9,27 @@ use axum::{
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::sync::Arc;
use tracing::debug;
use super::context::SharedComponents;
use super::pipeline::RequestPipeline;
use super::{context::SharedComponents, pipeline::RequestPipeline};
use crate::{
config::types::RetryConfig,
core::{ConnectionMode, WorkerRegistry, WorkerType},
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::RouterTrait,
server::AppContext,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
/// gRPC PD (Prefill-Decode) router implementation for SGLang
#[derive(Clone)]

View File

@@ -3,29 +3,29 @@
//! This module defines the core pipeline abstraction and individual processing stages
//! that transform a RequestContext through its lifecycle.
use std::{
sync::Arc,
time::{Instant, SystemTime, UNIX_EPOCH},
};
use async_trait::async_trait;
use axum::response::{IntoResponse, Response};
use tracing::{debug, error, warn};
use super::context::*;
use super::processing;
use super::streaming;
use super::utils;
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
use crate::grpc_client::proto;
use crate::policies::PolicyRegistry;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::common::InputIds;
use crate::protocols::generate::GenerateRequest;
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory;
use proto::DisaggregatedParams;
use rand::Rng;
use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use tracing::{debug, error, warn};
use uuid::Uuid;
use super::{context::*, processing, streaming, utils};
use crate::{
core::{ConnectionMode, Worker, WorkerRegistry, WorkerType},
grpc_client::proto,
policies::PolicyRegistry,
protocols::{chat::ChatCompletionRequest, common::InputIds, generate::GenerateRequest},
reasoning_parser::ParserFactory as ReasoningParserFactory,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
// ============================================================================
// Pipeline Trait
// ============================================================================

View File

@@ -3,28 +3,30 @@
//! This module contains response processing functions that are shared between
//! the regular router and PD router, eliminating ~1,200 lines of exact duplicates.
use std::sync::Arc;
use std::{sync::Arc, time::Instant};
use proto::generate_complete::MatchedStop;
use serde_json::Value;
use tracing::error;
use crate::grpc_client::proto;
use crate::protocols::chat::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
use super::{
context::{DispatchMetadata, ExecutionResult},
utils,
};
use crate::protocols::common::{
FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage,
use crate::{
grpc_client::proto,
protocols::{
chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse},
common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage},
generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
tokenizer::{
stop::{SequenceDecoderOutput, StopSequenceDecoder},
traits::Tokenizer,
},
tool_parser::ParserFactory as ToolParserFactory,
};
use crate::protocols::generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse};
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory;
use proto::generate_complete::MatchedStop;
use std::time::Instant;
use super::context::{DispatchMetadata, ExecutionResult};
use super::utils;
// ============================================================================
// Response Processor - Main Entry Point

View File

@@ -11,23 +11,25 @@ use axum::{
};
use tracing::debug;
use crate::config::types::RetryConfig;
use crate::core::WorkerRegistry;
use crate::policies::PolicyRegistry;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::routers::RouterTrait;
use crate::server::AppContext;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory;
use super::context::SharedComponents;
use super::pipeline::RequestPipeline;
use super::{context::SharedComponents, pipeline::RequestPipeline};
use crate::{
config::types::RetryConfig,
core::WorkerRegistry,
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::RouterTrait,
server::AppContext,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
/// gRPC router implementation for SGLang
#[derive(Clone)]

View File

@@ -3,38 +3,40 @@
//! This module contains shared streaming logic for both Regular and PD routers,
//! eliminating ~600 lines of duplication.
use axum::response::Response;
use axum::{body::Body, http::StatusCode};
use std::{collections::HashMap, io, sync::Arc, time::Instant};
use axum::{body::Body, http::StatusCode, response::Response};
use bytes::Bytes;
use http::header::{HeaderValue, CONTENT_TYPE};
use proto::{
generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId},
generate_response::Response::{Chunk, Complete, Error},
};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::io;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tokio::sync::{mpsc, mpsc::UnboundedSender};
use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
use tracing::{debug, error, warn};
use super::context;
use super::utils;
use crate::grpc_client::proto;
use crate::protocols::chat::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
use super::{context, utils};
use crate::{
grpc_client::proto,
protocols::{
chat::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
},
common::{
ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice,
ToolChoiceValue, Usage,
},
generate::GenerateRequest,
},
reasoning_parser::ReasoningParser,
tokenizer::{
stop::{SequenceDecoderOutput, StopSequenceDecoder},
traits::Tokenizer,
},
tool_parser::ToolParser,
};
use crate::protocols::common::{
ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice,
ToolChoiceValue, Usage,
};
use crate::protocols::generate::GenerateRequest;
use crate::reasoning_parser::ReasoningParser;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParser;
use proto::generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId};
use proto::generate_response::Response::{Chunk, Complete, Error};
use std::time::Instant;
use tokio::sync::mpsc;
/// Shared streaming processor for both single and dual dispatch modes
#[derive(Clone)]

View File

@@ -1,19 +1,7 @@
//! Shared utilities for gRPC routers
use super::ProcessedMessages;
use crate::core::Worker;
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::chat::{ChatCompletionRequest, ChatMessage};
use crate::protocols::common::{
ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall,
ToolChoice, ToolChoiceValue, TopLogProb,
};
use crate::protocols::generate::GenerateFinishReason;
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer;
pub use crate::tokenizer::StopSequenceDecoder;
use std::{collections::HashMap, sync::Arc};
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
@@ -21,11 +9,29 @@ use axum::{
};
use futures::StreamExt;
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{error, warn};
use uuid::Uuid;
use super::ProcessedMessages;
pub use crate::tokenizer::StopSequenceDecoder;
use crate::{
core::Worker,
grpc_client::{proto, sglang_scheduler::AbortOnDropStream, SglangSchedulerClient},
protocols::{
chat::{ChatCompletionRequest, ChatMessage},
common::{
ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall,
ToolChoice, ToolChoiceValue, TopLogProb,
},
generate::GenerateFinishReason,
},
tokenizer::{
chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
traits::Tokenizer,
HuggingFaceTokenizer,
},
};
/// Get gRPC client from worker, returning appropriate error response on failure
pub async fn get_grpc_client_from_worker(
worker: &Arc<dyn Worker>,
@@ -953,12 +959,17 @@ pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> Generate
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::chat::{ChatMessage, UserMessageContent};
use crate::protocols::common::{ContentPart, ImageUrl};
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
use serde_json::json;
use super::*;
use crate::{
protocols::{
chat::{ChatMessage, UserMessageContent},
common::{ContentPart, ImageUrl},
},
tokenizer::chat_template::ChatTemplateContentFormat,
};
#[test]
fn test_transform_messages_string_format() {
let messages = vec![ChatMessage::User {

View File

@@ -1,6 +1,4 @@
use axum::body::Body;
use axum::extract::Request;
use axum::http::HeaderMap;
use axum::{body::Body, extract::Request, http::HeaderMap};
/// Copy request headers to a Vec of name-value string pairs
/// Used for forwarding headers to backend workers

View File

@@ -1,19 +1,5 @@
use super::pd_types::api_path;
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent};
use crate::protocols::common::{InputIds, StringOrArray};
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::header_utils;
use crate::routers::RouterTrait;
use std::{sync::Arc, time::Instant};
use async_trait::async_trait;
use axum::{
body::Body,
@@ -25,11 +11,29 @@ use futures_util::StreamExt;
use reqwest::Client;
use serde::Serialize;
use serde_json::{json, Value};
use std::sync::Arc;
use std::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use super::pd_types::api_path;
use crate::{
config::types::RetryConfig,
core::{
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
},
metrics::RouterMetrics,
policies::{LoadBalancingPolicy, PolicyRegistry},
protocols::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
common::{InputIds, StringOrArray},
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::{header_utils, RouterTrait},
};
#[derive(Debug)]
pub struct PDRouter {
pub worker_registry: Arc<WorkerRegistry>,

View File

@@ -1,35 +1,39 @@
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::common::GenerationRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::{RerankRequest, RerankResponse, RerankResult};
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::header_utils;
use crate::routers::RouterTrait;
use axum::body::to_bytes;
use std::{sync::Arc, time::Instant};
use axum::{
body::Body,
body::{to_bytes, Body},
extract::Request,
http::{
header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode,
header::{CONTENT_LENGTH, CONTENT_TYPE},
HeaderMap, HeaderValue, Method, StatusCode,
},
response::{IntoResponse, Response},
Json,
};
use futures_util::StreamExt;
use reqwest::Client;
use std::sync::Arc;
use std::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error};
use crate::{
config::types::RetryConfig,
core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
},
metrics::RouterMetrics,
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
common::GenerationRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::{RerankRequest, RerankResponse, RerankResult},
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::{header_utils, RouterTrait},
};
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {

View File

@@ -1,5 +1,7 @@
//! Router implementations
use std::fmt::Debug;
use async_trait::async_trait;
use axum::{
body::Body,
@@ -7,16 +9,17 @@ use axum::{
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::fmt::Debug;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use serde_json::Value;
use crate::protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
};
pub mod factory;
pub mod grpc;
pub mod header_utils;
@@ -25,7 +28,6 @@ pub mod openai; // New refactored OpenAI router module
pub mod router_manager;
pub use factory::RouterFactory;
// Re-export HTTP routers for convenience
pub use http::{pd_router, pd_types, router};

View File

@@ -1,22 +1,26 @@
//! Conversation CRUD operations and persistence
use crate::data_connector::{
conversation_items::ListParams, conversation_items::SortOrder, Conversation, ConversationId,
ConversationItemId, ConversationItemStorage, ConversationStorage, NewConversation,
NewConversationItem, ResponseId, ResponseStorage, SharedConversationItemStorage,
SharedConversationStorage,
use std::{collections::HashMap, sync::Arc};
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use crate::protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use chrono::Utc;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, warn};
use super::responses::build_stored_response;
use crate::{
data_connector::{
conversation_items::{ListParams, SortOrder},
Conversation, ConversationId, ConversationItemId, ConversationItemStorage,
ConversationStorage, NewConversation, NewConversationItem, ResponseId, ResponseStorage,
SharedConversationItemStorage, SharedConversationStorage,
},
protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest},
};
/// Maximum number of properties allowed in conversation metadata
pub(crate) const MAX_METADATA_PROPERTIES: usize = 16;

View File

@@ -8,19 +8,20 @@
//! - Payload transformation for MCP tool interception
//! - Metadata injection for MCP operations
use crate::mcp::McpClientManager;
use crate::protocols::responses::{
ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest,
};
use crate::routers::header_utils::apply_request_headers;
use std::{io, sync::Arc};
use axum::http::HeaderMap;
use bytes::Bytes;
use serde_json::{json, to_value, Value};
use std::{io, sync::Arc};
use tokio::sync::mpsc;
use tracing::{info, warn};
use super::utils::event_types;
use crate::{
mcp::McpClientManager,
protocols::responses::{ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest},
routers::header_utils::apply_request_headers,
};
// ============================================================================
// Configuration and State Types

View File

@@ -1,12 +1,15 @@
//! Response storage, patching, and extraction utilities
use crate::data_connector::{ResponseId, StoredResponse};
use crate::protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest};
use serde_json::{json, Value};
use std::collections::HashMap;
use serde_json::{json, Value};
use tracing::warn;
use super::utils::event_types;
use crate::{
data_connector::{ResponseId, StoredResponse},
protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest},
};
// ============================================================================
// Response Storage Operations

View File

@@ -1,21 +1,10 @@
//! OpenAI router - main coordinator that delegates to specialized modules
use crate::config::CircuitBreakerConfig;
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
use crate::data_connector::{
conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId,
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
use std::{
any::Any,
sync::{atomic::AtomicBool, Arc},
};
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
ResponsesRequest,
};
use crate::routers::header_utils::apply_request_headers;
use axum::{
body::Body,
extract::Request,
@@ -25,10 +14,6 @@ use axum::{
};
use futures_util::StreamExt;
use serde_json::{json, to_value, Value};
use std::{
any::Any,
sync::{atomic::AtomicBool, Arc},
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn;
@@ -39,12 +24,35 @@ use super::conversations::{
get_conversation, get_conversation_item, list_conversation_items, persist_conversation_items,
update_conversation,
};
use super::mcp::{
execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
McpLoopConfig,
use super::{
mcp::{
execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
McpLoopConfig,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json},
streaming::handle_streaming_response,
};
use crate::{
config::CircuitBreakerConfig,
core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig},
data_connector::{
conversation_items::{ListParams, SortOrder},
ConversationId, ResponseId, SharedConversationItemStorage, SharedConversationStorage,
SharedResponseStorage,
},
protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
ResponsesRequest,
},
},
routers::header_utils::apply_request_headers,
};
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json};
use super::streaming::handle_streaming_response;
// ============================================================================
// OpenAIRouter Struct

View File

@@ -7,11 +7,8 @@
//! - MCP tool execution loops within streaming responses
//! - Event transformation and output index remapping
use crate::data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
};
use crate::protocols::responses::{ResponseToolType, ResponsesRequest};
use crate::routers::header_utils::{apply_request_headers, preserve_response_headers};
use std::{borrow::Cow, io, sync::Arc};
use axum::{
body::Body,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
@@ -20,20 +17,28 @@ use axum::{
use bytes::Bytes;
use futures_util::StreamExt;
use serde_json::{json, Value};
use std::{borrow::Cow, io, sync::Arc};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn;
// Import from sibling modules
use super::conversations::persist_conversation_items;
use super::mcp::{
build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming,
mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, send_mcp_list_tools_events,
McpLoopConfig, ToolLoopState,
use super::{
mcp::{
build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming,
mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
send_mcp_list_tools_events, McpLoopConfig, ToolLoopState,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block},
utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction},
};
use crate::{
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
protocols::responses::{ResponseToolType, ResponsesRequest},
routers::header_utils::{apply_request_headers, preserve_response_headers},
};
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block};
use super::utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction};
// ============================================================================
// Streaming Response Accumulator

View File

@@ -4,16 +4,8 @@
//! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use crate::config::{ConnectionMode, RoutingMode};
use crate::core::{WorkerRegistry, WorkerType};
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::RouterTrait;
use crate::server::{AppContext, ServerConfig};
use std::sync::Arc;
use async_trait::async_trait;
use axum::{
body::Body,
@@ -23,9 +15,23 @@ use axum::{
};
use dashmap::DashMap;
use serde_json::Value;
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::{
config::{ConnectionMode, RoutingMode},
core::{WorkerRegistry, WorkerType},
protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::RouterTrait,
server::{AppContext, ServerConfig},
};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct RouterId(String);

View File

@@ -1,3 +1,24 @@
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
},
time::Duration,
};
use axum::{
extract::{Path, Query, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
serve, Json, Router,
};
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};
use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
core::{
@@ -30,24 +51,6 @@ use crate::{
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tool_parser::ParserFactory as ToolParserFactory,
};
use axum::{
extract::{Path, Query, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
serve, Json, Router,
};
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use std::sync::OnceLock;
use std::{
sync::atomic::{AtomicBool, Ordering},
sync::Arc,
time::Duration,
};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};
//

View File

@@ -1,24 +1,25 @@
use crate::core::WorkerManager;
use crate::protocols::worker_spec::WorkerConfigRequest;
use crate::server::AppContext;
use std::{
collections::{HashMap, HashSet},
sync::{Arc, Mutex},
time::Duration,
};
use futures::{StreamExt, TryStreamExt};
use k8s_openapi::api::core::v1::Pod;
use kube::{
api::Api,
runtime::watcher::{watcher, Config},
runtime::WatchStreamExt,
runtime::{
watcher::{watcher, Config},
WatchStreamExt,
},
Client,
};
use std::collections::{HashMap, HashSet};
use rustls;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::task;
use tokio::time;
use tokio::{task, time};
use tracing::{debug, error, info, warn};
use crate::{core::WorkerManager, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
#[derive(Debug, Clone)]
pub struct ServiceDiscoveryConfig {
pub enabled: bool,
@@ -452,10 +453,12 @@ async fn handle_pod_deletion(
#[cfg(test)]
mod tests {
use k8s_openapi::{
api::core::v1::{Pod, PodCondition, PodSpec, PodStatus},
apimachinery::pkg::apis::meta::v1::{ObjectMeta, Time},
};
use super::*;
use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus};
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
fn create_k8s_pod(
name: Option<&str>,
@@ -535,8 +538,7 @@ mod tests {
}
async fn create_test_app_context() -> Arc<AppContext> {
use crate::config::RouterConfig;
use crate::middleware::TokenBucket;
use crate::{config::RouterConfig, middleware::TokenBucket};
let router_config = RouterConfig {
worker_startup_timeout_secs: 1,

View File

@@ -3,12 +3,16 @@
//! This module provides functionality to apply chat templates to messages,
//! similar to HuggingFace transformers' apply_chat_template method.
use anyhow::{anyhow, Result};
use minijinja::machinery::ast::{Expr, Stmt};
use minijinja::{context, Environment, Value};
use serde_json;
use std::collections::HashMap;
use anyhow::{anyhow, Result};
use minijinja::{
context,
machinery::ast::{Expr, Stmt},
Environment, Value,
};
use serde_json;
/// Chat template content format
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChatTemplateContentFormat {
@@ -319,8 +323,10 @@ impl<'a> Detector<'a> {
/// AST-based detection using minijinja's unstable machinery
/// Single-pass detector with scope tracking
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
use minijinja::machinery::{parse, WhitespaceConfig};
use minijinja::syntax::SyntaxConfig;
use minijinja::{
machinery::{parse, WhitespaceConfig},
syntax::SyntaxConfig,
};
let ast = match parse(
template,

View File

@@ -1,13 +1,9 @@
use super::traits;
use std::{fs::File, io::Read, path::Path, sync::Arc};
use anyhow::{Error, Result};
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::sync::Arc;
use tracing::{debug, info};
use super::huggingface::HuggingFaceTokenizer;
use super::tiktoken::TiktokenTokenizer;
use super::{huggingface::HuggingFaceTokenizer, tiktoken::TiktokenTokenizer, traits};
use crate::tokenizer::hub::download_tokenizer_from_hf;
/// Represents the type of tokenizer being used
@@ -379,8 +375,7 @@ pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
_ => {
// Try auto-detection
use std::fs::File;
use std::io::Read;
use std::{fs::File, io::Read};
let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; 512];

View File

@@ -1,6 +1,9 @@
use std::{
env,
path::{Path, PathBuf},
};
use hf_hub::api::tokio::ApiBuilder;
use std::env;
use std::path::{Path, PathBuf};
const IGNORED: [&str; 5] = [
".gitattributes",

View File

@@ -3,12 +3,12 @@ use std::collections::HashMap;
use anyhow::{Error, Result};
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor,
};
use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
use super::{
chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor,
},
traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
};
/// HuggingFace tokenizer wrapper

View File

@@ -1,9 +1,11 @@
//! Mock tokenizer implementation for testing
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
use anyhow::Result;
use std::collections::HashMap;
use anyhow::Result;
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
/// Mock tokenizer for testing purposes
pub struct MockTokenizer {
vocab: HashMap<String, u32>,

View File

@@ -1,6 +1,6 @@
use std::{ops::Deref, sync::Arc};
use anyhow::Result;
use std::ops::Deref;
use std::sync::Arc;
pub mod factory;
pub mod hub;
@@ -27,14 +27,12 @@ pub use factory::{
create_tokenizer_from_file, create_tokenizer_with_chat_template,
create_tokenizer_with_chat_template_blocking, TokenizerType,
};
pub use huggingface::HuggingFaceTokenizer;
pub use sequence::Sequence;
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
pub use huggingface::HuggingFaceTokenizer;
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)]

View File

@@ -1,7 +1,9 @@
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
use anyhow::Result;
use std::sync::Arc;
use anyhow::Result;
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
/// Maintains state for an ongoing sequence of tokens and their decoded text
/// This provides a cleaner abstraction for managing token sequences
pub struct Sequence {

View File

@@ -1,8 +1,11 @@
use super::sequence::Sequence;
use super::traits::{self, TokenIdType};
use std::{collections::HashSet, sync::Arc};
use anyhow::Result;
use std::collections::HashSet;
use std::sync::Arc;
use super::{
sequence::Sequence,
traits::{self, TokenIdType},
};
/// Output from the sequence decoder
#[derive(Debug, Clone, PartialEq)]

View File

@@ -1,9 +1,11 @@
// src/tokenizer/stream.rs
use super::traits::{self, TokenIdType};
use anyhow::Result;
use std::sync::Arc;
use anyhow::Result;
use super::traits::{self, TokenIdType};
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
/// DecodeStream will keep the state necessary to produce individual chunks of

View File

@@ -1,8 +1,9 @@
#[cfg(test)]
use super::*;
#[cfg(test)]
use std::sync::Arc;
#[cfg(test)]
use super::*;
#[test]
fn test_mock_tokenizer_encode() {
let tokenizer = mock::MockTokenizer::new();

View File

@@ -1,8 +1,9 @@
use anyhow::{Error, Result};
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
use anyhow::{Error, Result};
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
/// Tiktoken tokenizer wrapper for OpenAI GPT models
pub struct TiktokenTokenizer {

View File

@@ -1,6 +1,9 @@
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
use anyhow::Result;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
/// Type alias for token IDs
pub type TokenIdType = u32;

View File

@@ -1,14 +1,19 @@
// Factory and pool for creating model-specific tool parsers with pooling support.
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tokio::sync::Mutex;
use crate::tool_parser::parsers::{
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
LlamaParser, MistralParser, PassthroughParser, PythonicParser, QwenParser, Step3Parser,
use crate::tool_parser::{
parsers::{
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
LlamaParser, MistralParser, PassthroughParser, PythonicParser, QwenParser, Step3Parser,
},
traits::ToolParser,
};
use crate::tool_parser::traits::ToolParser;
/// Type alias for pooled parser instances.
pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>;

View File

@@ -18,11 +18,10 @@ mod tests;
// Re-export commonly used types
pub use errors::{ParserError, ParserResult};
pub use factory::{ParserFactory, ParserRegistry, PooledParser};
pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};
// Re-export parsers for convenience
pub use parsers::{
DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser,
MistralParser, PythonicParser, QwenParser, Step3Parser,
};
pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};

View File

@@ -2,13 +2,14 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
};
/// DeepSeek V3 format parser for tool calls

View File

@@ -2,13 +2,14 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
};
/// GLM-4 MoE format parser for tool calls

View File

@@ -1,11 +1,12 @@
use async_trait::async_trait;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::ParserResult,
traits::{TokenToolParser, ToolParser},
types::{StreamingParseResult, ToolCall},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::ParserResult,
traits::{TokenToolParser, ToolParser},
types::{StreamingParseResult, ToolCall},
},
};
/// Placeholder for the Harmony-backed GPT-OSS parser.

View File

@@ -2,14 +2,15 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
};
/// GPT-OSS format parser for tool calls

View File

@@ -1,9 +1,14 @@
use crate::protocols::common::Tool;
use serde_json::Value;
use std::collections::HashMap;
use crate::tool_parser::errors::{ParserError, ParserResult};
use crate::tool_parser::types::{StreamingParseResult, ToolCallItem};
use serde_json::Value;
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
types::{StreamingParseResult, ToolCallItem},
},
};
/// Get a mapping of tool names to their indices
pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {

View File

@@ -1,14 +1,15 @@
use async_trait::async_trait;
use serde_json::Value;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
};
/// JSON format parser for tool calls

View File

@@ -2,13 +2,14 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::ParserResult,
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::ParserResult,
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
};
/// Kimi K2 format parser for tool calls

View File

@@ -1,14 +1,15 @@
use async_trait::async_trait;
use serde_json::Value;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall},
},
};
/// Llama 3.2 format parser for tool calls

View File

@@ -1,14 +1,15 @@
use async_trait::async_trait;
use serde_json::Value;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall},
},
};
/// Mistral format parser for tool calls

View File

@@ -4,12 +4,17 @@
//! tool call parsing should be performed. It simply returns the input text
//! with no tool calls detected.
use crate::protocols::common::Tool;
use crate::tool_parser::errors::ParserResult;
use crate::tool_parser::traits::ToolParser;
use crate::tool_parser::types::{StreamingParseResult, ToolCall, ToolCallItem};
use async_trait::async_trait;
use crate::{
protocols::common::Tool,
tool_parser::{
errors::ParserResult,
traits::ToolParser,
types::{StreamingParseResult, ToolCall, ToolCallItem},
},
};
/// Passthrough parser that returns text unchanged with no tool calls
#[derive(Default)]
pub struct PassthroughParser;

View File

@@ -1,3 +1,5 @@
use std::sync::OnceLock;
/// Pythonic format parser for tool calls
///
/// Handles Python function call syntax within square brackets:
@@ -10,18 +12,20 @@
use async_trait::async_trait;
use num_traits::ToPrimitive;
use regex::Regex;
use rustpython_parser::ast::{Constant, Expr, Mod, UnaryOp};
use rustpython_parser::{parse, Mode};
use rustpython_parser::{
ast::{Constant, Expr, Mod, UnaryOp},
parse, Mode,
};
use serde_json::{Map, Number, Value};
use std::sync::OnceLock;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
};
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();

View File

@@ -2,14 +2,15 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall},
use crate::{
protocols::common::Tool,
tool_parser::{
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall},
},
};
/// Qwen format parser for tool calls

Some files were not shown because too many files have changed in this diff Show More