[router] Add rustfmt and set group imports by default (#11732)
This commit is contained in:
4
.github/workflows/pr-test-rust.yml
vendored
4
.github/workflows/pr-test-rust.yml
vendored
@@ -54,7 +54,9 @@ jobs:
|
||||
run: |
|
||||
source "$HOME/.cargo/env"
|
||||
cd sgl-router/
|
||||
cargo fmt -- --check
|
||||
rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt
|
||||
rustup toolchain install nightly --profile minimal
|
||||
cargo +nightly fmt -- --check
|
||||
|
||||
- name: Run Rust tests
|
||||
timeout-minutes: 20
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use serde_json::{from_str, to_string, to_value, to_vec};
|
||||
use std::time::Instant;
|
||||
|
||||
use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType};
|
||||
use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent};
|
||||
use sglang_router_rs::protocols::common::StringOrArray;
|
||||
use sglang_router_rs::protocols::completion::CompletionRequest;
|
||||
use sglang_router_rs::protocols::generate::GenerateRequest;
|
||||
use sglang_router_rs::protocols::sampling_params::SamplingParams;
|
||||
use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap};
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use serde_json::{from_str, to_string, to_value, to_vec};
|
||||
use sglang_router_rs::{
|
||||
core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType},
|
||||
protocols::{
|
||||
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
|
||||
common::StringOrArray,
|
||||
completion::CompletionRequest,
|
||||
generate::GenerateRequest,
|
||||
sampling_params::SamplingParams,
|
||||
},
|
||||
routers::http::pd_types::{generate_room_id, RequestWithBootstrap},
|
||||
};
|
||||
|
||||
fn create_test_worker() -> BasicWorker {
|
||||
BasicWorkerBuilder::new("http://test-server:8000")
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
//! Comprehensive tokenizer benchmark with clean summary output
|
||||
//! Each test adds a row to the final summary table
|
||||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
path::PathBuf,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU64, Ordering},
|
||||
Arc, Mutex, OnceLock,
|
||||
},
|
||||
thread,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
|
||||
use sglang_router_rs::tokenizer::{
|
||||
huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*,
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
// Include the common test utilities
|
||||
#[path = "../tests/common/mod.rs"]
|
||||
|
||||
@@ -7,15 +7,22 @@
|
||||
//! - Streaming vs complete parsing
|
||||
//! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.)
|
||||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU64, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
thread,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::protocols::common::{Function, Tool};
|
||||
use sglang_router_rs::tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser};
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
use sglang_router_rs::{
|
||||
protocols::common::{Function, Tool},
|
||||
tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser},
|
||||
};
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
// Test data for different parser formats - realistic complex examples
|
||||
|
||||
8
sgl-router/rustfmt.toml
Normal file
8
sgl-router/rustfmt.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
# Rust formatting configuration
|
||||
|
||||
# Enforce grouped imports by crate
|
||||
imports_granularity = "Crate"
|
||||
# Group std, external crates, and local crate imports separately
|
||||
group_imports = "StdExternalCrate"
|
||||
reorder_imports = true
|
||||
reorder_modules = true
|
||||
@@ -1,7 +1,9 @@
|
||||
use super::ConfigResult;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::ConfigResult;
|
||||
|
||||
/// Main router configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RouterConfig {
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicU32, AtomicU64, Ordering},
|
||||
Arc, RwLock,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use tracing::info;
|
||||
|
||||
/// Circuit breaker configuration
|
||||
@@ -316,9 +321,10 @@ pub struct CircuitBreakerStats {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::thread;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_circuit_breaker_initial_state() {
|
||||
let cb = CircuitBreaker::new();
|
||||
|
||||
@@ -68,9 +68,10 @@ impl From<reqwest::Error> for WorkerError {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::error::Error;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_health_check_failed_display() {
|
||||
let error = WorkerError::HealthCheckFailed {
|
||||
|
||||
@@ -3,16 +3,22 @@
|
||||
//! Provides non-blocking worker management by queuing operations and processing
|
||||
//! them asynchronously in background worker tasks.
|
||||
|
||||
use crate::core::WorkerManager;
|
||||
use crate::protocols::worker_spec::{JobStatus, WorkerConfigRequest};
|
||||
use crate::server::AppContext;
|
||||
use std::{
|
||||
sync::{Arc, Weak},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use metrics::{counter, gauge, histogram};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{
|
||||
core::WorkerManager,
|
||||
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
|
||||
server::AppContext,
|
||||
};
|
||||
|
||||
/// Job types for control plane operations
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Job {
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Response;
|
||||
use rand::Rng;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::{http::StatusCode, response::Response};
|
||||
use rand::Rng;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
|
||||
/// Check if an HTTP status code indicates a retryable error
|
||||
pub fn is_retryable_status(status: StatusCode) -> bool {
|
||||
matches!(
|
||||
@@ -162,11 +163,14 @@ impl RetryExecutor {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
use axum::{http::StatusCode, response::IntoResponse};
|
||||
|
||||
use super::*;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn base_retry_config() -> RetryConfig {
|
||||
RetryConfig {
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use tokio::sync::{Mutex, Notify};
|
||||
use tracing::{debug, trace};
|
||||
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
use super::{CircuitBreaker, WorkerError, WorkerResult};
|
||||
use crate::core::CircuitState;
|
||||
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::protocols::worker_spec::WorkerInfo;
|
||||
use std::{
|
||||
fmt,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicUsize, Ordering},
|
||||
Arc, LazyLock,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures;
|
||||
use serde_json;
|
||||
use std::fmt;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tokio::time;
|
||||
use tokio::{
|
||||
sync::{Mutex, RwLock},
|
||||
time,
|
||||
};
|
||||
|
||||
use super::{CircuitBreaker, WorkerError, WorkerResult};
|
||||
use crate::{
|
||||
core::{BasicWorkerBuilder, CircuitState, DPAwareWorkerBuilder},
|
||||
grpc_client::SglangSchedulerClient,
|
||||
metrics::RouterMetrics,
|
||||
protocols::worker_spec::WorkerInfo,
|
||||
};
|
||||
|
||||
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
reqwest::Client::builder()
|
||||
@@ -1024,10 +1032,10 @@ pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{thread, time::Duration};
|
||||
|
||||
use super::*;
|
||||
use crate::core::CircuitBreakerConfig;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn test_worker_type_display() {
|
||||
@@ -1502,9 +1510,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_load_counter_performance() {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build();
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
|
||||
use super::worker::{
|
||||
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType,
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{
|
||||
circuit_breaker::{CircuitBreaker, CircuitBreakerConfig},
|
||||
worker::{
|
||||
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType,
|
||||
},
|
||||
};
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Builder for creating BasicWorker instances with fluent API
|
||||
pub struct BasicWorkerBuilder {
|
||||
@@ -100,6 +103,7 @@ impl BasicWorkerBuilder {
|
||||
atomic::{AtomicBool, AtomicUsize},
|
||||
Arc,
|
||||
};
|
||||
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
let bootstrap_host = match url::Url::parse(&self.url) {
|
||||
@@ -282,9 +286,10 @@ impl DPAwareWorkerBuilder {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::*;
|
||||
use crate::core::worker::Worker;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn test_basic_worker_builder_minimal() {
|
||||
|
||||
@@ -3,31 +3,35 @@
|
||||
//! Handles all aspects of worker lifecycle including discovery, initialization,
|
||||
//! runtime management, and health monitoring.
|
||||
|
||||
use crate::config::types::{
|
||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode,
|
||||
HealthCheckConfig, RouterConfig, RoutingMode,
|
||||
};
|
||||
use crate::core::{
|
||||
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig,
|
||||
Worker, WorkerFactory, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::worker_spec::{
|
||||
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
|
||||
};
|
||||
use crate::server::AppContext;
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use futures::future;
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{watch, Mutex};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::{
|
||||
sync::{watch, Mutex},
|
||||
task::JoinHandle,
|
||||
};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{
|
||||
config::types::{
|
||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode,
|
||||
HealthCheckConfig, RouterConfig, RoutingMode,
|
||||
},
|
||||
core::{
|
||||
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder,
|
||||
HealthConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType,
|
||||
},
|
||||
grpc_client::SglangSchedulerClient,
|
||||
policies::PolicyRegistry,
|
||||
protocols::worker_spec::{
|
||||
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
|
||||
},
|
||||
server::AppContext,
|
||||
};
|
||||
|
||||
static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(10))
|
||||
@@ -1803,9 +1807,10 @@ impl Drop for LoadMonitor {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_server_info() {
|
||||
let json = serde_json::json!({
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
//!
|
||||
//! Provides centralized registry for workers with model-based indexing
|
||||
|
||||
use crate::core::{ConnectionMode, Worker, WorkerType};
|
||||
use dashmap::DashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::{ConnectionMode, Worker, WorkerType};
|
||||
|
||||
/// Unique identifier for a worker
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct WorkerId(String);
|
||||
@@ -363,8 +365,10 @@ impl WorkerRegistry {
|
||||
/// Start a health checker for all workers in the registry
|
||||
/// This should be called once after the registry is populated with workers
|
||||
pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker {
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
let shutdown_clone = shutdown.clone();
|
||||
@@ -433,9 +437,10 @@ pub struct WorkerRegistryStats {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::*;
|
||||
use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_worker_registry() {
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::sync::RwLock;
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
sync::RwLock,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
use super::conversation_items::{
|
||||
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams,
|
||||
Result, SortOrder,
|
||||
use super::{
|
||||
conversation_items::{
|
||||
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams,
|
||||
Result, SortOrder,
|
||||
},
|
||||
conversations::ConversationId,
|
||||
};
|
||||
use super::conversations::ConversationId;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MemoryConversationItemStorage {
|
||||
@@ -190,9 +194,10 @@ impl ConversationItemStorage for MemoryConversationItemStorage {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::{TimeZone, Utc};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn make_item(
|
||||
item_type: &str,
|
||||
role: Option<&str>,
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
use crate::config::OracleConfig;
|
||||
use crate::data_connector::conversation_items::{
|
||||
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
|
||||
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
|
||||
};
|
||||
use crate::data_connector::conversations::ConversationId;
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
|
||||
use oracle::sql_type::ToSql;
|
||||
use oracle::Connection;
|
||||
use oracle::{sql_type::ToSql, Connection};
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{
|
||||
config::OracleConfig,
|
||||
data_connector::{
|
||||
conversation_items::{
|
||||
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
|
||||
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
|
||||
},
|
||||
conversations::ConversationId,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OracleConversationItemStorage {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use std::{
|
||||
fmt::{Display, Formatter},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::conversations::ConversationId;
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::conversations::{
|
||||
Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation,
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
use crate::config::OracleConfig;
|
||||
use crate::data_connector::conversations::{
|
||||
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
|
||||
ConversationStorageError, NewConversation, Result,
|
||||
};
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
|
||||
use oracle::{sql_type::OracleType, Connection};
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{
|
||||
config::OracleConfig,
|
||||
data_connector::conversations::{
|
||||
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
|
||||
ConversationStorageError, NewConversation, Result,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OracleConversationStorage {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use std::{
|
||||
fmt::{Display, Formatter},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map as JsonMap, Value};
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct ConversationId(pub String);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse};
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
use crate::config::OracleConfig;
|
||||
use crate::data_connector::responses::{
|
||||
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
|
||||
StoredResponse,
|
||||
};
|
||||
use std::{collections::HashMap, path::Path, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
|
||||
use oracle::{Connection, Row};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{
|
||||
config::OracleConfig,
|
||||
data_connector::responses::{
|
||||
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
|
||||
StoredResponse,
|
||||
},
|
||||
};
|
||||
|
||||
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
|
||||
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses";
|
||||
@@ -510,9 +511,10 @@ impl OracleErrorExt for ResponseStorageError {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_handles_empty_input() {
|
||||
assert!(parse_tool_calls(None).unwrap().is_empty());
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Response identifier
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use tonic::{transport::Channel, Request, Streaming};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::protocols::chat::ChatCompletionRequest;
|
||||
use crate::protocols::common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue};
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::protocols::sampling_params::SamplingParams as GenerateSamplingParams;
|
||||
use crate::protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
|
||||
generate::GenerateRequest,
|
||||
sampling_params::SamplingParams as GenerateSamplingParams,
|
||||
};
|
||||
|
||||
// Include the generated protobuf code
|
||||
pub mod proto {
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use tracing::Level;
|
||||
use tracing_appender::non_blocking::WorkerGuard;
|
||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||
use tracing_appender::{
|
||||
non_blocking::WorkerGuard,
|
||||
rolling::{RollingFileAppender, Rotation},
|
||||
};
|
||||
use tracing_log::LogTracer;
|
||||
use tracing_subscriber::fmt::time::ChronoUtc;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{EnvFilter, Layer};
|
||||
use tracing_subscriber::{
|
||||
fmt::time::ChronoUtc, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoggingConfig {
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
use clap::{ArgAction, Parser, ValueEnum};
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
||||
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
|
||||
RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::metrics::PrometheusConfig;
|
||||
use sglang_router_rs::server::{self, ServerConfig};
|
||||
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use clap::{ArgAction, Parser, ValueEnum};
|
||||
use sglang_router_rs::{
|
||||
config::{
|
||||
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
||||
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
|
||||
RouterConfig, RoutingMode,
|
||||
},
|
||||
metrics::PrometheusConfig,
|
||||
server::{self, ServerConfig},
|
||||
service_discovery::ServiceDiscoveryConfig,
|
||||
};
|
||||
|
||||
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let mut prefill_entries = Vec::new();
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::{borrow::Cow, collections::HashMap, time::Duration};
|
||||
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use dashmap::DashMap;
|
||||
use rmcp::{
|
||||
@@ -13,7 +15,6 @@ use rmcp::{
|
||||
RoleClient, ServiceExt,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{borrow::Cow, collections::HashMap, time::Duration};
|
||||
|
||||
use crate::mcp::{
|
||||
config::{McpConfig, McpServerConfig, McpTransport},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpConfig {
|
||||
pub servers: Vec<McpServerConfig>,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
// OAuth authentication support for MCP servers
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::Html,
|
||||
@@ -8,7 +10,6 @@ use axum::{
|
||||
};
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use serde::Deserialize;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
|
||||
use crate::mcp::error::{McpError, McpResult};
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, SocketAddr},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PrometheusConfig {
|
||||
@@ -620,9 +623,10 @@ impl TokenizerMetrics {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::TcpListener;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_prometheus_config_default() {
|
||||
let config = PrometheusConfig::default();
|
||||
@@ -912,9 +916,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_metric_updates() {
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
thread,
|
||||
};
|
||||
|
||||
let done = Arc::new(AtomicBool::new(false));
|
||||
let mut handles = vec![];
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
body::Body, extract::Request, extract::State, http::header, http::HeaderValue,
|
||||
http::StatusCode, middleware::Next, response::IntoResponse, response::Response,
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{header, HeaderValue, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use rand::Rng;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use subtle::ConstantTimeEq;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tower::{Layer, Service};
|
||||
@@ -14,9 +21,7 @@ use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
|
||||
use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
|
||||
|
||||
pub use crate::core::token_bucket::TokenBucket;
|
||||
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::server::AppState;
|
||||
use crate::{metrics::RouterMetrics, server::AppState};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthConfig {
|
||||
|
||||
@@ -59,17 +59,15 @@
|
||||
during the next eviction cycle.
|
||||
*/
|
||||
|
||||
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::tree::Tree;
|
||||
use std::{sync::Arc, thread, time::Duration};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tracing::debug;
|
||||
|
||||
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
|
||||
use crate::{core::Worker, metrics::RouterMetrics, tree::Tree};
|
||||
|
||||
/// Cache-aware routing policy
|
||||
///
|
||||
/// Routes requests based on cache affinity when load is balanced,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
//! Factory for creating load balancing policies
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{
|
||||
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
|
||||
RoundRobinPolicy,
|
||||
};
|
||||
use crate::config::PolicyConfig;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Factory for creating policy instances
|
||||
pub struct PolicyFactory;
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
//! This module provides a unified abstraction for routing policies that work
|
||||
//! across both regular and prefill-decode (PD) routing modes.
|
||||
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::core::Worker;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod cache_aware;
|
||||
mod factory;
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
//! Power-of-two choices load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use rand::Rng;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tracing::info;
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::{core::Worker, metrics::RouterMetrics};
|
||||
|
||||
/// Power-of-two choices policy
|
||||
///
|
||||
/// Randomly selects two workers and routes to the one with lower load.
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
//! Random load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::{core::Worker, metrics::RouterMetrics};
|
||||
|
||||
/// Random selection policy
|
||||
///
|
||||
/// Selects workers randomly with uniform distribution among healthy workers.
|
||||
@@ -50,9 +51,10 @@ impl LoadBalancingPolicy for RandomPolicy {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::*;
|
||||
use crate::core::{BasicWorkerBuilder, WorkerType};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_random_selection() {
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Policy Registry for managing model-to-policy mappings
|
||||
///
|
||||
/// This registry manages the dynamic assignment of load balancing policies to models.
|
||||
@@ -8,11 +15,7 @@ use super::{
|
||||
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
|
||||
RoundRobinPolicy,
|
||||
};
|
||||
use crate::config::types::PolicyConfig;
|
||||
use crate::core::Worker;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
use crate::{config::types::PolicyConfig, core::Worker};
|
||||
|
||||
/// Registry for managing model-to-policy mappings
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
//! Round-robin load balancing policy
|
||||
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use crate::{core::Worker, metrics::RouterMetrics};
|
||||
|
||||
/// Round-robin selection policy
|
||||
///
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use validator::Validate;
|
||||
|
||||
use super::common::*;
|
||||
use super::sampling_params::{validate_top_k_value, validate_top_p_value};
|
||||
use super::{
|
||||
common::*,
|
||||
sampling_params::{validate_top_k_value, validate_top_p_value},
|
||||
};
|
||||
use crate::protocols::validated::Normalizable;
|
||||
|
||||
// ============================================================================
|
||||
@@ -532,11 +535,12 @@ impl Normalizable for ChatCompletionRequest {
|
||||
// Apply tool_choice defaults
|
||||
if self.tool_choice.is_none() {
|
||||
if let Some(tools) = &self.tools {
|
||||
self.tool_choice = if !tools.is_empty() {
|
||||
Some(ToolChoice::Value(ToolChoiceValue::Auto))
|
||||
let choice_value = if !tools.is_empty() {
|
||||
ToolChoiceValue::Auto
|
||||
} else {
|
||||
Some(ToolChoice::Value(ToolChoiceValue::None))
|
||||
ToolChoiceValue::None
|
||||
};
|
||||
self.tool_choice = Some(ToolChoice::Value(choice_value));
|
||||
}
|
||||
// If tools is None, leave tool_choice as None (don't set it)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ============================================================================
|
||||
// Default value helpers
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::common::*;
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use validator::Validate;
|
||||
|
||||
use super::common::{default_true, GenerationRequest, InputIds};
|
||||
use super::sampling_params::SamplingParams;
|
||||
use super::{
|
||||
common::{default_true, GenerationRequest, InputIds},
|
||||
sampling_params::SamplingParams,
|
||||
};
|
||||
use crate::protocols::validated::Normalizable;
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use validator::Validate;
|
||||
|
||||
use super::common::{default_model, default_true, GenerationRequest, StringOrArray, UsageInfo};
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
// OpenAI Responses API types
|
||||
// https://platform.openai.com/docs/api-reference/responses
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Import shared types from common module
|
||||
use super::common::{
|
||||
|
||||
@@ -117,10 +117,11 @@ impl<T> std::ops::DerefMut for ValidatedJson<T> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use validator::Validate;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Validate)]
|
||||
struct TestRequest {
|
||||
#[validate(range(min = 0.0, max = 1.0))]
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
//!
|
||||
//! Defines the request/response structures for worker management endpoints
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Worker configuration for API requests
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct WorkerConfigRequest {
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
// Factory and registry for creating model-specific reasoning parsers.
|
||||
// Now with parser pooling support for efficient reuse across requests.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::reasoning_parser::parsers::{
|
||||
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
|
||||
QwenThinkingParser, Step3Parser,
|
||||
use crate::reasoning_parser::{
|
||||
parsers::{
|
||||
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
|
||||
QwenThinkingParser, Step3Parser,
|
||||
},
|
||||
traits::{ParseError, ParserConfig, ReasoningParser},
|
||||
};
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
|
||||
|
||||
/// Type alias for pooled parser instances.
|
||||
/// Uses tokio::Mutex to avoid blocking the async executor.
|
||||
@@ -402,8 +406,10 @@ mod tests {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
|
||||
async fn test_high_concurrency_parser_access() {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::Instant;
|
||||
use std::{
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
let factory = ParserFactory::new();
|
||||
let num_tasks = 100;
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
// This parser starts with in_reasoning=true, assuming all text is reasoning
|
||||
// until an end token is encountered.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
use crate::reasoning_parser::{
|
||||
parsers::BaseReasoningParser,
|
||||
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
|
||||
};
|
||||
|
||||
/// DeepSeek-R1 reasoning parser.
|
||||
///
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// GLM45 specific reasoning parser.
|
||||
// Uses the same format as Qwen3 but has its own implementation for debugging.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
use crate::reasoning_parser::{
|
||||
parsers::BaseReasoningParser,
|
||||
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
|
||||
};
|
||||
|
||||
/// GLM45 reasoning parser.
|
||||
///
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// Kimi specific reasoning parser.
|
||||
// This parser uses Unicode tokens and starts with in_reasoning=false.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
use crate::reasoning_parser::{
|
||||
parsers::BaseReasoningParser,
|
||||
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
|
||||
};
|
||||
|
||||
/// Kimi reasoning parser.
|
||||
///
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
// This parser starts with in_reasoning=false, requiring an explicit
|
||||
// start token to enter reasoning mode.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
use crate::reasoning_parser::{
|
||||
parsers::BaseReasoningParser,
|
||||
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
|
||||
};
|
||||
|
||||
/// Qwen3 reasoning parser.
|
||||
///
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// Step3 specific reasoning parser.
|
||||
// Uses the same format as DeepSeek-R1 but has its own implementation for debugging.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
use crate::reasoning_parser::{
|
||||
parsers::BaseReasoningParser,
|
||||
traits::{ParseError, ParserConfig, ParserResult, ReasoningParser},
|
||||
};
|
||||
|
||||
/// Step3 reasoning parser.
|
||||
///
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
//! Factory for creating router instances
|
||||
|
||||
use super::grpc::pd_router::GrpcPDRouter;
|
||||
use super::grpc::router::GrpcRouter;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{
|
||||
grpc::{pd_router::GrpcPDRouter, router::GrpcRouter},
|
||||
http::{pd_router::PDRouter, router::Router},
|
||||
openai::OpenAIRouter,
|
||||
RouterTrait,
|
||||
};
|
||||
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
|
||||
use crate::policies::PolicyFactory;
|
||||
use crate::server::AppContext;
|
||||
use std::sync::Arc;
|
||||
use crate::{
|
||||
config::{ConnectionMode, PolicyConfig, RoutingMode},
|
||||
policies::PolicyFactory,
|
||||
server::AppContext,
|
||||
};
|
||||
|
||||
/// Factory for creating router instances based on configuration
|
||||
pub struct RouterFactory;
|
||||
|
||||
@@ -4,20 +4,22 @@
|
||||
//! eliminating deep parameter passing chains and providing a single source of truth
|
||||
//! for request state.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use axum::http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::core::Worker;
|
||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||
use crate::protocols::chat::{ChatCompletionRequest, ChatCompletionResponse};
|
||||
use crate::protocols::generate::{GenerateRequest, GenerateResponse};
|
||||
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
|
||||
use crate::tokenizer::stop::StopSequenceDecoder;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserFactory as ToolParserFactory;
|
||||
use crate::{
|
||||
core::Worker,
|
||||
grpc_client::{proto, SglangSchedulerClient},
|
||||
protocols::{
|
||||
chat::{ChatCompletionRequest, ChatCompletionResponse},
|
||||
generate::{GenerateRequest, GenerateResponse},
|
||||
},
|
||||
reasoning_parser::ParserFactory as ReasoningParserFactory,
|
||||
tokenizer::{stop::StopSequenceDecoder, traits::Tokenizer},
|
||||
tool_parser::ParserFactory as ToolParserFactory,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Core Context Types
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
//! gRPC router implementations
|
||||
|
||||
use crate::grpc_client::proto;
|
||||
use crate::protocols::common::StringOrArray;
|
||||
use crate::{grpc_client::proto, protocols::common::StringOrArray};
|
||||
|
||||
pub mod context;
|
||||
pub mod pd_router;
|
||||
|
||||
@@ -1,19 +1,7 @@
|
||||
// PD (Prefill-Decode) gRPC Router Implementation
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::chat::ChatCompletionRequest;
|
||||
use crate::protocols::completion::CompletionRequest;
|
||||
use crate::protocols::embedding::EmbeddingRequest;
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::protocols::rerank::RerankRequest;
|
||||
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
|
||||
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::server::AppContext;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserFactory as ToolParserFactory;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -21,12 +9,27 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
use tracing::debug;
|
||||
|
||||
use super::context::SharedComponents;
|
||||
use super::pipeline::RequestPipeline;
|
||||
use super::{context::SharedComponents, pipeline::RequestPipeline};
|
||||
use crate::{
|
||||
config::types::RetryConfig,
|
||||
core::{ConnectionMode, WorkerRegistry, WorkerType},
|
||||
policies::PolicyRegistry,
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
rerank::RerankRequest,
|
||||
responses::{ResponsesGetParams, ResponsesRequest},
|
||||
},
|
||||
reasoning_parser::ParserFactory as ReasoningParserFactory,
|
||||
routers::RouterTrait,
|
||||
server::AppContext,
|
||||
tokenizer::traits::Tokenizer,
|
||||
tool_parser::ParserFactory as ToolParserFactory,
|
||||
};
|
||||
|
||||
/// gRPC PD (Prefill-Decode) router implementation for SGLang
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -3,29 +3,29 @@
|
||||
//! This module defines the core pipeline abstraction and individual processing stages
|
||||
//! that transform a RequestContext through its lifecycle.
|
||||
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Instant, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::context::*;
|
||||
use super::processing;
|
||||
use super::streaming;
|
||||
use super::utils;
|
||||
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
|
||||
use crate::grpc_client::proto;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::chat::ChatCompletionRequest;
|
||||
use crate::protocols::common::InputIds;
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserFactory as ToolParserFactory;
|
||||
use proto::DisaggregatedParams;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
use tracing::{debug, error, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{context::*, processing, streaming, utils};
|
||||
use crate::{
|
||||
core::{ConnectionMode, Worker, WorkerRegistry, WorkerType},
|
||||
grpc_client::proto,
|
||||
policies::PolicyRegistry,
|
||||
protocols::{chat::ChatCompletionRequest, common::InputIds, generate::GenerateRequest},
|
||||
reasoning_parser::ParserFactory as ReasoningParserFactory,
|
||||
tokenizer::traits::Tokenizer,
|
||||
tool_parser::ParserFactory as ToolParserFactory,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Pipeline Trait
|
||||
// ============================================================================
|
||||
|
||||
@@ -3,28 +3,30 @@
|
||||
//! This module contains response processing functions that are shared between
|
||||
//! the regular router and PD router, eliminating ~1,200 lines of exact duplicates.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
use proto::generate_complete::MatchedStop;
|
||||
use serde_json::Value;
|
||||
use tracing::error;
|
||||
|
||||
use crate::grpc_client::proto;
|
||||
use crate::protocols::chat::{
|
||||
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
|
||||
use super::{
|
||||
context::{DispatchMetadata, ExecutionResult},
|
||||
utils,
|
||||
};
|
||||
use crate::protocols::common::{
|
||||
FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage,
|
||||
use crate::{
|
||||
grpc_client::proto,
|
||||
protocols::{
|
||||
chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse},
|
||||
common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage},
|
||||
generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse},
|
||||
},
|
||||
reasoning_parser::ParserFactory as ReasoningParserFactory,
|
||||
tokenizer::{
|
||||
stop::{SequenceDecoderOutput, StopSequenceDecoder},
|
||||
traits::Tokenizer,
|
||||
},
|
||||
tool_parser::ParserFactory as ToolParserFactory,
|
||||
};
|
||||
use crate::protocols::generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse};
|
||||
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
|
||||
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserFactory as ToolParserFactory;
|
||||
use proto::generate_complete::MatchedStop;
|
||||
use std::time::Instant;
|
||||
|
||||
use super::context::{DispatchMetadata, ExecutionResult};
|
||||
use super::utils;
|
||||
|
||||
// ============================================================================
|
||||
// Response Processor - Main Entry Point
|
||||
|
||||
@@ -11,23 +11,25 @@ use axum::{
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::WorkerRegistry;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::chat::ChatCompletionRequest;
|
||||
use crate::protocols::completion::CompletionRequest;
|
||||
use crate::protocols::embedding::EmbeddingRequest;
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::protocols::rerank::RerankRequest;
|
||||
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
|
||||
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::server::AppContext;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserFactory as ToolParserFactory;
|
||||
|
||||
use super::context::SharedComponents;
|
||||
use super::pipeline::RequestPipeline;
|
||||
use super::{context::SharedComponents, pipeline::RequestPipeline};
|
||||
use crate::{
|
||||
config::types::RetryConfig,
|
||||
core::WorkerRegistry,
|
||||
policies::PolicyRegistry,
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
rerank::RerankRequest,
|
||||
responses::{ResponsesGetParams, ResponsesRequest},
|
||||
},
|
||||
reasoning_parser::ParserFactory as ReasoningParserFactory,
|
||||
routers::RouterTrait,
|
||||
server::AppContext,
|
||||
tokenizer::traits::Tokenizer,
|
||||
tool_parser::ParserFactory as ToolParserFactory,
|
||||
};
|
||||
|
||||
/// gRPC router implementation for SGLang
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -3,38 +3,40 @@
|
||||
//! This module contains shared streaming logic for both Regular and PD routers,
|
||||
//! eliminating ~600 lines of duplication.
|
||||
|
||||
use axum::response::Response;
|
||||
use axum::{body::Body, http::StatusCode};
|
||||
use std::{collections::HashMap, io, sync::Arc, time::Instant};
|
||||
|
||||
use axum::{body::Body, http::StatusCode, response::Response};
|
||||
use bytes::Bytes;
|
||||
use http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use proto::{
|
||||
generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId},
|
||||
generate_response::Response::{Chunk, Complete, Error},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio::sync::{mpsc, mpsc::UnboundedSender};
|
||||
use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::context;
|
||||
use super::utils;
|
||||
use crate::grpc_client::proto;
|
||||
use crate::protocols::chat::{
|
||||
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
|
||||
use super::{context, utils};
|
||||
use crate::{
|
||||
grpc_client::proto,
|
||||
protocols::{
|
||||
chat::{
|
||||
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
|
||||
},
|
||||
common::{
|
||||
ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice,
|
||||
ToolChoiceValue, Usage,
|
||||
},
|
||||
generate::GenerateRequest,
|
||||
},
|
||||
reasoning_parser::ReasoningParser,
|
||||
tokenizer::{
|
||||
stop::{SequenceDecoderOutput, StopSequenceDecoder},
|
||||
traits::Tokenizer,
|
||||
},
|
||||
tool_parser::ToolParser,
|
||||
};
|
||||
use crate::protocols::common::{
|
||||
ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice,
|
||||
ToolChoiceValue, Usage,
|
||||
};
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::reasoning_parser::ReasoningParser;
|
||||
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ToolParser;
|
||||
use proto::generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId};
|
||||
use proto::generate_response::Response::{Chunk, Complete, Error};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Shared streaming processor for both single and dual dispatch modes
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -1,19 +1,7 @@
|
||||
//! Shared utilities for gRPC routers
|
||||
|
||||
use super::ProcessedMessages;
|
||||
use crate::core::Worker;
|
||||
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
|
||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||
use crate::protocols::chat::{ChatCompletionRequest, ChatMessage};
|
||||
use crate::protocols::common::{
|
||||
ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall,
|
||||
ToolChoice, ToolChoiceValue, TopLogProb,
|
||||
};
|
||||
use crate::protocols::generate::GenerateFinishReason;
|
||||
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tokenizer::HuggingFaceTokenizer;
|
||||
pub use crate::tokenizer::StopSequenceDecoder;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
@@ -21,11 +9,29 @@ use axum::{
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::ProcessedMessages;
|
||||
pub use crate::tokenizer::StopSequenceDecoder;
|
||||
use crate::{
|
||||
core::Worker,
|
||||
grpc_client::{proto, sglang_scheduler::AbortOnDropStream, SglangSchedulerClient},
|
||||
protocols::{
|
||||
chat::{ChatCompletionRequest, ChatMessage},
|
||||
common::{
|
||||
ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall,
|
||||
ToolChoice, ToolChoiceValue, TopLogProb,
|
||||
},
|
||||
generate::GenerateFinishReason,
|
||||
},
|
||||
tokenizer::{
|
||||
chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
|
||||
traits::Tokenizer,
|
||||
HuggingFaceTokenizer,
|
||||
},
|
||||
};
|
||||
|
||||
/// Get gRPC client from worker, returning appropriate error response on failure
|
||||
pub async fn get_grpc_client_from_worker(
|
||||
worker: &Arc<dyn Worker>,
|
||||
@@ -953,12 +959,17 @@ pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> Generate
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocols::chat::{ChatMessage, UserMessageContent};
|
||||
use crate::protocols::common::{ContentPart, ImageUrl};
|
||||
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
protocols::{
|
||||
chat::{ChatMessage, UserMessageContent},
|
||||
common::{ContentPart, ImageUrl},
|
||||
},
|
||||
tokenizer::chat_template::ChatTemplateContentFormat,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_string_format() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
use axum::body::Body;
|
||||
use axum::extract::Request;
|
||||
use axum::http::HeaderMap;
|
||||
use axum::{body::Body, extract::Request, http::HeaderMap};
|
||||
|
||||
/// Copy request headers to a Vec of name-value string pairs
|
||||
/// Used for forwarding headers to backend workers
|
||||
|
||||
@@ -1,19 +1,5 @@
|
||||
use super::pd_types::api_path;
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
use crate::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent};
|
||||
use crate::protocols::common::{InputIds, StringOrArray};
|
||||
use crate::protocols::completion::CompletionRequest;
|
||||
use crate::protocols::embedding::EmbeddingRequest;
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::protocols::rerank::RerankRequest;
|
||||
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
|
||||
use crate::routers::header_utils;
|
||||
use crate::routers::RouterTrait;
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -25,11 +11,29 @@ use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::Serialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::pd_types::api_path;
|
||||
use crate::{
|
||||
config::types::RetryConfig,
|
||||
core::{
|
||||
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
|
||||
},
|
||||
metrics::RouterMetrics,
|
||||
policies::{LoadBalancingPolicy, PolicyRegistry},
|
||||
protocols::{
|
||||
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
|
||||
common::{InputIds, StringOrArray},
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
rerank::RerankRequest,
|
||||
responses::{ResponsesGetParams, ResponsesRequest},
|
||||
},
|
||||
routers::{header_utils, RouterTrait},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PDRouter {
|
||||
pub worker_registry: Arc<WorkerRegistry>,
|
||||
|
||||
@@ -1,35 +1,39 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::chat::ChatCompletionRequest;
|
||||
use crate::protocols::common::GenerationRequest;
|
||||
use crate::protocols::completion::CompletionRequest;
|
||||
use crate::protocols::embedding::EmbeddingRequest;
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::protocols::rerank::{RerankRequest, RerankResponse, RerankResult};
|
||||
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
|
||||
use crate::routers::header_utils;
|
||||
use crate::routers::RouterTrait;
|
||||
use axum::body::to_bytes;
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
body::{to_bytes, Body},
|
||||
extract::Request,
|
||||
http::{
|
||||
header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode,
|
||||
header::{CONTENT_LENGTH, CONTENT_TYPE},
|
||||
HeaderMap, HeaderValue, Method, StatusCode,
|
||||
},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error};
|
||||
|
||||
use crate::{
|
||||
config::types::RetryConfig,
|
||||
core::{
|
||||
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
|
||||
},
|
||||
metrics::RouterMetrics,
|
||||
policies::PolicyRegistry,
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
common::GenerationRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
rerank::{RerankRequest, RerankResponse, RerankResult},
|
||||
responses::{ResponsesGetParams, ResponsesRequest},
|
||||
},
|
||||
routers::{header_utils, RouterTrait},
|
||||
};
|
||||
|
||||
/// Regular router that uses injected load balancing policies
|
||||
#[derive(Debug)]
|
||||
pub struct Router {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
//! Router implementations
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -7,16 +9,17 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::protocols::chat::ChatCompletionRequest;
|
||||
use crate::protocols::completion::CompletionRequest;
|
||||
use crate::protocols::embedding::EmbeddingRequest;
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::protocols::rerank::RerankRequest;
|
||||
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
rerank::RerankRequest,
|
||||
responses::{ResponsesGetParams, ResponsesRequest},
|
||||
};
|
||||
|
||||
pub mod factory;
|
||||
pub mod grpc;
|
||||
pub mod header_utils;
|
||||
@@ -25,7 +28,6 @@ pub mod openai; // New refactored OpenAI router module
|
||||
pub mod router_manager;
|
||||
|
||||
pub use factory::RouterFactory;
|
||||
|
||||
// Re-export HTTP routers for convenience
|
||||
pub use http::{pd_router, pd_types, router};
|
||||
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
//! Conversation CRUD operations and persistence
|
||||
|
||||
use crate::data_connector::{
|
||||
conversation_items::ListParams, conversation_items::SortOrder, Conversation, ConversationId,
|
||||
ConversationItemId, ConversationItemStorage, ConversationStorage, NewConversation,
|
||||
NewConversationItem, ResponseId, ResponseStorage, SharedConversationItemStorage,
|
||||
SharedConversationStorage,
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use crate::protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use chrono::Utc;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::responses::build_stored_response;
|
||||
use crate::{
|
||||
data_connector::{
|
||||
conversation_items::{ListParams, SortOrder},
|
||||
Conversation, ConversationId, ConversationItemId, ConversationItemStorage,
|
||||
ConversationStorage, NewConversation, NewConversationItem, ResponseId, ResponseStorage,
|
||||
SharedConversationItemStorage, SharedConversationStorage,
|
||||
},
|
||||
protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest},
|
||||
};
|
||||
|
||||
/// Maximum number of properties allowed in conversation metadata
|
||||
pub(crate) const MAX_METADATA_PROPERTIES: usize = 16;
|
||||
|
||||
@@ -8,19 +8,20 @@
|
||||
//! - Payload transformation for MCP tool interception
|
||||
//! - Metadata injection for MCP operations
|
||||
|
||||
use crate::mcp::McpClientManager;
|
||||
use crate::protocols::responses::{
|
||||
ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest,
|
||||
};
|
||||
use crate::routers::header_utils::apply_request_headers;
|
||||
use std::{io, sync::Arc};
|
||||
|
||||
use axum::http::HeaderMap;
|
||||
use bytes::Bytes;
|
||||
use serde_json::{json, to_value, Value};
|
||||
use std::{io, sync::Arc};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::utils::event_types;
|
||||
use crate::{
|
||||
mcp::McpClientManager,
|
||||
protocols::responses::{ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest},
|
||||
routers::header_utils::apply_request_headers,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Configuration and State Types
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
//! Response storage, patching, and extraction utilities
|
||||
|
||||
use crate::data_connector::{ResponseId, StoredResponse};
|
||||
use crate::protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde_json::{json, Value};
|
||||
use tracing::warn;
|
||||
|
||||
use super::utils::event_types;
|
||||
use crate::{
|
||||
data_connector::{ResponseId, StoredResponse},
|
||||
protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest},
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Response Storage Operations
|
||||
|
||||
@@ -1,21 +1,10 @@
|
||||
//! OpenAI router - main coordinator that delegates to specialized modules
|
||||
|
||||
use crate::config::CircuitBreakerConfig;
|
||||
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
|
||||
use crate::data_connector::{
|
||||
conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId,
|
||||
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
|
||||
use std::{
|
||||
any::Any,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
};
|
||||
use crate::protocols::chat::ChatCompletionRequest;
|
||||
use crate::protocols::completion::CompletionRequest;
|
||||
use crate::protocols::embedding::EmbeddingRequest;
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::protocols::rerank::RerankRequest;
|
||||
use crate::protocols::responses::{
|
||||
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
|
||||
ResponsesRequest,
|
||||
};
|
||||
use crate::routers::header_utils::apply_request_headers;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
@@ -25,10 +14,6 @@ use axum::{
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::{json, to_value, Value};
|
||||
use std::{
|
||||
any::Any,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::warn;
|
||||
@@ -39,12 +24,35 @@ use super::conversations::{
|
||||
get_conversation, get_conversation_item, list_conversation_items, persist_conversation_items,
|
||||
update_conversation,
|
||||
};
|
||||
use super::mcp::{
|
||||
execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
|
||||
McpLoopConfig,
|
||||
use super::{
|
||||
mcp::{
|
||||
execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
|
||||
McpLoopConfig,
|
||||
},
|
||||
responses::{mask_tools_as_mcp, patch_streaming_response_json},
|
||||
streaming::handle_streaming_response,
|
||||
};
|
||||
use crate::{
|
||||
config::CircuitBreakerConfig,
|
||||
core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig},
|
||||
data_connector::{
|
||||
conversation_items::{ListParams, SortOrder},
|
||||
ConversationId, ResponseId, SharedConversationItemStorage, SharedConversationStorage,
|
||||
SharedResponseStorage,
|
||||
},
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
rerank::RerankRequest,
|
||||
responses::{
|
||||
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
|
||||
ResponsesRequest,
|
||||
},
|
||||
},
|
||||
routers::header_utils::apply_request_headers,
|
||||
};
|
||||
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json};
|
||||
use super::streaming::handle_streaming_response;
|
||||
|
||||
// ============================================================================
|
||||
// OpenAIRouter Struct
|
||||
|
||||
@@ -7,11 +7,8 @@
|
||||
//! - MCP tool execution loops within streaming responses
|
||||
//! - Event transformation and output index remapping
|
||||
|
||||
use crate::data_connector::{
|
||||
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
|
||||
};
|
||||
use crate::protocols::responses::{ResponseToolType, ResponsesRequest};
|
||||
use crate::routers::header_utils::{apply_request_headers, preserve_response_headers};
|
||||
use std::{borrow::Cow, io, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||
@@ -20,20 +17,28 @@ use axum::{
|
||||
use bytes::Bytes;
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::{json, Value};
|
||||
use std::{borrow::Cow, io, sync::Arc};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::warn;
|
||||
|
||||
// Import from sibling modules
|
||||
use super::conversations::persist_conversation_items;
|
||||
use super::mcp::{
|
||||
build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming,
|
||||
mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, send_mcp_list_tools_events,
|
||||
McpLoopConfig, ToolLoopState,
|
||||
use super::{
|
||||
mcp::{
|
||||
build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming,
|
||||
mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
|
||||
send_mcp_list_tools_events, McpLoopConfig, ToolLoopState,
|
||||
},
|
||||
responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block},
|
||||
utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction},
|
||||
};
|
||||
use crate::{
|
||||
data_connector::{
|
||||
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
|
||||
},
|
||||
protocols::responses::{ResponseToolType, ResponsesRequest},
|
||||
routers::header_utils::{apply_request_headers, preserve_response_headers},
|
||||
};
|
||||
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block};
|
||||
use super::utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction};
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Response Accumulator
|
||||
|
||||
@@ -4,16 +4,8 @@
|
||||
//! - Single Router Mode (enable_igw=false): Router owns workers directly
|
||||
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
||||
|
||||
use crate::config::{ConnectionMode, RoutingMode};
|
||||
use crate::core::{WorkerRegistry, WorkerType};
|
||||
use crate::protocols::chat::ChatCompletionRequest;
|
||||
use crate::protocols::completion::CompletionRequest;
|
||||
use crate::protocols::embedding::EmbeddingRequest;
|
||||
use crate::protocols::generate::GenerateRequest;
|
||||
use crate::protocols::rerank::RerankRequest;
|
||||
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::server::{AppContext, ServerConfig};
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -23,9 +15,23 @@ use axum::{
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::{
|
||||
config::{ConnectionMode, RoutingMode},
|
||||
core::{WorkerRegistry, WorkerType},
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
rerank::RerankRequest,
|
||||
responses::{ResponsesGetParams, ResponsesRequest},
|
||||
},
|
||||
routers::RouterTrait,
|
||||
server::{AppContext, ServerConfig},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct RouterId(String);
|
||||
|
||||
|
||||
@@ -1,3 +1,24 @@
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, OnceLock,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, Request, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{delete, get, post},
|
||||
serve, Json, Router,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tokio::{net::TcpListener, signal, spawn};
|
||||
use tracing::{error, info, warn, Level};
|
||||
|
||||
use crate::{
|
||||
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
||||
core::{
|
||||
@@ -30,24 +51,6 @@ use crate::{
|
||||
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
|
||||
tool_parser::ParserFactory as ToolParserFactory,
|
||||
};
|
||||
use axum::{
|
||||
extract::{Path, Query, Request, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{delete, get, post},
|
||||
serve, Json, Router,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::OnceLock;
|
||||
use std::{
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{net::TcpListener, signal, spawn};
|
||||
use tracing::{error, info, warn, Level};
|
||||
|
||||
//
|
||||
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
use crate::core::WorkerManager;
|
||||
use crate::protocols::worker_spec::WorkerConfigRequest;
|
||||
use crate::server::AppContext;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use k8s_openapi::api::core::v1::Pod;
|
||||
use kube::{
|
||||
api::Api,
|
||||
runtime::watcher::{watcher, Config},
|
||||
runtime::WatchStreamExt,
|
||||
runtime::{
|
||||
watcher::{watcher, Config},
|
||||
WatchStreamExt,
|
||||
},
|
||||
Client,
|
||||
};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use rustls;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::task;
|
||||
use tokio::time;
|
||||
use tokio::{task, time};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{core::WorkerManager, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServiceDiscoveryConfig {
|
||||
pub enabled: bool,
|
||||
@@ -452,10 +453,12 @@ async fn handle_pod_deletion(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use k8s_openapi::{
|
||||
api::core::v1::{Pod, PodCondition, PodSpec, PodStatus},
|
||||
apimachinery::pkg::apis::meta::v1::{ObjectMeta, Time},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus};
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
|
||||
|
||||
fn create_k8s_pod(
|
||||
name: Option<&str>,
|
||||
@@ -535,8 +538,7 @@ mod tests {
|
||||
}
|
||||
|
||||
async fn create_test_app_context() -> Arc<AppContext> {
|
||||
use crate::config::RouterConfig;
|
||||
use crate::middleware::TokenBucket;
|
||||
use crate::{config::RouterConfig, middleware::TokenBucket};
|
||||
|
||||
let router_config = RouterConfig {
|
||||
worker_startup_timeout_secs: 1,
|
||||
|
||||
@@ -3,12 +3,16 @@
|
||||
//! This module provides functionality to apply chat templates to messages,
|
||||
//! similar to HuggingFace transformers' apply_chat_template method.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use minijinja::machinery::ast::{Expr, Stmt};
|
||||
use minijinja::{context, Environment, Value};
|
||||
use serde_json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use minijinja::{
|
||||
context,
|
||||
machinery::ast::{Expr, Stmt},
|
||||
Environment, Value,
|
||||
};
|
||||
use serde_json;
|
||||
|
||||
/// Chat template content format
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ChatTemplateContentFormat {
|
||||
@@ -319,8 +323,10 @@ impl<'a> Detector<'a> {
|
||||
/// AST-based detection using minijinja's unstable machinery
|
||||
/// Single-pass detector with scope tracking
|
||||
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
|
||||
use minijinja::machinery::{parse, WhitespaceConfig};
|
||||
use minijinja::syntax::SyntaxConfig;
|
||||
use minijinja::{
|
||||
machinery::{parse, WhitespaceConfig},
|
||||
syntax::SyntaxConfig,
|
||||
};
|
||||
|
||||
let ast = match parse(
|
||||
template,
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
use super::traits;
|
||||
use std::{fs::File, io::Read, path::Path, sync::Arc};
|
||||
|
||||
use anyhow::{Error, Result};
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::huggingface::HuggingFaceTokenizer;
|
||||
use super::tiktoken::TiktokenTokenizer;
|
||||
use super::{huggingface::HuggingFaceTokenizer, tiktoken::TiktokenTokenizer, traits};
|
||||
use crate::tokenizer::hub::download_tokenizer_from_hf;
|
||||
|
||||
/// Represents the type of tokenizer being used
|
||||
@@ -379,8 +375,7 @@ pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
|
||||
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
|
||||
_ => {
|
||||
// Try auto-detection
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::{fs::File, io::Read};
|
||||
|
||||
let mut file = File::open(file_path)?;
|
||||
let mut buffer = vec![0u8; 512];
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use std::{
|
||||
env,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use hf_hub::api::tokio::ApiBuilder;
|
||||
use std::env;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const IGNORED: [&str; 5] = [
|
||||
".gitattributes",
|
||||
|
||||
@@ -3,12 +3,12 @@ use std::collections::HashMap;
|
||||
use anyhow::{Error, Result};
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
use super::chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
|
||||
ChatTemplateProcessor,
|
||||
};
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
use super::{
|
||||
chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
|
||||
ChatTemplateProcessor,
|
||||
},
|
||||
traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
|
||||
};
|
||||
|
||||
/// HuggingFace tokenizer wrapper
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
//! Mock tokenizer implementation for testing
|
||||
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
|
||||
/// Mock tokenizer for testing purposes
|
||||
pub struct MockTokenizer {
|
||||
vocab: HashMap<String, u32>,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod factory;
|
||||
pub mod hub;
|
||||
@@ -27,14 +27,12 @@ pub use factory::{
|
||||
create_tokenizer_from_file, create_tokenizer_with_chat_template,
|
||||
create_tokenizer_with_chat_template_blocking, TokenizerType,
|
||||
};
|
||||
pub use huggingface::HuggingFaceTokenizer;
|
||||
pub use sequence::Sequence;
|
||||
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
||||
pub use stream::DecodeStream;
|
||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
|
||||
pub use huggingface::HuggingFaceTokenizer;
|
||||
|
||||
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
|
||||
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
|
||||
|
||||
/// Maintains state for an ongoing sequence of tokens and their decoded text
|
||||
/// This provides a cleaner abstraction for managing token sequences
|
||||
pub struct Sequence {
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use super::sequence::Sequence;
|
||||
use super::traits::{self, TokenIdType};
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{
|
||||
sequence::Sequence,
|
||||
traits::{self, TokenIdType},
|
||||
};
|
||||
|
||||
/// Output from the sequence decoder
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
// src/tokenizer/stream.rs
|
||||
|
||||
use super::traits::{self, TokenIdType};
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use super::traits::{self, TokenIdType};
|
||||
|
||||
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
|
||||
|
||||
/// DecodeStream will keep the state necessary to produce individual chunks of
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
#[cfg(test)]
|
||||
use super::*;
|
||||
#[cfg(test)]
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mock_tokenizer_encode() {
|
||||
let tokenizer = mock::MockTokenizer::new();
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use anyhow::{Error, Result};
|
||||
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
|
||||
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use anyhow::{Error, Result};
|
||||
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
|
||||
|
||||
/// Tiktoken tokenizer wrapper for OpenAI GPT models
|
||||
pub struct TiktokenTokenizer {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Type alias for token IDs
|
||||
pub type TokenIdType = u32;
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
// Factory and pool for creating model-specific tool parsers with pooling support.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::tool_parser::parsers::{
|
||||
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
|
||||
LlamaParser, MistralParser, PassthroughParser, PythonicParser, QwenParser, Step3Parser,
|
||||
use crate::tool_parser::{
|
||||
parsers::{
|
||||
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
|
||||
LlamaParser, MistralParser, PassthroughParser, PythonicParser, QwenParser, Step3Parser,
|
||||
},
|
||||
traits::ToolParser,
|
||||
};
|
||||
use crate::tool_parser::traits::ToolParser;
|
||||
|
||||
/// Type alias for pooled parser instances.
|
||||
pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>;
|
||||
|
||||
@@ -18,11 +18,10 @@ mod tests;
|
||||
// Re-export commonly used types
|
||||
pub use errors::{ParserError, ParserResult};
|
||||
pub use factory::{ParserFactory, ParserRegistry, PooledParser};
|
||||
pub use traits::{PartialJsonParser, ToolParser};
|
||||
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};
|
||||
|
||||
// Re-export parsers for convenience
|
||||
pub use parsers::{
|
||||
DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser,
|
||||
MistralParser, PythonicParser, QwenParser, Step3Parser,
|
||||
};
|
||||
pub use traits::{PartialJsonParser, ToolParser};
|
||||
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};
|
||||
|
||||
@@ -2,13 +2,14 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
},
|
||||
};
|
||||
|
||||
/// DeepSeek V3 format parser for tool calls
|
||||
|
||||
@@ -2,13 +2,14 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
},
|
||||
};
|
||||
|
||||
/// GLM-4 MoE format parser for tool calls
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::ParserResult,
|
||||
traits::{TokenToolParser, ToolParser},
|
||||
types::{StreamingParseResult, ToolCall},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::ParserResult,
|
||||
traits::{TokenToolParser, ToolParser},
|
||||
types::{StreamingParseResult, ToolCall},
|
||||
},
|
||||
};
|
||||
|
||||
/// Placeholder for the Harmony-backed GPT-OSS parser.
|
||||
|
||||
@@ -2,14 +2,15 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
},
|
||||
};
|
||||
|
||||
/// GPT-OSS format parser for tool calls
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
use crate::protocols::common::Tool;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::tool_parser::errors::{ParserError, ParserResult};
|
||||
use crate::tool_parser::types::{StreamingParseResult, ToolCallItem};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
types::{StreamingParseResult, ToolCallItem},
|
||||
},
|
||||
};
|
||||
|
||||
/// Get a mapping of tool names to their indices
|
||||
pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
},
|
||||
};
|
||||
|
||||
/// JSON format parser for tool calls
|
||||
|
||||
@@ -2,13 +2,14 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::ParserResult,
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::ParserResult,
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
},
|
||||
};
|
||||
|
||||
/// Kimi K2 format parser for tool calls
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
},
|
||||
};
|
||||
|
||||
/// Llama 3.2 format parser for tool calls
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
},
|
||||
};
|
||||
|
||||
/// Mistral format parser for tool calls
|
||||
|
||||
@@ -4,12 +4,17 @@
|
||||
//! tool call parsing should be performed. It simply returns the input text
|
||||
//! with no tool calls detected.
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
use crate::tool_parser::errors::ParserResult;
|
||||
use crate::tool_parser::traits::ToolParser;
|
||||
use crate::tool_parser::types::{StreamingParseResult, ToolCall, ToolCallItem};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::ParserResult,
|
||||
traits::ToolParser,
|
||||
types::{StreamingParseResult, ToolCall, ToolCallItem},
|
||||
},
|
||||
};
|
||||
|
||||
/// Passthrough parser that returns text unchanged with no tool calls
|
||||
#[derive(Default)]
|
||||
pub struct PassthroughParser;
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::OnceLock;
|
||||
|
||||
/// Pythonic format parser for tool calls
|
||||
///
|
||||
/// Handles Python function call syntax within square brackets:
|
||||
@@ -10,18 +12,20 @@
|
||||
use async_trait::async_trait;
|
||||
use num_traits::ToPrimitive;
|
||||
use regex::Regex;
|
||||
use rustpython_parser::ast::{Constant, Expr, Mod, UnaryOp};
|
||||
use rustpython_parser::{parse, Mode};
|
||||
use rustpython_parser::{
|
||||
ast::{Constant, Expr, Mod, UnaryOp},
|
||||
parse, Mode,
|
||||
};
|
||||
use serde_json::{Map, Number, Value};
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
},
|
||||
};
|
||||
|
||||
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();
|
||||
|
||||
@@ -2,14 +2,15 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::common::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
use crate::{
|
||||
protocols::common::Tool,
|
||||
tool_parser::{
|
||||
errors::{ParserError, ParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
},
|
||||
};
|
||||
|
||||
/// Qwen format parser for tool calls
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user