sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,28 @@
pub mod types;
pub mod validation;
pub use types::*;
pub use validation::*;
/// Configuration errors
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("Validation failed: {reason}")]
ValidationFailed { reason: String },
#[error("Invalid value for field '{field}': {value} - {reason}")]
InvalidValue {
field: String,
value: String,
reason: String,
},
#[error("Incompatible configuration: {reason}")]
IncompatibleConfig { reason: String },
#[error("Missing required field: {field}")]
MissingRequired { field: String },
}
/// Result type for configuration operations
pub type ConfigResult<T> = Result<T, ConfigError>;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,776 @@
use super::*;
/// Configuration validator
pub struct ConfigValidator;
impl ConfigValidator {
/// Validate a complete router configuration
pub fn validate(config: &RouterConfig) -> ConfigResult<()> {
// Check if service discovery is enabled
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
Self::validate_mode(&config.mode, has_service_discovery)?;
Self::validate_policy(&config.policy)?;
Self::validate_server_settings(config)?;
if let Some(discovery) = &config.discovery {
Self::validate_discovery(discovery, &config.mode)?;
}
if let Some(metrics) = &config.metrics {
Self::validate_metrics(metrics)?;
}
Self::validate_compatibility(config)?;
// Validate effective retry/CB configs (respect disable flags)
let retry_cfg = config.effective_retry_config();
let cb_cfg = config.effective_circuit_breaker_config();
Self::validate_retry(&retry_cfg)?;
Self::validate_circuit_breaker(&cb_cfg)?;
Ok(())
}
/// Validate routing mode configuration
fn validate_mode(mode: &RoutingMode, has_service_discovery: bool) -> ConfigResult<()> {
match mode {
RoutingMode::Regular { worker_urls } => {
// Validate URLs if any are provided
if !worker_urls.is_empty() {
Self::validate_urls(worker_urls)?;
}
// Note: We allow empty worker URLs even without service discovery
// to let the router start and fail at runtime when routing requests.
// This matches legacy behavior and test expectations.
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} => {
// Only require URLs if service discovery is disabled
if !has_service_discovery {
if prefill_urls.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: "PD mode requires at least one prefill worker URL".to_string(),
});
}
if decode_urls.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: "PD mode requires at least one decode worker URL".to_string(),
});
}
}
// Validate URLs if any are provided
if !prefill_urls.is_empty() {
let prefill_url_strings: Vec<String> =
prefill_urls.iter().map(|(url, _)| url.clone()).collect();
Self::validate_urls(&prefill_url_strings)?;
}
if !decode_urls.is_empty() {
Self::validate_urls(decode_urls)?;
}
// Validate bootstrap ports
for (_url, port) in prefill_urls {
if let Some(port) = port {
if *port == 0 {
return Err(ConfigError::InvalidValue {
field: "bootstrap_port".to_string(),
value: port.to_string(),
reason: "Port must be between 1 and 65535".to_string(),
});
}
}
}
// Validate optional prefill and decode policies
if let Some(p_policy) = prefill_policy {
Self::validate_policy(p_policy)?;
}
if let Some(d_policy) = decode_policy {
Self::validate_policy(d_policy)?;
}
}
RoutingMode::OpenAI { worker_urls } => {
// Require exactly one worker URL for OpenAI router
if worker_urls.len() != 1 {
return Err(ConfigError::ValidationFailed {
reason: "OpenAI mode requires exactly one --worker-urls entry".to_string(),
});
}
// Validate URL format
if let Err(e) = url::Url::parse(&worker_urls[0]) {
return Err(ConfigError::ValidationFailed {
reason: format!("Invalid OpenAI worker URL '{}': {}", &worker_urls[0], e),
});
}
}
}
Ok(())
}
/// Validate policy configuration
fn validate_policy(policy: &PolicyConfig) -> ConfigResult<()> {
match policy {
PolicyConfig::Random | PolicyConfig::RoundRobin => {
// No specific validation needed
}
PolicyConfig::CacheAware {
cache_threshold,
balance_abs_threshold: _,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
} => {
if !(0.0..=1.0).contains(cache_threshold) {
return Err(ConfigError::InvalidValue {
field: "cache_threshold".to_string(),
value: cache_threshold.to_string(),
reason: "Must be between 0.0 and 1.0".to_string(),
});
}
if *balance_rel_threshold < 1.0 {
return Err(ConfigError::InvalidValue {
field: "balance_rel_threshold".to_string(),
value: balance_rel_threshold.to_string(),
reason: "Must be >= 1.0".to_string(),
});
}
if *eviction_interval_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "eviction_interval_secs".to_string(),
value: eviction_interval_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
if *max_tree_size == 0 {
return Err(ConfigError::InvalidValue {
field: "max_tree_size".to_string(),
value: max_tree_size.to_string(),
reason: "Must be > 0".to_string(),
});
}
}
PolicyConfig::PowerOfTwo {
load_check_interval_secs,
} => {
if *load_check_interval_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "load_check_interval_secs".to_string(),
value: load_check_interval_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
}
}
Ok(())
}
/// Validate server configuration
fn validate_server_settings(config: &RouterConfig) -> ConfigResult<()> {
if config.port == 0 {
return Err(ConfigError::InvalidValue {
field: "port".to_string(),
value: config.port.to_string(),
reason: "Port must be > 0".to_string(),
});
}
if config.max_payload_size == 0 {
return Err(ConfigError::InvalidValue {
field: "max_payload_size".to_string(),
value: config.max_payload_size.to_string(),
reason: "Must be > 0".to_string(),
});
}
if config.request_timeout_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "request_timeout_secs".to_string(),
value: config.request_timeout_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
if config.worker_startup_timeout_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "worker_startup_timeout_secs".to_string(),
value: config.worker_startup_timeout_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
if config.worker_startup_check_interval_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "worker_startup_check_interval_secs".to_string(),
value: config.worker_startup_check_interval_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
Ok(())
}
/// Validate service discovery configuration
fn validate_discovery(discovery: &DiscoveryConfig, mode: &RoutingMode) -> ConfigResult<()> {
if !discovery.enabled {
return Ok(()); // No validation needed if disabled
}
if discovery.port == 0 {
return Err(ConfigError::InvalidValue {
field: "discovery.port".to_string(),
value: discovery.port.to_string(),
reason: "Port must be > 0".to_string(),
});
}
if discovery.check_interval_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "discovery.check_interval_secs".to_string(),
value: discovery.check_interval_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
// Validate selectors based on mode
match mode {
RoutingMode::Regular { .. } => {
if discovery.selector.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: "Regular mode with service discovery requires a non-empty selector"
.to_string(),
});
}
}
RoutingMode::PrefillDecode { .. } => {
if discovery.prefill_selector.is_empty() && discovery.decode_selector.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: "PD mode with service discovery requires at least one non-empty selector (prefill or decode)".to_string(),
});
}
}
RoutingMode::OpenAI { .. } => {
// OpenAI mode doesn't use service discovery
return Err(ConfigError::ValidationFailed {
reason: "OpenAI mode does not support service discovery".to_string(),
});
}
}
Ok(())
}
/// Validate metrics configuration
fn validate_metrics(metrics: &MetricsConfig) -> ConfigResult<()> {
if metrics.port == 0 {
return Err(ConfigError::InvalidValue {
field: "metrics.port".to_string(),
value: metrics.port.to_string(),
reason: "Port must be > 0".to_string(),
});
}
if metrics.host.is_empty() {
return Err(ConfigError::InvalidValue {
field: "metrics.host".to_string(),
value: metrics.host.clone(),
reason: "Host cannot be empty".to_string(),
});
}
Ok(())
}
/// Validate retry configuration
fn validate_retry(retry: &RetryConfig) -> ConfigResult<()> {
if retry.max_retries < 1 {
return Err(ConfigError::InvalidValue {
field: "retry.max_retries".to_string(),
value: retry.max_retries.to_string(),
reason: "Must be >= 1 (set to 1 to effectively disable retries)".to_string(),
});
}
if retry.initial_backoff_ms == 0 {
return Err(ConfigError::InvalidValue {
field: "retry.initial_backoff_ms".to_string(),
value: retry.initial_backoff_ms.to_string(),
reason: "Must be > 0".to_string(),
});
}
if retry.max_backoff_ms < retry.initial_backoff_ms {
return Err(ConfigError::InvalidValue {
field: "retry.max_backoff_ms".to_string(),
value: retry.max_backoff_ms.to_string(),
reason: "Must be >= initial_backoff_ms".to_string(),
});
}
if retry.backoff_multiplier < 1.0 {
return Err(ConfigError::InvalidValue {
field: "retry.backoff_multiplier".to_string(),
value: retry.backoff_multiplier.to_string(),
reason: "Must be >= 1.0".to_string(),
});
}
if !(0.0..=1.0).contains(&retry.jitter_factor) {
return Err(ConfigError::InvalidValue {
field: "retry.jitter_factor".to_string(),
value: retry.jitter_factor.to_string(),
reason: "Must be between 0.0 and 1.0".to_string(),
});
}
Ok(())
}
/// Validate circuit breaker configuration
fn validate_circuit_breaker(cb: &CircuitBreakerConfig) -> ConfigResult<()> {
if cb.failure_threshold < 1 {
return Err(ConfigError::InvalidValue {
field: "circuit_breaker.failure_threshold".to_string(),
value: cb.failure_threshold.to_string(),
reason: "Must be >= 1 (set to u32::MAX to effectively disable CB)".to_string(),
});
}
if cb.success_threshold < 1 {
return Err(ConfigError::InvalidValue {
field: "circuit_breaker.success_threshold".to_string(),
value: cb.success_threshold.to_string(),
reason: "Must be >= 1".to_string(),
});
}
if cb.timeout_duration_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "circuit_breaker.timeout_duration_secs".to_string(),
value: cb.timeout_duration_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
if cb.window_duration_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "circuit_breaker.window_duration_secs".to_string(),
value: cb.window_duration_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
Ok(())
}
/// Validate compatibility between different configuration sections
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
// IGW mode is independent - skip other compatibility checks when enabled
if config.enable_igw {
return Ok(());
}
// Validate gRPC connection mode requires tokenizer configuration
if config.connection_mode == ConnectionMode::Grpc
&& config.tokenizer_path.is_none()
&& config.model_path.is_none()
{
return Err(ConfigError::ValidationFailed {
reason: "gRPC connection mode requires either --tokenizer-path or --model-path to be specified".to_string(),
});
}
// All policies are now supported for both router types thanks to the unified trait design
// No mode/policy restrictions needed anymore
// Check if service discovery is enabled for worker count validation
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
// Only validate worker counts if service discovery is disabled
if !has_service_discovery {
// Check if power-of-two policy makes sense with insufficient workers
if let PolicyConfig::PowerOfTwo { .. } = &config.policy {
let worker_count = config.mode.worker_count();
if worker_count < 2 {
return Err(ConfigError::IncompatibleConfig {
reason: "Power-of-two policy requires at least 2 workers".to_string(),
});
}
}
// For PD mode, validate that policies have sufficient workers
if let RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} = &config.mode
{
// Check power-of-two for prefill
if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy {
if prefill_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig {
reason: "Power-of-two policy for prefill requires at least 2 prefill workers".to_string(),
});
}
}
// Check power-of-two for decode
if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy {
if decode_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig {
reason:
"Power-of-two policy for decode requires at least 2 decode workers"
.to_string(),
});
}
}
}
}
// Service discovery is conflict with dp_aware routing for now
// since it's not fully supported yet
if has_service_discovery && config.dp_aware {
return Err(ConfigError::IncompatibleConfig {
reason: "DP-aware routing is not compatible with service discovery".to_string(),
});
}
Ok(())
}
/// Validate URL format
fn validate_urls(urls: &[String]) -> ConfigResult<()> {
for url in urls {
if url.is_empty() {
return Err(ConfigError::InvalidValue {
field: "worker_url".to_string(),
value: url.clone(),
reason: "URL cannot be empty".to_string(),
});
}
if !url.starts_with("http://")
&& !url.starts_with("https://")
&& !url.starts_with("grpc://")
{
return Err(ConfigError::InvalidValue {
field: "worker_url".to_string(),
value: url.clone(),
reason: "URL must start with http://, https://, or grpc://".to_string(),
});
}
// Basic URL validation
match ::url::Url::parse(url) {
Ok(parsed) => {
// Additional validation
if parsed.host_str().is_none() {
return Err(ConfigError::InvalidValue {
field: "worker_url".to_string(),
value: url.clone(),
reason: "URL must have a valid host".to_string(),
});
}
}
Err(e) => {
return Err(ConfigError::InvalidValue {
field: "worker_url".to_string(),
value: url.clone(),
reason: format!("Invalid URL format: {}", e),
});
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_regular_mode() {
let config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["http://worker:8000".to_string()],
},
PolicyConfig::Random,
);
assert!(ConfigValidator::validate(&config).is_ok());
}
#[test]
fn test_validate_empty_worker_urls() {
let config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec![],
},
PolicyConfig::Random,
);
// Empty worker URLs are now allowed to match legacy behavior
assert!(ConfigValidator::validate(&config).is_ok());
}
#[test]
fn test_validate_empty_worker_urls_with_service_discovery() {
let mut config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec![],
},
PolicyConfig::Random,
);
// Enable service discovery
config.discovery = Some(DiscoveryConfig {
enabled: true,
selector: vec![("app".to_string(), "test".to_string())]
.into_iter()
.collect(),
..Default::default()
});
// Should pass validation since service discovery is enabled
assert!(ConfigValidator::validate(&config).is_ok());
}
#[test]
fn test_validate_invalid_urls() {
let config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["invalid-url".to_string()],
},
PolicyConfig::Random,
);
assert!(ConfigValidator::validate(&config).is_err());
}
#[test]
fn test_validate_cache_aware_thresholds() {
let config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec![
"http://worker1:8000".to_string(),
"http://worker2:8000".to_string(),
],
},
PolicyConfig::CacheAware {
cache_threshold: 1.5, // Invalid: > 1.0
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
},
);
assert!(ConfigValidator::validate(&config).is_err());
}
#[test]
fn test_validate_cache_aware_single_worker() {
// Cache-aware with single worker should be allowed (even if not optimal)
let config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["http://worker1:8000".to_string()],
},
PolicyConfig::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
},
);
assert!(ConfigValidator::validate(&config).is_ok());
}
#[test]
fn test_validate_pd_mode() {
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
},
PolicyConfig::Random,
);
assert!(ConfigValidator::validate(&config).is_ok());
}
#[test]
fn test_validate_roundrobin_with_pd_mode() {
// RoundRobin with PD mode is now supported
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
},
PolicyConfig::RoundRobin,
);
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_cache_aware_with_pd_mode() {
// CacheAware with PD mode is now supported
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
},
PolicyConfig::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
},
);
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_power_of_two_with_regular_mode() {
// PowerOfTwo with Regular mode is now supported
let config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec![
"http://worker1:8000".to_string(),
"http://worker2:8000".to_string(),
],
},
PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
},
);
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_pd_mode_with_separate_policies() {
// Test PD mode with different policies for prefill and decode
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://prefill1:8000".to_string(), None),
("http://prefill2:8000".to_string(), None),
],
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
prefill_policy: Some(PolicyConfig::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
}),
decode_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}),
},
PolicyConfig::Random, // Main policy as fallback
);
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
// Test that power-of-two policy requires at least 2 workers
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
prefill_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}), // Requires 2+ workers
decode_policy: None,
},
PolicyConfig::Random,
);
let result = ConfigValidator::validate(&config);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("prefill requires at least 2"));
}
}
#[test]
fn test_validate_grpc_requires_tokenizer() {
// Test that gRPC connection mode requires tokenizer configuration
let mut config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["grpc://worker:50051".to_string()],
},
PolicyConfig::Random,
);
// Set connection mode to gRPC without tokenizer config
config.connection_mode = ConnectionMode::Grpc;
config.tokenizer_path = None;
config.model_path = None;
let result = ConfigValidator::validate(&config);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("gRPC connection mode requires"));
}
}
#[test]
fn test_validate_grpc_with_model_path() {
// Test that gRPC works with model_path
let mut config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["grpc://worker:50051".to_string()],
},
PolicyConfig::Random,
);
config.connection_mode = ConnectionMode::Grpc;
config.model_path = Some("meta-llama/Llama-3-8B".to_string());
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_grpc_with_tokenizer_path() {
// Test that gRPC works with tokenizer_path
let mut config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["grpc://worker:50051".to_string()],
},
PolicyConfig::Random,
);
config.connection_mode = ConnectionMode::Grpc;
config.tokenizer_path = Some("/path/to/tokenizer.json".to_string());
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
}

View File

@@ -0,0 +1,555 @@
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tracing::info;
/// Circuit breaker configuration
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
/// Number of consecutive failures to open the circuit
pub failure_threshold: u32,
/// Success threshold to close circuit from half-open
pub success_threshold: u32,
/// Duration to wait before attempting half-open
pub timeout_duration: Duration,
/// Time window for failure counting
pub window_duration: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout_duration: Duration::from_secs(30),
window_duration: Duration::from_secs(60),
}
}
}
/// Circuit breaker state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
/// Normal operation - requests are allowed
Closed,
/// Circuit is open - requests are rejected
Open,
/// Testing if service has recovered - limited requests allowed
HalfOpen,
}
impl std::fmt::Display for CircuitState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitState::Closed => write!(f, "Closed"),
CircuitState::Open => write!(f, "Open"),
CircuitState::HalfOpen => write!(f, "HalfOpen"),
}
}
}
/// Circuit breaker implementation
#[derive(Debug)]
pub struct CircuitBreaker {
state: Arc<RwLock<CircuitState>>,
consecutive_failures: Arc<AtomicU32>,
consecutive_successes: Arc<AtomicU32>,
total_failures: Arc<AtomicU64>,
total_successes: Arc<AtomicU64>,
last_failure_time: Arc<RwLock<Option<Instant>>>,
last_state_change: Arc<RwLock<Instant>>,
config: CircuitBreakerConfig,
}
impl CircuitBreaker {
/// Create a new circuit breaker with default configuration
pub fn new() -> Self {
Self::with_config(CircuitBreakerConfig::default())
}
/// Create a new circuit breaker with custom configuration
pub fn with_config(config: CircuitBreakerConfig) -> Self {
Self {
state: Arc::new(RwLock::new(CircuitState::Closed)),
consecutive_failures: Arc::new(AtomicU32::new(0)),
consecutive_successes: Arc::new(AtomicU32::new(0)),
total_failures: Arc::new(AtomicU64::new(0)),
total_successes: Arc::new(AtomicU64::new(0)),
last_failure_time: Arc::new(RwLock::new(None)),
last_state_change: Arc::new(RwLock::new(Instant::now())),
config,
}
}
/// Check if a request can be executed
pub fn can_execute(&self) -> bool {
// First check if we need to transition from Open to HalfOpen
self.check_and_update_state();
let state = *self.state.read().unwrap();
match state {
CircuitState::Closed => true,
CircuitState::Open => false,
CircuitState::HalfOpen => true, // Allow limited requests in half-open state
}
}
/// Get the current state
pub fn state(&self) -> CircuitState {
self.check_and_update_state();
*self.state.read().unwrap()
}
/// Record the outcome of a request
pub fn record_outcome(&self, success: bool) {
if success {
self.record_success();
} else {
self.record_failure();
}
}
/// Record a successful request
pub fn record_success(&self) {
self.total_successes.fetch_add(1, Ordering::Relaxed);
self.consecutive_failures.store(0, Ordering::Release);
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
// Outcome-level metrics are recorded at the worker level where the worker label is known
let current_state = *self.state.read().unwrap();
match current_state {
CircuitState::HalfOpen => {
// Check if we've reached the success threshold to close the circuit
if successes >= self.config.success_threshold {
self.transition_to(CircuitState::Closed);
}
}
CircuitState::Closed => {
// Already closed, nothing to do
}
CircuitState::Open => {
// Shouldn't happen, but if it does, stay open
tracing::warn!("Success recorded while circuit is open");
}
}
}
/// Record a failed request
pub fn record_failure(&self) {
self.total_failures.fetch_add(1, Ordering::Relaxed);
self.consecutive_successes.store(0, Ordering::Release);
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
// Outcome-level metrics are recorded at the worker level where the worker label is known
// Update last failure time
{
let mut last_failure = self.last_failure_time.write().unwrap();
*last_failure = Some(Instant::now());
}
let current_state = *self.state.read().unwrap();
match current_state {
CircuitState::Closed => {
// Check if we've reached the failure threshold to open the circuit
if failures >= self.config.failure_threshold {
self.transition_to(CircuitState::Open);
}
}
CircuitState::HalfOpen => {
// Single failure in half-open state reopens the circuit
self.transition_to(CircuitState::Open);
}
CircuitState::Open => {
// Already open, nothing to do
}
}
}
/// Check and update state based on timeout
fn check_and_update_state(&self) {
let current_state = *self.state.read().unwrap();
if current_state == CircuitState::Open {
// Check if timeout has expired
let last_change = *self.last_state_change.read().unwrap();
if last_change.elapsed() >= self.config.timeout_duration {
self.transition_to(CircuitState::HalfOpen);
}
}
}
/// Transition to a new state
fn transition_to(&self, new_state: CircuitState) {
let mut state = self.state.write().unwrap();
let old_state = *state;
if old_state != new_state {
*state = new_state;
// Update last state change time
let mut last_change = self.last_state_change.write().unwrap();
*last_change = Instant::now();
// Reset counters based on transition
match new_state {
CircuitState::Closed => {
self.consecutive_failures.store(0, Ordering::Release);
self.consecutive_successes.store(0, Ordering::Release);
}
CircuitState::Open => {
self.consecutive_successes.store(0, Ordering::Release);
}
CircuitState::HalfOpen => {
self.consecutive_failures.store(0, Ordering::Release);
self.consecutive_successes.store(0, Ordering::Release);
}
}
let from = match old_state {
CircuitState::Closed => "closed",
CircuitState::Open => "open",
CircuitState::HalfOpen => "half_open",
};
let to = match new_state {
CircuitState::Closed => "closed",
CircuitState::Open => "open",
CircuitState::HalfOpen => "half_open",
};
info!("Circuit breaker state transition: {} -> {}", from, to);
// Transition metrics are recorded at the worker level where the worker label is known
}
}
/// Get the number of consecutive failures
pub fn failure_count(&self) -> u32 {
self.consecutive_failures.load(Ordering::Acquire)
}
/// Get the number of consecutive successes
pub fn success_count(&self) -> u32 {
self.consecutive_successes.load(Ordering::Acquire)
}
/// Get total failures
pub fn total_failures(&self) -> u64 {
self.total_failures.load(Ordering::Relaxed)
}
/// Get total successes
pub fn total_successes(&self) -> u64 {
self.total_successes.load(Ordering::Relaxed)
}
/// Get time since last failure
pub fn time_since_last_failure(&self) -> Option<Duration> {
self.last_failure_time.read().unwrap().map(|t| t.elapsed())
}
/// Get time since last state change
pub fn time_since_last_state_change(&self) -> Duration {
self.last_state_change.read().unwrap().elapsed()
}
/// Check if the circuit is in a half-open state
pub fn is_half_open(&self) -> bool {
self.state() == CircuitState::HalfOpen
}
/// Record a test success (for health check probing)
pub fn record_test_success(&self) {
if self.is_half_open() {
self.record_success();
}
}
/// Record a test failure (for health check probing)
pub fn record_test_failure(&self) {
if self.is_half_open() {
self.record_failure();
}
}
/// Reset the circuit breaker to closed state
pub fn reset(&self) {
self.transition_to(CircuitState::Closed);
self.consecutive_failures.store(0, Ordering::Release);
self.consecutive_successes.store(0, Ordering::Release);
}
/// Force the circuit to open (for manual intervention)
pub fn force_open(&self) {
self.transition_to(CircuitState::Open);
}
/// Get circuit breaker statistics
pub fn stats(&self) -> CircuitBreakerStats {
CircuitBreakerStats {
state: self.state(),
consecutive_failures: self.failure_count(),
consecutive_successes: self.success_count(),
total_failures: self.total_failures(),
total_successes: self.total_successes(),
time_since_last_failure: self.time_since_last_failure(),
time_since_last_state_change: self.time_since_last_state_change(),
}
}
}
impl Clone for CircuitBreaker {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
consecutive_failures: Arc::clone(&self.consecutive_failures),
consecutive_successes: Arc::clone(&self.consecutive_successes),
total_failures: Arc::clone(&self.total_failures),
total_successes: Arc::clone(&self.total_successes),
last_failure_time: Arc::clone(&self.last_failure_time),
last_state_change: Arc::clone(&self.last_state_change),
config: self.config.clone(),
}
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new()
}
}
/// Circuit breaker statistics
#[derive(Debug, Clone)]
pub struct CircuitBreakerStats {
pub state: CircuitState,
pub consecutive_failures: u32,
pub consecutive_successes: u32,
pub total_failures: u64,
pub total_successes: u64,
pub time_since_last_failure: Option<Duration>,
pub time_since_last_state_change: Duration,
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_circuit_breaker_initial_state() {
let cb = CircuitBreaker::new();
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.can_execute());
assert_eq!(cb.failure_count(), 0);
assert_eq!(cb.success_count(), 0);
}
#[test]
fn test_circuit_opens_on_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Record failures up to threshold
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
// Circuit should now be open
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.can_execute());
assert_eq!(cb.failure_count(), 3);
}
#[test]
fn test_circuit_half_open_after_timeout() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
timeout_duration: Duration::from_millis(100),
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
// Wait for timeout
thread::sleep(Duration::from_millis(150));
// Circuit should be half-open
assert_eq!(cb.state(), CircuitState::HalfOpen);
assert!(cb.can_execute());
}
#[test]
fn test_circuit_closes_on_success_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 2,
timeout_duration: Duration::from_millis(50),
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
// Wait for timeout
thread::sleep(Duration::from_millis(100));
assert_eq!(cb.state(), CircuitState::HalfOpen);
// Record successes
cb.record_success();
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
// Circuit should now be closed
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.can_execute());
}
#[test]
fn test_circuit_reopens_on_half_open_failure() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
timeout_duration: Duration::from_millis(50),
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
// Wait for timeout
thread::sleep(Duration::from_millis(100));
assert_eq!(cb.state(), CircuitState::HalfOpen);
// Record a failure in half-open state
cb.record_failure();
// Circuit should reopen immediately
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.can_execute());
}
#[test]
fn test_success_resets_failure_count() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Record some failures
cb.record_failure();
cb.record_failure();
assert_eq!(cb.failure_count(), 2);
// Success should reset failure count
cb.record_success();
assert_eq!(cb.failure_count(), 0);
assert_eq!(cb.success_count(), 1);
// Can now record more failures without opening
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_manual_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
// Manual reset
cb.reset();
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.failure_count(), 0);
assert_eq!(cb.success_count(), 0);
}
#[test]
fn test_force_open() {
let cb = CircuitBreaker::new();
assert_eq!(cb.state(), CircuitState::Closed);
cb.force_open();
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.can_execute());
}
#[test]
fn test_stats() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
cb.record_success();
cb.record_failure();
cb.record_failure();
let stats = cb.stats();
assert_eq!(stats.state, CircuitState::Open);
assert_eq!(stats.consecutive_failures, 2);
assert_eq!(stats.consecutive_successes, 0);
assert_eq!(stats.total_failures, 2);
assert_eq!(stats.total_successes, 1);
}
#[test]
fn test_clone() {
let cb1 = CircuitBreaker::new();
cb1.record_failure();
let cb2 = cb1.clone();
assert_eq!(cb2.failure_count(), 1);
// Changes to cb1 affect cb2 (shared state)
cb1.record_failure();
assert_eq!(cb2.failure_count(), 2);
}
#[test]
fn test_thread_safety() {
use std::sync::Arc;
let cb = Arc::new(CircuitBreaker::new());
let mut handles = vec![];
// Spawn threads that record failures
for _ in 0..10 {
let cb_clone = Arc::clone(&cb);
let handle = thread::spawn(move || {
for _ in 0..100 {
cb_clone.record_failure();
}
});
handles.push(handle);
}
// Wait for all threads
for handle in handles {
handle.join().unwrap();
}
// Should have recorded 1000 failures
assert_eq!(cb.total_failures(), 1000);
}
}

View File

@@ -0,0 +1,240 @@
//! Error types for the SGLang router core
//!
//! This module defines error types used throughout the router for worker operations.
use std::fmt;
/// Worker-related errors
#[derive(Debug)]
pub enum WorkerError {
/// Health check failed
HealthCheckFailed { url: String, reason: String },
/// Worker not found
WorkerNotFound { url: String },
/// Invalid worker configuration
InvalidConfiguration { message: String },
/// Network error
NetworkError { url: String, error: String },
/// Worker is at capacity
WorkerAtCapacity { url: String },
/// Invalid URL format
InvalidUrl { url: String },
}
impl fmt::Display for WorkerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WorkerError::HealthCheckFailed { url, reason } => {
write!(f, "Health check failed for worker {}: {}", url, reason)
}
WorkerError::WorkerNotFound { url } => {
write!(f, "Worker not found: {}", url)
}
WorkerError::InvalidConfiguration { message } => {
write!(f, "Invalid worker configuration: {}", message)
}
WorkerError::NetworkError { url, error } => {
write!(f, "Network error for worker {}: {}", url, error)
}
WorkerError::WorkerAtCapacity { url } => {
write!(f, "Worker at capacity: {}", url)
}
WorkerError::InvalidUrl { url } => {
write!(f, "Invalid URL format: {}", url)
}
}
}
}
impl std::error::Error for WorkerError {}
/// Result type for worker operations
pub type WorkerResult<T> = Result<T, WorkerError>;
/// Convert from reqwest errors to worker errors
impl From<reqwest::Error> for WorkerError {
fn from(err: reqwest::Error) -> Self {
WorkerError::NetworkError {
url: err.url().map(|u| u.to_string()).unwrap_or_default(),
error: err.to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error;
#[test]
fn test_health_check_failed_display() {
let error = WorkerError::HealthCheckFailed {
url: "http://worker1:8080".to_string(),
reason: "Connection refused".to_string(),
};
assert_eq!(
error.to_string(),
"Health check failed for worker http://worker1:8080: Connection refused"
);
}
#[test]
fn test_worker_not_found_display() {
let error = WorkerError::WorkerNotFound {
url: "http://worker2:8080".to_string(),
};
assert_eq!(error.to_string(), "Worker not found: http://worker2:8080");
}
#[test]
fn test_invalid_configuration_display() {
let error = WorkerError::InvalidConfiguration {
message: "Missing port number".to_string(),
};
assert_eq!(
error.to_string(),
"Invalid worker configuration: Missing port number"
);
}
#[test]
fn test_network_error_display() {
let error = WorkerError::NetworkError {
url: "http://worker3:8080".to_string(),
error: "Timeout after 30s".to_string(),
};
assert_eq!(
error.to_string(),
"Network error for worker http://worker3:8080: Timeout after 30s"
);
}
#[test]
fn test_worker_at_capacity_display() {
let error = WorkerError::WorkerAtCapacity {
url: "http://worker4:8080".to_string(),
};
assert_eq!(error.to_string(), "Worker at capacity: http://worker4:8080");
}
#[test]
fn test_worker_error_implements_std_error() {
let error = WorkerError::WorkerNotFound {
url: "http://test".to_string(),
};
// Verify it implements Error trait
let _: &dyn Error = &error;
assert!(error.source().is_none());
}
#[test]
fn test_error_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<WorkerError>();
}
#[test]
fn test_worker_result_type_alias() {
// Test Ok variant
let result: WorkerResult<i32> = Ok(42);
assert!(matches!(result, Ok(42)));
// Test Err variant
let error = WorkerError::WorkerNotFound {
url: "test".to_string(),
};
let result: WorkerResult<i32> = Err(error);
assert!(result.is_err());
}
#[test]
fn test_empty_url_handling() {
// Test empty URLs in error variants
let error1 = WorkerError::HealthCheckFailed {
url: "".to_string(),
reason: "No connection".to_string(),
};
assert_eq!(
error1.to_string(),
"Health check failed for worker : No connection"
);
let error2 = WorkerError::NetworkError {
url: "".to_string(),
error: "DNS failure".to_string(),
};
assert_eq!(error2.to_string(), "Network error for worker : DNS failure");
let error3 = WorkerError::WorkerNotFound {
url: "".to_string(),
};
assert_eq!(error3.to_string(), "Worker not found: ");
}
#[test]
fn test_special_characters_in_messages() {
// Test with special characters
let error = WorkerError::InvalidConfiguration {
message: "Invalid JSON: {\"error\": \"test\"}".to_string(),
};
assert_eq!(
error.to_string(),
"Invalid worker configuration: Invalid JSON: {\"error\": \"test\"}"
);
// Test with unicode
let error2 = WorkerError::HealthCheckFailed {
url: "http://测试:8080".to_string(),
reason: "连接被拒绝".to_string(),
};
assert_eq!(
error2.to_string(),
"Health check failed for worker http://测试:8080: 连接被拒绝"
);
}
#[test]
fn test_very_long_error_messages() {
let long_message = "A".repeat(10000);
let error = WorkerError::InvalidConfiguration {
message: long_message.clone(),
};
let display = error.to_string();
assert!(display.contains(&long_message));
assert_eq!(
display.len(),
"Invalid worker configuration: ".len() + long_message.len()
);
}
// Mock reqwest error for testing conversion
#[test]
fn test_reqwest_error_conversion() {
// Test that NetworkError is the correct variant
let network_error = WorkerError::NetworkError {
url: "http://example.com".to_string(),
error: "connection timeout".to_string(),
};
match network_error {
WorkerError::NetworkError { url, error } => {
assert_eq!(url, "http://example.com");
assert_eq!(error, "connection timeout");
}
_ => panic!("Expected NetworkError variant"),
}
}
#[test]
fn test_error_equality() {
// WorkerError doesn't implement PartialEq, but we can test that
// the same error construction produces the same display output
let error1 = WorkerError::WorkerNotFound {
url: "http://test".to_string(),
};
let error2 = WorkerError::WorkerNotFound {
url: "http://test".to_string(),
};
assert_eq!(error1.to_string(), error2.to_string());
}
}

View File

@@ -0,0 +1,24 @@
//! Core abstractions for the SGLang router
//!
//! This module contains the fundamental types and traits used throughout the router:
//! - Worker trait and implementations
//! - Error types
//! - Circuit breaker for reliability
//! - Common utilities
pub mod circuit_breaker;
pub mod error;
pub mod retry;
pub mod token_bucket;
pub mod worker;
// Re-export commonly used types at the module level
pub use circuit_breaker::{
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
};
pub use error::{WorkerError, WorkerResult};
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
pub use worker::{
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
};

View File

@@ -0,0 +1,409 @@
use crate::config::types::RetryConfig;
use axum::http::StatusCode;
use axum::response::Response;
use rand::Rng;
use std::time::Duration;
use tracing::debug;
/// Check if an HTTP status code indicates a retryable error
pub fn is_retryable_status(status: StatusCode) -> bool {
matches!(
status,
StatusCode::REQUEST_TIMEOUT
| StatusCode::TOO_MANY_REQUESTS
| StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE
| StatusCode::GATEWAY_TIMEOUT
)
}
/// Computes exponential backoff with optional jitter.
#[derive(Debug, Clone)]
pub struct BackoffCalculator;
impl BackoffCalculator {
/// Calculate backoff delay for a given attempt index (0-based).
pub fn calculate_delay(config: &RetryConfig, attempt: u32) -> Duration {
// Base exponential backoff
let pow = config.backoff_multiplier.powi(attempt as i32);
let mut delay_ms = (config.initial_backoff_ms as f32 * pow) as u64;
if delay_ms > config.max_backoff_ms {
delay_ms = config.max_backoff_ms;
}
// Apply jitter in range [-j, +j]
let jitter = config.jitter_factor.clamp(0.0, 1.0);
if jitter > 0.0 {
let mut rng = rand::rng();
let jitter_scale: f32 = rng.random_range(-jitter..=jitter);
let jitter_ms = (delay_ms as f32 * jitter_scale)
.round()
.max(-(delay_ms as f32));
let adjusted = (delay_ms as i64 + jitter_ms as i64).max(0) as u64;
return Duration::from_millis(adjusted);
}
Duration::from_millis(delay_ms)
}
}
#[derive(Debug, thiserror::Error)]
pub enum RetryError {
#[error("no available workers")]
NoAvailableWorkers,
#[error("maximum retry attempts exceeded")]
MaxRetriesExceeded,
}
/// A thin async retry executor for generic operations.
#[derive(Debug, Clone, Default)]
pub struct RetryExecutor;
impl RetryExecutor {
/// Execute an async operation with retries and backoff.
/// The `operation` closure is invoked each attempt with the attempt index.
pub async fn execute_with_retry<F, Fut, T>(
config: &RetryConfig,
mut operation: F,
) -> Result<T, RetryError>
where
F: FnMut(u32) -> Fut,
Fut: std::future::Future<Output = Result<T, ()>>,
{
let max = config.max_retries.max(1);
let mut attempt: u32 = 0;
loop {
match operation(attempt).await {
Ok(val) => return Ok(val),
Err(_) => {
// Use the number of failures so far (0-indexed) to compute delay,
// so the first retry uses `initial_backoff_ms`.
let is_last = attempt + 1 >= max;
if is_last {
return Err(RetryError::MaxRetriesExceeded);
}
let delay = BackoffCalculator::calculate_delay(config, attempt);
attempt += 1; // advance to the next attempt after computing delay
tokio::time::sleep(delay).await;
}
}
}
}
/// Execute an operation that returns an HTTP Response with retries and backoff.
///
/// Usage pattern:
/// - `operation(attempt)`: perform one attempt (0-based). Construct and send the request,
/// then return the `Response`. Do any per-attempt bookkeeping (e.g., load tracking,
/// circuit-breaker outcome recording) inside this closure.
/// - `should_retry(&response, attempt)`: decide if the given response should be retried
/// (e.g., based on HTTP status). Returning false short-circuits and returns the response.
/// - `on_backoff(delay, next_attempt)`: called before sleeping between attempts.
/// Use this to record metrics.
/// - `on_exhausted()`: called when the executor has exhausted all retry attempts.
///
/// Example:
/// ```ignore
/// let resp = RetryExecutor::execute_response_with_retry(
/// &retry_cfg,
/// |attempt| async move {
/// let worker = select_cb_aware_worker()?;
/// let resp = send_request(worker).await;
/// worker.record_outcome(resp.status().is_success());
/// resp
/// },
/// |res, _| matches!(res.status(), StatusCode::REQUEST_TIMEOUT | StatusCode::TOO_MANY_REQUESTS | StatusCode::INTERNAL_SERVER_ERROR | StatusCode::BAD_GATEWAY | StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT),
/// |delay, attempt| RouterMetrics::record_retry_backoff_duration(delay, attempt),
/// || RouterMetrics::record_retries_exhausted("/route"),
/// ).await;
/// ```
pub async fn execute_response_with_retry<Op, Fut, ShouldRetry, OnBackoff, OnExhausted>(
config: &RetryConfig,
mut operation: Op,
should_retry: ShouldRetry,
on_backoff: OnBackoff,
mut on_exhausted: OnExhausted,
) -> Response
where
Op: FnMut(u32) -> Fut,
Fut: std::future::Future<Output = Response>,
ShouldRetry: Fn(&Response, u32) -> bool,
OnBackoff: Fn(Duration, u32),
OnExhausted: FnMut(),
{
let max = config.max_retries.max(1);
let mut attempt: u32 = 0;
loop {
let response = operation(attempt).await;
let is_last = attempt + 1 >= max;
if !should_retry(&response, attempt) {
return response;
}
if is_last {
// Exhausted retries
on_exhausted();
return response;
}
// Backoff before next attempt
let next_attempt = attempt + 1;
// Compute delay based on the number of failures so far (0-indexed)
let delay = BackoffCalculator::calculate_delay(config, attempt);
debug!(
attempt = attempt,
next_attempt = next_attempt,
delay_ms = delay.as_millis() as u64,
"Retry backoff"
);
on_backoff(delay, next_attempt);
tokio::time::sleep(delay).await;
attempt = next_attempt;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
fn base_retry_config() -> RetryConfig {
RetryConfig {
max_retries: 3,
initial_backoff_ms: 1,
max_backoff_ms: 4,
backoff_multiplier: 2.0,
jitter_factor: 0.0,
}
}
#[test]
fn test_backoff_no_jitter_progression_and_cap() {
let cfg = RetryConfig {
max_retries: 10,
initial_backoff_ms: 100,
max_backoff_ms: 250,
backoff_multiplier: 2.0,
jitter_factor: 0.0,
};
// attempt=0 => 100ms
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 0),
Duration::from_millis(100)
);
// attempt=1 => 200ms
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 1),
Duration::from_millis(200)
);
// attempt=2 => 400ms -> capped to 250ms
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 2),
Duration::from_millis(250)
);
// large attempt still capped
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 10),
Duration::from_millis(250)
);
}
#[test]
fn test_backoff_with_jitter_within_bounds() {
let cfg = RetryConfig {
max_retries: 5,
initial_backoff_ms: 100,
max_backoff_ms: 10_000,
backoff_multiplier: 2.0,
jitter_factor: 0.5,
};
// attempt=2 => base 400ms, jitter in [0.5x, 1.5x]
let base = 400.0;
for _ in 0..50 {
let d = BackoffCalculator::calculate_delay(&cfg, 2).as_millis() as f32;
assert!(d >= base * 0.5 - 1.0 && d <= base * 1.5 + 1.0);
}
}
#[tokio::test]
async fn test_execute_with_retry_success_after_failures() {
let cfg = base_retry_config();
let remaining = Arc::new(AtomicU32::new(2));
let calls = Arc::new(AtomicU32::new(0));
let res: Result<u32, RetryError> = RetryExecutor::execute_with_retry(&cfg, {
let remaining = remaining.clone();
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
let remaining = remaining.clone();
async move {
if remaining
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
.is_ok()
{
Err(())
} else {
Ok(42u32)
}
}
}
})
.await;
assert!(res.is_ok());
assert_eq!(res.unwrap(), 42);
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
}
#[tokio::test]
async fn test_execute_with_retry_exhausted() {
let cfg = base_retry_config();
let calls = Arc::new(AtomicU32::new(0));
let res: Result<u32, RetryError> = RetryExecutor::execute_with_retry(&cfg, {
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
async move { Err(()) }
}
})
.await;
assert!(matches!(res, Err(RetryError::MaxRetriesExceeded)));
assert_eq!(calls.load(Ordering::Relaxed), cfg.max_retries);
}
#[tokio::test]
async fn test_execute_response_with_retry_success_path_and_hooks() {
let cfg = base_retry_config();
let remaining = Arc::new(AtomicU32::new(2));
let calls = Arc::new(AtomicU32::new(0));
let backoffs = Arc::new(AtomicU32::new(0));
let exhausted = Arc::new(AtomicU32::new(0));
let response = RetryExecutor::execute_response_with_retry(
&cfg,
{
let remaining = remaining.clone();
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
let remaining = remaining.clone();
async move {
if remaining
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
.is_ok()
{
(StatusCode::SERVICE_UNAVAILABLE, "fail").into_response()
} else {
(StatusCode::OK, "ok").into_response()
}
}
}
},
|res, _attempt| !res.status().is_success(), // retry until success
{
let backoffs = backoffs.clone();
move |_delay, _next_attempt| {
backoffs.fetch_add(1, Ordering::Relaxed);
}
},
{
let exhausted = exhausted.clone();
move || {
exhausted.fetch_add(1, Ordering::Relaxed);
}
},
)
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
assert_eq!(backoffs.load(Ordering::Relaxed), 2);
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_execute_response_with_retry_non_retryable_short_circuit() {
let cfg = base_retry_config();
let calls = Arc::new(AtomicU32::new(0));
let backoffs = Arc::new(AtomicU32::new(0));
let exhausted = Arc::new(AtomicU32::new(0));
let response = RetryExecutor::execute_response_with_retry(
&cfg,
{
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
async move { (StatusCode::BAD_REQUEST, "bad").into_response() }
}
},
|_res, _attempt| false, // never retry
{
let backoffs = backoffs.clone();
move |_delay, _next_attempt| {
backoffs.fetch_add(1, Ordering::Relaxed);
}
},
{
let exhausted = exhausted.clone();
move || {
exhausted.fetch_add(1, Ordering::Relaxed);
}
},
)
.await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert_eq!(calls.load(Ordering::Relaxed), 1);
assert_eq!(backoffs.load(Ordering::Relaxed), 0);
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_execute_response_with_retry_exhausted_hooks() {
let cfg = base_retry_config();
let calls = Arc::new(AtomicU32::new(0));
let backoffs = Arc::new(AtomicU32::new(0));
let exhausted = Arc::new(AtomicU32::new(0));
let response = RetryExecutor::execute_response_with_retry(
&cfg,
{
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
async move { (StatusCode::SERVICE_UNAVAILABLE, "fail").into_response() }
}
},
|_res, _attempt| true, // keep retrying
{
let backoffs = backoffs.clone();
move |_delay, _next_attempt| {
backoffs.fetch_add(1, Ordering::Relaxed);
}
},
{
let exhausted = exhausted.clone();
move || {
exhausted.fetch_add(1, Ordering::Relaxed);
}
},
)
.await;
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(calls.load(Ordering::Relaxed), cfg.max_retries);
assert_eq!(backoffs.load(Ordering::Relaxed), cfg.max_retries - 1);
assert_eq!(exhausted.load(Ordering::Relaxed), 1);
}
}

View File

@@ -0,0 +1,195 @@
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Notify};
use tracing::{debug, trace};
/// Token bucket for rate limiting
///
/// This implementation provides:
/// - Smooth rate limiting with configurable refill rate
/// - Burst capacity handling
/// - Fair queuing for waiting requests
#[derive(Clone)]
pub struct TokenBucket {
inner: Arc<Mutex<TokenBucketInner>>,
notify: Arc<Notify>,
capacity: f64,
refill_rate: f64, // tokens per second
}
struct TokenBucketInner {
tokens: f64,
last_refill: Instant,
}
impl TokenBucket {
/// Create a new token bucket
///
/// # Arguments
/// * `capacity` - Maximum number of tokens (burst capacity)
/// * `refill_rate` - Tokens added per second
pub fn new(capacity: usize, refill_rate: usize) -> Self {
let capacity = capacity as f64;
let refill_rate = refill_rate as f64;
// Ensure refill_rate is not zero to prevent division by zero
let refill_rate = if refill_rate > 0.0 {
refill_rate
} else {
1.0 // Default to 1 token per second if zero
};
Self {
inner: Arc::new(Mutex::new(TokenBucketInner {
tokens: capacity, // Start full
last_refill: Instant::now(),
})),
notify: Arc::new(Notify::new()),
capacity,
refill_rate,
}
}
/// Try to acquire tokens immediately
pub async fn try_acquire(&self, tokens: f64) -> Result<(), ()> {
let mut inner = self.inner.lock().await;
// Refill tokens based on elapsed time
let now = Instant::now();
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
let refill_amount = elapsed * self.refill_rate;
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
inner.last_refill = now;
trace!(
"Token bucket: {} tokens available, requesting {}",
inner.tokens,
tokens
);
if inner.tokens >= tokens {
inner.tokens -= tokens;
debug!(
"Token bucket: acquired {} tokens, {} remaining",
tokens, inner.tokens
);
Ok(())
} else {
Err(())
}
}
/// Acquire tokens, waiting if necessary
pub async fn acquire(&self, tokens: f64) -> Result<(), tokio::time::error::Elapsed> {
// First try to acquire immediately
if self.try_acquire(tokens).await.is_ok() {
return Ok(());
}
// Calculate wait time
let wait_time = {
let inner = self.inner.lock().await;
let tokens_needed = tokens - inner.tokens;
let wait_secs = tokens_needed / self.refill_rate;
Duration::from_secs_f64(wait_secs)
};
debug!(
"Token bucket: waiting {:?} for {} tokens",
wait_time, tokens
);
// Wait for tokens to be available
tokio::time::timeout(wait_time, async {
loop {
// Check if we can acquire now
if self.try_acquire(tokens).await.is_ok() {
return;
}
// Wait for notification or small interval
tokio::select! {
_ = self.notify.notified() => {},
_ = tokio::time::sleep(Duration::from_millis(10)) => {},
}
}
})
.await?;
Ok(())
}
/// Acquire tokens with custom timeout
pub async fn acquire_timeout(
&self,
tokens: f64,
timeout: Duration,
) -> Result<(), tokio::time::error::Elapsed> {
tokio::time::timeout(timeout, self.acquire(tokens)).await?
}
/// Return tokens to the bucket (for cancelled requests)
pub async fn return_tokens(&self, tokens: f64) {
let mut inner = self.inner.lock().await;
inner.tokens = (inner.tokens + tokens).min(self.capacity);
self.notify.notify_waiters();
debug!(
"Token bucket: returned {} tokens, {} available",
tokens, inner.tokens
);
}
/// Get current available tokens (for monitoring)
pub async fn available_tokens(&self) -> f64 {
let mut inner = self.inner.lock().await;
// Refill before checking
let now = Instant::now();
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
let refill_amount = elapsed * self.refill_rate;
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
inner.last_refill = now;
inner.tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_token_bucket_basic() {
let bucket = TokenBucket::new(10, 5); // 10 capacity, 5 per second
// Should succeed - bucket starts full
assert!(bucket.try_acquire(5.0).await.is_ok());
assert!(bucket.try_acquire(5.0).await.is_ok());
// Should fail - no tokens left
assert!(bucket.try_acquire(1.0).await.is_err());
// Wait for refill
tokio::time::sleep(Duration::from_millis(300)).await;
// Should have ~1.5 tokens now
assert!(bucket.try_acquire(1.0).await.is_ok());
}
#[tokio::test]
async fn test_token_bucket_refill() {
let bucket = TokenBucket::new(10, 10); // 10 capacity, 10 per second
// Use all tokens
assert!(bucket.try_acquire(10.0).await.is_ok());
// Wait for partial refill
tokio::time::sleep(Duration::from_millis(500)).await;
// Should have ~5 tokens
let available = bucket.available_tokens().await;
assert!((4.0..=6.0).contains(&available));
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,327 @@
use std::time::Duration;
use tonic::{transport::Channel, Request};
use tracing::debug;
// Include the generated protobuf code
pub mod proto {
tonic::include_proto!("sglang.grpc.scheduler");
}
// The generated module structure depends on the package name in the .proto file
// package sglang.grpc.scheduler; generates a nested module structure
/// gRPC client for SGLang scheduler
pub struct SglangSchedulerClient {
client: proto::sglang_scheduler_client::SglangSchedulerClient<Channel>,
}
impl SglangSchedulerClient {
/// Create a new client and connect to the scheduler
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
debug!("Connecting to SGLang scheduler at {}", endpoint);
// Convert grpc:// to http:// for tonic
let http_endpoint = if endpoint.starts_with("grpc://") {
endpoint.replace("grpc://", "http://")
} else {
endpoint.to_string()
};
let channel = Channel::from_shared(http_endpoint)?
.timeout(Duration::from_secs(30))
.connect()
.await?;
let client = proto::sglang_scheduler_client::SglangSchedulerClient::new(channel);
Ok(Self { client })
}
/// Initialize the connection
pub async fn initialize(
&mut self,
client_id: String,
) -> Result<proto::InitializeResponse, Box<dyn std::error::Error>> {
let request = Request::new(proto::InitializeRequest {
client_id,
client_version: "0.1.0".to_string(),
mode: proto::initialize_request::Mode::Regular as i32,
});
let response = self.client.initialize(request).await?;
Ok(response.into_inner())
}
/// Submit a generation request (returns streaming response)
pub async fn generate_stream(
&mut self,
req: proto::GenerateRequest,
) -> Result<tonic::Streaming<proto::GenerateResponse>, Box<dyn std::error::Error>> {
let request = Request::new(req);
let response = self.client.generate(request).await?;
Ok(response.into_inner())
}
/// Perform health check
pub async fn health_check(
&mut self,
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
debug!("Sending health check request");
let request = Request::new(proto::HealthCheckRequest {
include_detailed_metrics: false,
});
let response = self.client.health_check(request).await?;
debug!("Health check response received");
Ok(response.into_inner())
}
/// Abort a request
pub async fn abort_request(
&mut self,
request_id: String,
reason: String,
) -> Result<(), Box<dyn std::error::Error>> {
let request = Request::new(proto::AbortRequest { request_id, reason });
self.client.abort(request).await?;
Ok(())
}
/// Flush cache
pub async fn flush_cache(
&mut self,
flush_all: bool,
session_ids: &[String],
) -> Result<proto::FlushCacheResponse, Box<dyn std::error::Error>> {
let request = Request::new(proto::FlushCacheRequest {
flush_all,
session_ids: session_ids.to_vec(),
});
let response = self.client.flush_cache(request).await?;
Ok(response.into_inner())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_proto_types_compilation() {
// Test that protobuf types can be constructed
let init_req = proto::InitializeRequest {
client_id: "test-client".to_string(),
client_version: "0.1.0".to_string(),
mode: 0,
};
assert_eq!(init_req.client_id, "test-client");
assert_eq!(init_req.client_version, "0.1.0");
assert_eq!(init_req.mode, 0);
}
#[test]
fn test_generate_request_construction() {
let sampling_params = proto::SamplingParams {
temperature: 0.7,
max_new_tokens: 128,
top_p: 0.9,
top_k: 50,
stop: vec!["</s>".to_string()],
..Default::default()
};
let gen_req = proto::GenerateRequest {
request_id: "test-req-123".to_string(),
input: Some(proto::generate_request::Input::Text(
"Hello world".to_string(),
)),
sampling_params: Some(sampling_params),
return_logprob: true,
logprob_start_len: 0,
top_logprobs_num: 5,
..Default::default()
};
assert_eq!(gen_req.request_id, "test-req-123");
if let Some(proto::generate_request::Input::Text(text)) = &gen_req.input {
assert_eq!(text, "Hello world");
}
assert!(gen_req.return_logprob);
assert_eq!(gen_req.top_logprobs_num, 5);
let params = gen_req.sampling_params.unwrap();
assert_eq!(params.temperature, 0.7);
assert_eq!(params.max_new_tokens, 128);
assert_eq!(params.stop, vec!["</s>"]);
}
#[test]
fn test_health_check_request() {
let health_req = proto::HealthCheckRequest {
include_detailed_metrics: true,
};
assert!(health_req.include_detailed_metrics);
}
#[test]
fn test_abort_request_construction() {
let abort_req = proto::AbortRequest {
request_id: "req-456".to_string(),
reason: "User canceled".to_string(),
};
assert_eq!(abort_req.request_id, "req-456");
assert_eq!(abort_req.reason, "User canceled");
}
#[test]
fn test_flush_cache_request() {
let flush_req = proto::FlushCacheRequest {
flush_all: true,
session_ids: vec!["session1".to_string(), "session2".to_string()],
};
assert!(flush_req.flush_all);
assert_eq!(flush_req.session_ids.len(), 2);
assert_eq!(flush_req.session_ids[0], "session1");
}
#[test]
fn test_sampling_params_defaults() {
let params = proto::SamplingParams::default();
assert_eq!(params.temperature, 0.0);
assert_eq!(params.max_new_tokens, 0);
assert_eq!(params.top_p, 0.0);
assert_eq!(params.top_k, 0);
assert!(params.stop.is_empty());
}
#[test]
fn test_multimodal_inputs() {
let mm_inputs = proto::MultimodalInputs {
image_urls: vec!["http://example.com/image.jpg".to_string()],
video_urls: vec![],
audio_urls: vec![],
image_data: vec![],
video_data: vec![],
audio_data: vec![],
modalities: vec!["image".to_string()],
..Default::default()
};
assert_eq!(mm_inputs.image_urls.len(), 1);
assert_eq!(mm_inputs.image_urls[0], "http://example.com/image.jpg");
assert_eq!(mm_inputs.modalities[0], "image");
}
#[test]
fn test_session_params() {
let session_params = proto::SessionParams {
session_id: "sess-789".to_string(),
request_id: "req-101".to_string(),
offset: 100,
replace: true,
drop_previous_output: false,
};
assert_eq!(session_params.session_id, "sess-789");
assert_eq!(session_params.request_id, "req-101");
assert_eq!(session_params.offset, 100);
assert!(session_params.replace);
assert!(!session_params.drop_previous_output);
}
#[test]
fn test_embed_request() {
let embed_req = proto::EmbedRequest {
request_id: "embed-req-202".to_string(),
input: Some(proto::embed_request::Input::Text(
"This is a test sentence for embedding".to_string(),
)),
log_metrics: true,
data_parallel_rank: 0,
..Default::default()
};
assert_eq!(embed_req.request_id, "embed-req-202");
if let Some(proto::embed_request::Input::Text(text)) = &embed_req.input {
assert_eq!(text, "This is a test sentence for embedding");
}
assert!(embed_req.log_metrics);
assert_eq!(embed_req.data_parallel_rank, 0);
}
#[tokio::test]
async fn test_client_connect_invalid_endpoint() {
// Test connecting to an invalid endpoint should return error
let result = SglangSchedulerClient::connect("invalid://endpoint").await;
assert!(result.is_err());
}
#[test]
fn test_tokenized_input() {
let tokenized = proto::TokenizedInput {
original_text: "Hello world".to_string(),
input_ids: vec![1, 15043, 1917, 2],
};
assert_eq!(tokenized.original_text, "Hello world");
assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]);
}
// Test response type construction
#[test]
fn test_generate_stream_chunk() {
let chunk = proto::GenerateStreamChunk {
token_id: 1234,
text: " world".to_string(),
prompt_tokens: 5,
completion_tokens: 2,
cached_tokens: 3,
generation_time: 0.025,
queue_time: 10,
..Default::default()
};
assert_eq!(chunk.token_id, 1234);
assert_eq!(chunk.text, " world");
assert_eq!(chunk.prompt_tokens, 5);
assert_eq!(chunk.completion_tokens, 2);
assert_eq!(chunk.cached_tokens, 3);
assert_eq!(chunk.generation_time, 0.025);
assert_eq!(chunk.queue_time, 10);
}
#[test]
fn test_model_info() {
let model_info = proto::ModelInfo {
model_name: "Meta-Llama-3-8B-Instruct".to_string(),
max_context_length: 8192,
vocab_size: 128256,
supports_tool_calling: true,
supports_vision: false,
special_tokens: vec![
"<|begin_of_text|>".to_string(),
"<|end_of_text|>".to_string(),
],
model_type: "llama".to_string(),
num_layers: 32,
hidden_size: 4096,
num_attention_heads: 32,
num_key_value_heads: 8,
tokenizer_type: "llama".to_string(),
eos_token_ids: vec![128001, 128009],
pad_token_id: 128001,
bos_token_id: 128000,
};
assert_eq!(model_info.model_name, "Meta-Llama-3-8B-Instruct");
assert_eq!(model_info.max_context_length, 8192);
assert_eq!(model_info.vocab_size, 128256);
assert!(model_info.supports_tool_calling);
assert!(!model_info.supports_vision);
assert_eq!(model_info.special_tokens.len(), 2);
assert_eq!(model_info.num_layers, 32);
assert_eq!(model_info.eos_token_ids, vec![128001, 128009]);
}
}

View File

@@ -0,0 +1,8 @@
//! gRPC client module for communicating with SGLang scheduler
//!
//! This module provides a gRPC client implementation for the SGLang router.
pub mod client;
// Re-export the client
pub use client::{proto, SglangSchedulerClient};

508
sgl-router/src/lib.rs Normal file
View File

@@ -0,0 +1,508 @@
use pyo3::prelude::*;
pub mod config;
pub mod logging;
use std::collections::HashMap;
pub mod core;
#[cfg(feature = "grpc-client")]
pub mod grpc;
pub mod mcp;
pub mod metrics;
pub mod middleware;
pub mod policies;
pub mod protocols;
pub mod reasoning_parser;
pub mod routers;
pub mod server;
pub mod service_discovery;
pub mod tokenizer;
pub mod tool_parser;
pub mod tree;
use crate::metrics::PrometheusConfig;
#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum PolicyType {
Random,
RoundRobin,
CacheAware,
PowerOfTwo,
}
#[pyclass]
#[derive(Debug, Clone, PartialEq)]
struct Router {
host: String,
port: u16,
worker_urls: Vec<String>,
policy: PolicyType,
worker_startup_timeout_secs: u64,
worker_startup_check_interval: u64,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
dp_aware: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
service_discovery: bool,
selector: HashMap<String, String>,
service_discovery_port: u16,
service_discovery_namespace: Option<String>,
prefill_selector: HashMap<String, String>,
decode_selector: HashMap<String, String>,
bootstrap_port_annotation: String,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
request_id_headers: Option<Vec<String>>,
pd_disaggregation: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
// Retry configuration
retry_max_retries: u32,
retry_initial_backoff_ms: u64,
retry_max_backoff_ms: u64,
retry_backoff_multiplier: f32,
retry_jitter_factor: f32,
disable_retries: bool,
// Circuit breaker configuration
cb_failure_threshold: u32,
cb_success_threshold: u32,
cb_timeout_duration_secs: u64,
cb_window_duration_secs: u64,
disable_circuit_breaker: bool,
// Health check configuration
health_failure_threshold: u32,
health_success_threshold: u32,
health_check_timeout_secs: u64,
health_check_interval_secs: u64,
health_check_endpoint: String,
// IGW (Inference Gateway) configuration
enable_igw: bool,
queue_size: usize,
queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<usize>,
// Connection mode (determined from worker URLs)
connection_mode: config::ConnectionMode,
// Model path for tokenizer
model_path: Option<String>,
// Explicit tokenizer path
tokenizer_path: Option<String>,
}
impl Router {
/// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return config::ConnectionMode::Grpc;
}
}
// Default to HTTP for all other cases (including http://, https://, or no scheme)
config::ConnectionMode::Http
}
/// Convert PyO3 Router to RouterConfig
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
use config::{
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
};
// Convert policy helper function
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
match policy {
PolicyType::Random => ConfigPolicyConfig::Random,
PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value
},
}
};
// Determine routing mode
let mode = if self.enable_igw {
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode::Regular {
worker_urls: vec![],
}
} else if self.pd_disaggregation {
RoutingMode::PrefillDecode {
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
decode_urls: self.decode_urls.clone().unwrap_or_default(),
prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
decode_policy: self.decode_policy.as_ref().map(convert_policy),
}
} else {
RoutingMode::Regular {
worker_urls: self.worker_urls.clone(),
}
};
// Convert main policy
let policy = convert_policy(&self.policy);
// Service discovery configuration
let discovery = if self.service_discovery {
Some(DiscoveryConfig {
enabled: true,
namespace: self.service_discovery_namespace.clone(),
port: self.service_discovery_port,
check_interval_secs: 60,
selector: self.selector.clone(),
prefill_selector: self.prefill_selector.clone(),
decode_selector: self.decode_selector.clone(),
bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
})
} else {
None
};
// Metrics configuration
let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
(Some(port), Some(host)) => Some(MetricsConfig {
port,
host: host.clone(),
}),
_ => None,
};
Ok(config::RouterConfig {
mode,
policy,
host: self.host.clone(),
port: self.port,
connection_mode: self.connection_mode.clone(),
max_payload_size: self.max_payload_size,
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
dp_aware: self.dp_aware,
api_key: self.api_key.clone(),
discovery,
metrics,
log_dir: self.log_dir.clone(),
log_level: self.log_level.clone(),
request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
queue_size: self.queue_size,
queue_timeout_secs: self.queue_timeout_secs,
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig {
max_retries: self.retry_max_retries,
initial_backoff_ms: self.retry_initial_backoff_ms,
max_backoff_ms: self.retry_max_backoff_ms,
backoff_multiplier: self.retry_backoff_multiplier,
jitter_factor: self.retry_jitter_factor,
},
circuit_breaker: config::CircuitBreakerConfig {
failure_threshold: self.cb_failure_threshold,
success_threshold: self.cb_success_threshold,
timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs,
},
disable_retries: self.disable_retries,
disable_circuit_breaker: self.disable_circuit_breaker,
health_check: config::HealthCheckConfig {
failure_threshold: self.health_failure_threshold,
success_threshold: self.health_success_threshold,
timeout_secs: self.health_check_timeout_secs,
check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(),
},
enable_igw: self.enable_igw,
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
})
}
}
#[pymethods]
impl Router {
#[new]
#[pyo3(signature = (
worker_urls,
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
port = 3001,
worker_startup_timeout_secs = 600,
worker_startup_check_interval = 30,
cache_threshold = 0.3,
balance_abs_threshold = 64,
balance_rel_threshold = 1.5,
eviction_interval_secs = 120,
max_tree_size = 2usize.pow(26),
max_payload_size = 512 * 1024 * 1024, // 512MB default for large batches
dp_aware = false,
api_key = None,
log_dir = None,
log_level = None,
service_discovery = false,
selector = HashMap::new(),
service_discovery_port = 80,
service_discovery_namespace = None,
prefill_selector = HashMap::new(),
decode_selector = HashMap::new(),
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
prometheus_port = None,
prometheus_host = None,
request_timeout_secs = 1800, // Add configurable request timeout
request_id_headers = None, // Custom request ID headers
pd_disaggregation = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None,
prefill_policy = None,
decode_policy = None,
max_concurrent_requests = 256,
cors_allowed_origins = vec![],
// Retry defaults
retry_max_retries = 5,
retry_initial_backoff_ms = 50,
retry_max_backoff_ms = 30_000,
retry_backoff_multiplier = 1.5,
retry_jitter_factor = 0.2,
disable_retries = false,
// Circuit breaker defaults
cb_failure_threshold = 10,
cb_success_threshold = 3,
cb_timeout_duration_secs = 60,
cb_window_duration_secs = 120,
disable_circuit_breaker = false,
// Health check defaults
health_failure_threshold = 3,
health_success_threshold = 2,
health_check_timeout_secs = 5,
health_check_interval_secs = 60,
health_check_endpoint = String::from("/health"),
// IGW defaults
enable_igw = false,
queue_size = 100,
queue_timeout_secs = 60,
rate_limit_tokens_per_second = None,
// Tokenizer defaults
model_path = None,
tokenizer_path = None,
))]
#[allow(clippy::too_many_arguments)]
fn new(
worker_urls: Vec<String>,
policy: PolicyType,
host: String,
port: u16,
worker_startup_timeout_secs: u64,
worker_startup_check_interval: u64,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
dp_aware: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
service_discovery: bool,
selector: HashMap<String, String>,
service_discovery_port: u16,
service_discovery_namespace: Option<String>,
prefill_selector: HashMap<String, String>,
decode_selector: HashMap<String, String>,
bootstrap_port_annotation: String,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
request_id_headers: Option<Vec<String>>,
pd_disaggregation: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
retry_max_retries: u32,
retry_initial_backoff_ms: u64,
retry_max_backoff_ms: u64,
retry_backoff_multiplier: f32,
retry_jitter_factor: f32,
disable_retries: bool,
cb_failure_threshold: u32,
cb_success_threshold: u32,
cb_timeout_duration_secs: u64,
cb_window_duration_secs: u64,
disable_circuit_breaker: bool,
health_failure_threshold: u32,
health_success_threshold: u32,
health_check_timeout_secs: u64,
health_check_interval_secs: u64,
health_check_endpoint: String,
enable_igw: bool,
queue_size: usize,
queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<usize>,
model_path: Option<String>,
tokenizer_path: Option<String>,
) -> PyResult<Self> {
// Determine connection mode from worker URLs
let mut all_urls = worker_urls.clone();
// Add prefill URLs if in PD mode
if let Some(ref prefill_urls) = prefill_urls {
for (url, _) in prefill_urls {
all_urls.push(url.clone());
}
}
// Add decode URLs if in PD mode
if let Some(ref decode_urls) = decode_urls {
all_urls.extend(decode_urls.clone());
}
let connection_mode = Self::determine_connection_mode(&all_urls);
Ok(Router {
host,
port,
worker_urls,
policy,
worker_startup_timeout_secs,
worker_startup_check_interval,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
max_payload_size,
dp_aware,
api_key,
log_dir,
log_level,
service_discovery,
selector,
service_discovery_port,
service_discovery_namespace,
prefill_selector,
decode_selector,
bootstrap_port_annotation,
prometheus_port,
prometheus_host,
request_timeout_secs,
request_id_headers,
pd_disaggregation,
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
max_concurrent_requests,
cors_allowed_origins,
retry_max_retries,
retry_initial_backoff_ms,
retry_max_backoff_ms,
retry_backoff_multiplier,
retry_jitter_factor,
disable_retries,
cb_failure_threshold,
cb_success_threshold,
cb_timeout_duration_secs,
cb_window_duration_secs,
disable_circuit_breaker,
health_failure_threshold,
health_success_threshold,
health_check_timeout_secs,
health_check_interval_secs,
health_check_endpoint,
enable_igw,
queue_size,
queue_timeout_secs,
rate_limit_tokens_per_second,
connection_mode,
model_path,
tokenizer_path,
})
}
fn start(&self) -> PyResult<()> {
// Convert to RouterConfig and validate
let router_config = self.to_router_config().map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
})?;
// Validate the configuration
router_config.validate().map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!(
"Configuration validation failed: {}",
e
))
})?;
// Create service discovery config if enabled
let service_discovery_config = if self.service_discovery {
Some(service_discovery::ServiceDiscoveryConfig {
enabled: true,
selector: self.selector.clone(),
check_interval: std::time::Duration::from_secs(60),
port: self.service_discovery_port,
namespace: self.service_discovery_namespace.clone(),
pd_mode: self.pd_disaggregation,
prefill_selector: self.prefill_selector.clone(),
decode_selector: self.decode_selector.clone(),
bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
})
} else {
None
};
// Create Prometheus config if enabled
let prometheus_config = Some(PrometheusConfig {
port: self.prometheus_port.unwrap_or(29000),
host: self
.prometheus_host
.clone()
.unwrap_or_else(|| "127.0.0.1".to_string()),
});
// Use tokio runtime instead of actix-web System for better compatibility
let runtime = tokio::runtime::Runtime::new()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
// Block on the async startup function
runtime.block_on(async move {
server::startup(server::ServerConfig {
host: self.host.clone(),
port: self.port,
router_config,
max_payload_size: self.max_payload_size,
log_dir: self.log_dir.clone(),
log_level: self.log_level.clone(),
service_discovery_config,
prometheus_config,
request_timeout_secs: self.request_timeout_secs,
request_id_headers: self.request_id_headers.clone(),
})
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
})
}
}
#[pymodule]
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PolicyType>()?;
m.add_class::<Router>()?;
Ok(())
}

163
sgl-router/src/logging.rs Normal file
View File

@@ -0,0 +1,163 @@
use std::path::PathBuf;
use tracing::Level;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_log::LogTracer;
use tracing_subscriber::fmt::time::ChronoUtc;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
/// Configuration for the logging system
#[derive(Debug, Clone)]
pub struct LoggingConfig {
/// Log level for the application (default: INFO)
pub level: Level,
/// Whether to use json format for logs (default: false)
pub json_format: bool,
/// Path to store log files. If None, logs will only go to stdout/stderr
pub log_dir: Option<String>,
/// Whether to colorize logs when output is a terminal (default: true)
pub colorize: bool,
/// Log file name to use if log_dir is specified (default: "sgl-router")
pub log_file_name: String,
/// Custom log targets to filter (default: "sglang_router_rs")
pub log_targets: Option<Vec<String>>,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: Level::INFO,
json_format: false,
log_dir: None,
colorize: true,
log_file_name: "sgl-router".to_string(),
log_targets: Some(vec!["sglang_router_rs".to_string()]),
}
}
}
/// Guard that keeps the file appender worker thread alive
///
/// This must be kept in scope for the duration of the program
/// to ensure logs are properly written to files
#[allow(dead_code)]
pub struct LogGuard {
_file_guard: Option<WorkerGuard>,
}
/// Initialize the logging system with the given configuration
///
/// # Arguments
/// * `config` - Configuration for the logging system
///
/// # Returns
/// A LogGuard that must be kept alive for the duration of the program
///
/// # Panics
/// Will not panic, as initialization errors are handled gracefully
pub fn init_logging(config: LoggingConfig) -> LogGuard {
// Forward logs to tracing - ignore errors to allow for multiple initialization
let _ = LogTracer::init();
// Convert log level to filter string
let level_filter = match config.level {
Level::TRACE => "trace",
Level::DEBUG => "debug",
Level::INFO => "info",
Level::WARN => "warn",
Level::ERROR => "error",
};
// Create env filter
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
// Format: <target>=<level>,<target2>=<level2>,...
let filter_string = if let Some(targets) = &config.log_targets {
targets
.iter()
.enumerate()
.map(|(i, target)| {
if i > 0 {
format!(",{}={}", target, level_filter)
} else {
format!("{}={}", target, level_filter)
}
})
.collect::<String>()
} else {
format!("sglang_router_rs={}", level_filter)
};
EnvFilter::new(filter_string)
});
// Setup stdout/stderr layer
let mut layers = Vec::new();
// Standard timestamp format: YYYY-MM-DD HH:MM:SS
let time_format = "%Y-%m-%d %H:%M:%S".to_string();
// Configure the console stdout layer
let stdout_layer = tracing_subscriber::fmt::layer()
.with_ansi(config.colorize)
.with_file(true)
.with_line_number(true)
.with_timer(ChronoUtc::new(time_format.clone()));
let stdout_layer = if config.json_format {
stdout_layer.json().flatten_event(true).boxed()
} else {
stdout_layer.boxed()
};
layers.push(stdout_layer);
// Create a file appender if log_dir is specified
let mut file_guard = None;
if let Some(log_dir) = &config.log_dir {
let file_name = config.log_file_name.clone();
let log_dir = PathBuf::from(log_dir);
// Create log directory if it doesn't exist
if !log_dir.exists() {
if let Err(e) = std::fs::create_dir_all(&log_dir) {
eprintln!("Failed to create log directory: {}", e);
return LogGuard { _file_guard: None };
}
}
let file_appender = RollingFileAppender::new(Rotation::DAILY, log_dir, file_name);
let (non_blocking, guard) = tracing_appender::non_blocking(file_appender);
file_guard = Some(guard);
let file_layer = tracing_subscriber::fmt::layer()
.with_ansi(false) // Never use ANSI colors in log files
.with_file(true)
.with_line_number(true)
.with_timer(ChronoUtc::new(time_format))
.with_writer(non_blocking);
let file_layer = if config.json_format {
file_layer.json().flatten_event(true).boxed()
} else {
file_layer.boxed()
};
layers.push(file_layer);
}
// Initialize the subscriber with all layers
// Use try_init to handle errors gracefully in case another subscriber is already set
let _ = tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.try_init();
// Return the guard to keep the file appender worker thread alive
LogGuard {
_file_guard: file_guard,
}
}

636
sgl-router/src/main.rs Normal file
View File

@@ -0,0 +1,636 @@
use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
use sglang_router_rs::metrics::PrometheusConfig;
use sglang_router_rs::server::{self, ServerConfig};
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
use std::collections::HashMap;
// Helper function to parse prefill arguments from command line
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
let args: Vec<String> = std::env::args().collect();
let mut prefill_entries = Vec::new();
let mut i = 0;
while i < args.len() {
if args[i] == "--prefill" && i + 1 < args.len() {
let url = args[i + 1].clone();
let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") {
// Check if next arg is a port number
if let Ok(port) = args[i + 2].parse::<u16>() {
i += 1; // Skip the port argument
Some(port)
} else if args[i + 2].to_lowercase() == "none" {
i += 1; // Skip the "none" argument
None
} else {
None
}
} else {
None
};
prefill_entries.push((url, bootstrap_port));
i += 2; // Skip --prefill and URL
} else {
i += 1;
}
}
prefill_entries
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
pub enum Backend {
#[value(name = "sglang")]
Sglang,
#[value(name = "vllm")]
Vllm,
#[value(name = "trtllm")]
Trtllm,
#[value(name = "openai")]
Openai,
#[value(name = "anthropic")]
Anthropic,
}
impl std::fmt::Display for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Backend::Sglang => "sglang",
Backend::Vllm => "vllm",
Backend::Trtllm => "trtllm",
Backend::Openai => "openai",
Backend::Anthropic => "anthropic",
};
write!(f, "{}", s)
}
}
#[derive(Parser, Debug)]
#[command(name = "sglang-router")]
#[command(about = "SGLang Router - High-performance request distribution across worker nodes")]
#[command(long_about = r#"
SGLang Router - High-performance request distribution across worker nodes
Usage:
This launcher enables starting a router with individual worker instances. It is useful for
multi-node setups or when you want to start workers and router separately.
Examples:
# Regular mode
sglang-router --worker-urls http://worker1:8000 http://worker2:8000
# PD disaggregated mode with same policy for both
sglang-router --pd-disaggregation \
--prefill http://127.0.0.1:30001 9001 \
--prefill http://127.0.0.2:30002 9002 \
--decode http://127.0.0.3:30003 \
--decode http://127.0.0.4:30004 \
--policy cache_aware
# PD mode with different policies for prefill and decode
sglang-router --pd-disaggregation \
--prefill http://127.0.0.1:30001 9001 \
--prefill http://127.0.0.2:30002 \
--decode http://127.0.0.3:30003 \
--decode http://127.0.0.4:30004 \
--prefill-policy cache_aware --decode-policy power_of_two
"#)]
struct CliArgs {
/// Host address to bind the router server
#[arg(long, default_value = "127.0.0.1")]
host: String,
/// Port number to bind the router server
#[arg(long, default_value_t = 30000)]
port: u16,
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
#[arg(long, num_args = 0..)]
worker_urls: Vec<String>,
/// Load balancing policy to use
#[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
policy: String,
/// Enable PD (Prefill-Decode) disaggregated mode
#[arg(long, default_value_t = false)]
pd_disaggregation: bool,
/// Decode server URL (can be specified multiple times)
#[arg(long, action = ArgAction::Append)]
decode: Vec<String>,
/// Specific policy for prefill nodes in PD mode
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
prefill_policy: Option<String>,
/// Specific policy for decode nodes in PD mode
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
decode_policy: Option<String>,
/// Timeout in seconds for worker startup
#[arg(long, default_value_t = 600)]
worker_startup_timeout_secs: u64,
/// Interval in seconds between checks for worker startup
#[arg(long, default_value_t = 30)]
worker_startup_check_interval: u64,
/// Cache threshold (0.0-1.0) for cache-aware routing
#[arg(long, default_value_t = 0.3)]
cache_threshold: f32,
/// Absolute threshold for load balancing
#[arg(long, default_value_t = 64)]
balance_abs_threshold: usize,
/// Relative threshold for load balancing
#[arg(long, default_value_t = 1.5)]
balance_rel_threshold: f32,
/// Interval in seconds between cache eviction operations
#[arg(long, default_value_t = 120)]
eviction_interval: u64,
/// Maximum size of the approximation tree for cache-aware routing
#[arg(long, default_value_t = 67108864)] // 2^26
max_tree_size: usize,
/// Maximum payload size in bytes
#[arg(long, default_value_t = 536870912)] // 512MB
max_payload_size: usize,
/// Enable data parallelism aware schedule
#[arg(long, default_value_t = false)]
dp_aware: bool,
/// API key for worker authorization
#[arg(long)]
api_key: Option<String>,
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
backend: Backend,
/// Directory to store log files
#[arg(long)]
log_dir: Option<String>,
/// Set the logging level
#[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])]
log_level: String,
/// Enable Kubernetes service discovery
#[arg(long, default_value_t = false)]
service_discovery: bool,
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
#[arg(long, num_args = 0..)]
selector: Vec<String>,
/// Port to use for discovered worker pods
#[arg(long, default_value_t = 80)]
service_discovery_port: u16,
/// Kubernetes namespace to watch for pods
#[arg(long)]
service_discovery_namespace: Option<String>,
/// Label selector for prefill server pods in PD mode
#[arg(long, num_args = 0..)]
prefill_selector: Vec<String>,
/// Label selector for decode server pods in PD mode
#[arg(long, num_args = 0..)]
decode_selector: Vec<String>,
/// Port to expose Prometheus metrics
#[arg(long, default_value_t = 29000)]
prometheus_port: u16,
/// Host address to bind the Prometheus metrics server
#[arg(long, default_value = "127.0.0.1")]
prometheus_host: String,
/// Custom HTTP headers to check for request IDs
#[arg(long, num_args = 0..)]
request_id_headers: Vec<String>,
/// Request timeout in seconds
#[arg(long, default_value_t = 1800)]
request_timeout_secs: u64,
/// Maximum number of concurrent requests allowed
#[arg(long, default_value_t = 256)]
max_concurrent_requests: usize,
/// CORS allowed origins
#[arg(long, num_args = 0..)]
cors_allowed_origins: Vec<String>,
// Retry configuration
/// Maximum number of retries
#[arg(long, default_value_t = 5)]
retry_max_retries: u32,
/// Initial backoff in milliseconds for retries
#[arg(long, default_value_t = 50)]
retry_initial_backoff_ms: u64,
/// Maximum backoff in milliseconds for retries
#[arg(long, default_value_t = 30000)]
retry_max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
#[arg(long, default_value_t = 1.5)]
retry_backoff_multiplier: f32,
/// Jitter factor for retry backoff
#[arg(long, default_value_t = 0.2)]
retry_jitter_factor: f32,
/// Disable retries
#[arg(long, default_value_t = false)]
disable_retries: bool,
// Circuit breaker configuration
/// Number of failures before circuit breaker opens
#[arg(long, default_value_t = 10)]
cb_failure_threshold: u32,
/// Number of successes before circuit breaker closes
#[arg(long, default_value_t = 3)]
cb_success_threshold: u32,
/// Timeout duration in seconds for circuit breaker
#[arg(long, default_value_t = 60)]
cb_timeout_duration_secs: u64,
/// Window duration in seconds for circuit breaker
#[arg(long, default_value_t = 120)]
cb_window_duration_secs: u64,
/// Disable circuit breaker
#[arg(long, default_value_t = false)]
disable_circuit_breaker: bool,
// Health check configuration
/// Number of consecutive health check failures before marking worker unhealthy
#[arg(long, default_value_t = 3)]
health_failure_threshold: u32,
/// Number of consecutive health check successes before marking worker healthy
#[arg(long, default_value_t = 2)]
health_success_threshold: u32,
/// Timeout in seconds for health check requests
#[arg(long, default_value_t = 5)]
health_check_timeout_secs: u64,
/// Interval in seconds between runtime health checks
#[arg(long, default_value_t = 60)]
health_check_interval_secs: u64,
/// Health check endpoint path
#[arg(long, default_value = "/health")]
health_check_endpoint: String,
// IGW (Inference Gateway) configuration
/// Enable Inference Gateway mode
#[arg(long, default_value_t = false)]
enable_igw: bool,
// Tokenizer configuration
/// Model path for loading tokenizer (HuggingFace model ID or local path)
#[arg(long)]
model_path: Option<String>,
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
#[arg(long)]
tokenizer_path: Option<String>,
}
impl CliArgs {
/// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return ConnectionMode::Grpc;
}
}
// Default to HTTP for all other cases (including http://, https://, or no scheme)
ConnectionMode::Http
}
/// Parse selector strings into HashMap
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
let mut map = HashMap::new();
for item in selector_list {
if let Some(eq_pos) = item.find('=') {
let key = item[..eq_pos].to_string();
let value = item[eq_pos + 1..].to_string();
map.insert(key, value);
}
}
map
}
/// Convert policy string to PolicyConfig
fn parse_policy(&self, policy_str: &str) -> PolicyConfig {
match policy_str {
"random" => PolicyConfig::Random,
"round_robin" => PolicyConfig::RoundRobin,
"cache_aware" => PolicyConfig::CacheAware {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval,
max_tree_size: self.max_tree_size,
},
"power_of_two" => PolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value
},
_ => PolicyConfig::RoundRobin, // Fallback
}
}
/// Convert CLI arguments to RouterConfig
fn to_router_config(
&self,
prefill_urls: Vec<(String, Option<u16>)>,
) -> ConfigResult<RouterConfig> {
// Determine routing mode
let mode = if self.enable_igw {
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode::Regular {
worker_urls: vec![],
}
} else if matches!(self.backend, Backend::Openai) {
// OpenAI backend mode - use worker_urls as base(s)
RoutingMode::OpenAI {
worker_urls: self.worker_urls.clone(),
}
} else if self.pd_disaggregation {
let decode_urls = self.decode.clone();
// Validate PD configuration if not using service discovery
if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) {
return Err(ConfigError::ValidationFailed {
reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(),
});
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy: self.prefill_policy.as_ref().map(|p| self.parse_policy(p)),
decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)),
}
} else {
// Regular mode
if !self.service_discovery && self.worker_urls.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: "Regular mode requires --worker-urls when not using service discovery"
.to_string(),
});
}
RoutingMode::Regular {
worker_urls: self.worker_urls.clone(),
}
};
// Main policy
let policy = self.parse_policy(&self.policy);
// Service discovery configuration
let discovery = if self.service_discovery {
Some(DiscoveryConfig {
enabled: true,
namespace: self.service_discovery_namespace.clone(),
port: self.service_discovery_port,
check_interval_secs: 60,
selector: Self::parse_selector(&self.selector),
prefill_selector: Self::parse_selector(&self.prefill_selector),
decode_selector: Self::parse_selector(&self.decode_selector),
bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(),
})
} else {
None
};
// Metrics configuration
let metrics = Some(MetricsConfig {
port: self.prometheus_port,
host: self.prometheus_host.clone(),
});
// Determine connection mode from all worker URLs
let mut all_urls = Vec::new();
match &mode {
RoutingMode::Regular { worker_urls } => {
all_urls.extend(worker_urls.clone());
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
..
} => {
for (url, _) in prefill_urls {
all_urls.push(url.clone());
}
all_urls.extend(decode_urls.clone());
}
RoutingMode::OpenAI { .. } => {
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
}
}
let connection_mode = match &mode {
RoutingMode::OpenAI { .. } => ConnectionMode::Http,
_ => Self::determine_connection_mode(&all_urls),
};
// Build RouterConfig
Ok(RouterConfig {
mode,
policy,
connection_mode,
host: self.host.clone(),
port: self.port,
max_payload_size: self.max_payload_size,
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
dp_aware: self.dp_aware,
api_key: self.api_key.clone(),
discovery,
metrics,
log_dir: self.log_dir.clone(),
log_level: Some(self.log_level.clone()),
request_id_headers: if self.request_id_headers.is_empty() {
None
} else {
Some(self.request_id_headers.clone())
},
max_concurrent_requests: self.max_concurrent_requests,
queue_size: 100, // Default queue size
queue_timeout_secs: 60, // Default timeout
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: RetryConfig {
max_retries: self.retry_max_retries,
initial_backoff_ms: self.retry_initial_backoff_ms,
max_backoff_ms: self.retry_max_backoff_ms,
backoff_multiplier: self.retry_backoff_multiplier,
jitter_factor: self.retry_jitter_factor,
},
circuit_breaker: CircuitBreakerConfig {
failure_threshold: self.cb_failure_threshold,
success_threshold: self.cb_success_threshold,
timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs,
},
disable_retries: self.disable_retries,
disable_circuit_breaker: self.disable_circuit_breaker,
health_check: HealthCheckConfig {
failure_threshold: self.health_failure_threshold,
success_threshold: self.health_success_threshold,
timeout_secs: self.health_check_timeout_secs,
check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(),
},
enable_igw: self.enable_igw,
rate_limit_tokens_per_second: None,
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
})
}
/// Create ServerConfig from CLI args and RouterConfig
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
// Create service discovery config if enabled
let service_discovery_config = if self.service_discovery {
Some(ServiceDiscoveryConfig {
enabled: true,
selector: Self::parse_selector(&self.selector),
check_interval: std::time::Duration::from_secs(60),
port: self.service_discovery_port,
namespace: self.service_discovery_namespace.clone(),
pd_mode: self.pd_disaggregation,
prefill_selector: Self::parse_selector(&self.prefill_selector),
decode_selector: Self::parse_selector(&self.decode_selector),
bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(),
})
} else {
None
};
// Create Prometheus config
let prometheus_config = Some(PrometheusConfig {
port: self.prometheus_port,
host: self.prometheus_host.clone(),
});
ServerConfig {
host: self.host.clone(),
port: self.port,
router_config,
max_payload_size: self.max_payload_size,
log_dir: self.log_dir.clone(),
log_level: Some(self.log_level.clone()),
service_discovery_config,
prometheus_config,
request_timeout_secs: self.request_timeout_secs,
request_id_headers: if self.request_id_headers.is_empty() {
None
} else {
Some(self.request_id_headers.clone())
},
}
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Parse prefill arguments manually before clap parsing
let prefill_urls = parse_prefill_args();
// Filter out prefill arguments and their values before passing to clap
let mut filtered_args: Vec<String> = Vec::new();
let raw_args: Vec<String> = std::env::args().collect();
let mut i = 0;
while i < raw_args.len() {
if raw_args[i] == "--prefill" && i + 1 < raw_args.len() {
// Skip --prefill and its URL
i += 2;
// Also skip bootstrap port if present
if i < raw_args.len()
&& !raw_args[i].starts_with("--")
&& (raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none")
{
i += 1;
}
} else {
filtered_args.push(raw_args[i].clone());
i += 1;
}
}
// Parse CLI arguments with clap using filtered args
let cli_args = CliArgs::parse_from(filtered_args);
// Print startup info
println!("SGLang Router starting...");
println!("Host: {}:{}", cli_args.host, cli_args.port);
let mode_str = if cli_args.enable_igw {
"IGW (Inference Gateway)".to_string()
} else if matches!(cli_args.backend, Backend::Openai) {
"OpenAI Backend".to_string()
} else if cli_args.pd_disaggregation {
"PD Disaggregated".to_string()
} else {
format!("Regular ({})", cli_args.backend)
};
println!("Mode: {}", mode_str);
// Warn for runtimes that are parsed but not yet implemented
match cli_args.backend {
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
println!(
"WARNING: runtime '{}' not implemented yet; falling back to regular routing. \
Provide --worker-urls or PD flags as usual.",
cli_args.backend
);
}
Backend::Sglang | Backend::Openai => {}
}
if !cli_args.enable_igw {
println!("Policy: {}", cli_args.policy);
if cli_args.pd_disaggregation && !prefill_urls.is_empty() {
println!("Prefill nodes: {:?}", prefill_urls);
println!("Decode nodes: {:?}", cli_args.decode);
}
}
// Convert to RouterConfig
let router_config = cli_args.to_router_config(prefill_urls)?;
// Validate configuration
router_config.validate()?;
// Create ServerConfig
let server_config = cli_args.to_server_config(router_config);
// Create a new runtime for the server (like Python binding does)
let runtime = tokio::runtime::Runtime::new()?;
// Block on the async startup function
runtime.block_on(async move { server::startup(server_config).await })?;
Ok(())
}

View File

@@ -0,0 +1,535 @@
use backoff::ExponentialBackoffBuilder;
use dashmap::DashMap;
use rmcp::{
model::{
CallToolRequestParam, GetPromptRequestParam, GetPromptResult, Prompt,
ReadResourceRequestParam, ReadResourceResult, Resource, Tool as McpTool,
},
service::RunningService,
transport::{
sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig,
ConfigureCommandExt, SseClientTransport, StreamableHttpClientTransport, TokioChildProcess,
},
RoleClient, ServiceExt,
};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, time::Duration};
use crate::mcp::{
config::{McpConfig, McpServerConfig, McpTransport},
error::{McpError, McpResult},
};
/// Information about an available tool
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
pub name: String,
pub description: String,
pub server: String,
pub parameters: Option<serde_json::Value>,
}
/// Information about an available prompt
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptInfo {
pub name: String,
pub description: Option<String>,
pub server: String,
pub arguments: Option<Vec<serde_json::Value>>,
}
/// Information about an available resource
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceInfo {
pub uri: String,
pub name: String,
pub description: Option<String>,
pub mime_type: Option<String>,
pub server: String,
}
/// Manages MCP client connections and tool execution
pub struct McpClientManager {
/// Map of server_name -> MCP client
clients: HashMap<String, RunningService<RoleClient, ()>>,
/// Map of tool_name -> (server_name, tool_definition)
tools: DashMap<String, (String, McpTool)>,
/// Map of prompt_name -> (server_name, prompt_definition)
prompts: DashMap<String, (String, Prompt)>,
/// Map of resource_uri -> (server_name, resource_definition)
resources: DashMap<String, (String, Resource)>,
}
impl McpClientManager {
/// Create a new manager and connect to all configured servers
pub async fn new(config: McpConfig) -> McpResult<Self> {
let mut mgr = Self {
clients: HashMap::new(),
tools: DashMap::new(),
prompts: DashMap::new(),
resources: DashMap::new(),
};
for server_config in config.servers {
match Self::connect_server(&server_config).await {
Ok(client) => {
mgr.load_server_inventory(&server_config.name, &client)
.await;
mgr.clients.insert(server_config.name.clone(), client);
}
Err(e) => {
tracing::error!(
"Failed to connect to server '{}': {}",
server_config.name,
e
);
}
}
}
if mgr.clients.is_empty() {
return Err(McpError::ConnectionFailed(
"Failed to connect to any MCP servers".to_string(),
));
}
Ok(mgr)
}
/// Discover and cache tools/prompts/resources for a connected server
async fn load_server_inventory(
&self,
server_name: &str,
client: &RunningService<RoleClient, ()>,
) {
// Tools
match client.peer().list_all_tools().await {
Ok(ts) => {
tracing::info!("Discovered {} tools from '{}'", ts.len(), server_name);
for t in ts {
if self.tools.contains_key(t.name.as_ref()) {
tracing::warn!(
"Tool '{}' from server '{}' is overwriting an existing tool.",
&t.name,
server_name
);
}
self.tools
.insert(t.name.to_string(), (server_name.to_string(), t));
}
}
Err(e) => tracing::warn!("Failed to list tools from '{}': {}", server_name, e),
}
// Prompts
match client.peer().list_all_prompts().await {
Ok(ps) => {
tracing::info!("Discovered {} prompts from '{}'", ps.len(), server_name);
for p in ps {
if self.prompts.contains_key(&p.name) {
tracing::warn!(
"Prompt '{}' from server '{}' is overwriting an existing prompt.",
&p.name,
server_name
);
}
self.prompts
.insert(p.name.clone(), (server_name.to_string(), p));
}
}
Err(e) => tracing::debug!("No prompts or failed to list on '{}': {}", server_name, e),
}
// Resources
match client.peer().list_all_resources().await {
Ok(rs) => {
tracing::info!("Discovered {} resources from '{}'", rs.len(), server_name);
for r in rs {
if self.resources.contains_key(&r.uri) {
tracing::warn!(
"Resource '{}' from server '{}' is overwriting an existing resource.",
&r.uri,
server_name
);
}
self.resources
.insert(r.uri.clone(), (server_name.to_string(), r));
}
}
Err(e) => tracing::debug!("No resources or failed to list on '{}': {}", server_name, e),
}
}
/// Connect to a single MCP server with retry logic for remote transports
async fn connect_server(config: &McpServerConfig) -> McpResult<RunningService<RoleClient, ()>> {
let needs_retry = matches!(
&config.transport,
McpTransport::Sse { .. } | McpTransport::Streamable { .. }
);
if needs_retry {
Self::connect_server_with_retry(config).await
} else {
Self::connect_server_impl(config).await
}
}
/// Connect with exponential backoff retry for remote servers
async fn connect_server_with_retry(
config: &McpServerConfig,
) -> McpResult<RunningService<RoleClient, ()>> {
let backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(1))
.with_max_interval(Duration::from_secs(30))
.with_max_elapsed_time(Some(Duration::from_secs(120)))
.build();
backoff::future::retry(backoff, || async {
match Self::connect_server_impl(config).await {
Ok(client) => Ok(client),
Err(e) => {
tracing::warn!("Failed to connect to '{}', retrying: {}", config.name, e);
Err(backoff::Error::transient(e))
}
}
})
.await
}
/// Internal implementation of server connection
async fn connect_server_impl(
config: &McpServerConfig,
) -> McpResult<RunningService<RoleClient, ()>> {
tracing::info!(
"Connecting to MCP server '{}' via {:?}",
config.name,
config.transport
);
match &config.transport {
McpTransport::Stdio {
command,
args,
envs,
} => {
let transport = TokioChildProcess::new(
tokio::process::Command::new(command).configure(|cmd| {
cmd.args(args)
.envs(envs.iter())
.stderr(std::process::Stdio::inherit());
}),
)
.map_err(|e| McpError::Transport(format!("create stdio transport: {}", e)))?;
let client = ().serve(transport).await.map_err(|e| {
McpError::ConnectionFailed(format!("initialize stdio client: {}", e))
})?;
tracing::info!("Connected to stdio server '{}'", config.name);
Ok(client)
}
McpTransport::Sse { url, token } => {
let transport = if let Some(tok) = token {
let client = reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", tok).parse().map_err(|e| {
McpError::Transport(format!("auth token: {}", e))
})?,
);
headers
})
.build()
.map_err(|e| McpError::Transport(format!("build HTTP client: {}", e)))?;
let cfg = SseClientConfig {
sse_endpoint: url.clone().into(),
..Default::default()
};
SseClientTransport::start_with_client(client, cfg)
.await
.map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))?
} else {
SseClientTransport::start(url.as_str())
.await
.map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))?
};
let client = ().serve(transport).await.map_err(|e| {
McpError::ConnectionFailed(format!("initialize SSE client: {}", e))
})?;
tracing::info!("Connected to SSE server '{}' at {}", config.name, url);
Ok(client)
}
McpTransport::Streamable { url, token } => {
let transport = if let Some(tok) = token {
let mut cfg = StreamableHttpClientTransportConfig::with_uri(url.as_str());
cfg.auth_header = Some(format!("Bearer {}", tok));
StreamableHttpClientTransport::from_config(cfg)
} else {
StreamableHttpClientTransport::from_uri(url.as_str())
};
let client = ().serve(transport).await.map_err(|e| {
McpError::ConnectionFailed(format!("initialize streamable client: {}", e))
})?;
tracing::info!(
"Connected to streamable HTTP server '{}' at {}",
config.name,
url
);
Ok(client)
}
}
}
// ===== Helpers =====
fn client_for(&self, server_name: &str) -> McpResult<&RunningService<RoleClient, ()>> {
self.clients
.get(server_name)
.ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))
}
fn tool_entry(&self, name: &str) -> McpResult<(String, McpTool)> {
self.tools
.get(name)
.map(|e| e.value().clone())
.ok_or_else(|| McpError::ToolNotFound(name.to_string()))
}
fn prompt_entry(&self, name: &str) -> McpResult<(String, Prompt)> {
self.prompts
.get(name)
.map(|e| e.value().clone())
.ok_or_else(|| McpError::PromptNotFound(name.to_string()))
}
fn resource_entry(&self, uri: &str) -> McpResult<(String, Resource)> {
self.resources
.get(uri)
.map(|e| e.value().clone())
.ok_or_else(|| McpError::ResourceNotFound(uri.to_string()))
}
// ===== Tool Methods =====
/// Call a tool by name
pub async fn call_tool(
&self,
tool_name: &str,
arguments: Option<serde_json::Map<String, serde_json::Value>>,
) -> McpResult<rmcp::model::CallToolResult> {
let (server_name, _tool) = self.tool_entry(tool_name)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Calling tool '{}' on '{}'", tool_name, server_name);
client
.peer()
.call_tool(CallToolRequestParam {
name: Cow::Owned(tool_name.to_string()),
arguments,
})
.await
.map_err(|e| McpError::ToolExecution(format!("Tool call failed: {}", e)))
}
/// Get all available tools
pub fn list_tools(&self) -> Vec<ToolInfo> {
self.tools
.iter()
.map(|entry| {
let tool_name = entry.key().clone();
let (server_name, tool) = entry.value();
ToolInfo {
name: tool_name,
description: tool.description.as_deref().unwrap_or_default().to_string(),
server: server_name.clone(),
parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())),
}
})
.collect()
}
/// Get a specific tool by name
pub fn get_tool(&self, name: &str) -> Option<ToolInfo> {
self.tools.get(name).map(|entry| {
let (server_name, tool) = entry.value();
ToolInfo {
name: name.to_string(),
description: tool.description.as_deref().unwrap_or_default().to_string(),
server: server_name.clone(),
parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())),
}
})
}
/// Check if a tool exists
pub fn has_tool(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
/// Get list of connected servers
pub fn list_servers(&self) -> Vec<String> {
self.clients.keys().cloned().collect()
}
// ===== Prompt Methods =====
/// Get a prompt by name with arguments
pub async fn get_prompt(
&self,
prompt_name: &str,
arguments: Option<serde_json::Map<String, serde_json::Value>>,
) -> McpResult<GetPromptResult> {
let (server_name, _prompt) = self.prompt_entry(prompt_name)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Getting prompt '{}' from '{}'", prompt_name, server_name);
client
.peer()
.get_prompt(GetPromptRequestParam {
name: prompt_name.to_string(),
arguments,
})
.await
.map_err(|e| McpError::ToolExecution(format!("Failed to get prompt: {}", e)))
}
/// List all available prompts
pub fn list_prompts(&self) -> Vec<PromptInfo> {
self.prompts
.iter()
.map(|entry| {
let name = entry.key().clone();
let (server_name, prompt) = entry.value();
PromptInfo {
name,
description: prompt.description.clone(),
server: server_name.clone(),
arguments: prompt
.arguments
.clone()
.map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()),
}
})
.collect()
}
/// Get a specific prompt info by name
pub fn get_prompt_info(&self, name: &str) -> Option<PromptInfo> {
self.prompts.get(name).map(|entry| {
let (server_name, prompt) = entry.value();
PromptInfo {
name: name.to_string(),
description: prompt.description.clone(),
server: server_name.clone(),
arguments: prompt
.arguments
.clone()
.map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()),
}
})
}
// ===== Resource Methods =====
/// Read a resource by URI
pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
let (server_name, _resource) = self.resource_entry(uri)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Reading resource '{}' from '{}'", uri, server_name);
client
.peer()
.read_resource(ReadResourceRequestParam {
uri: uri.to_string(),
})
.await
.map_err(|e| McpError::ToolExecution(format!("Failed to read resource: {}", e)))
}
/// List all available resources
pub fn list_resources(&self) -> Vec<ResourceInfo> {
self.resources
.iter()
.map(|entry| {
let uri = entry.key().clone();
let (server_name, resource) = entry.value();
ResourceInfo {
uri,
name: resource.name.clone(),
description: resource.description.clone(),
mime_type: resource.mime_type.clone(),
server: server_name.clone(),
}
})
.collect()
}
/// Get a specific resource info by URI
pub fn get_resource_info(&self, uri: &str) -> Option<ResourceInfo> {
self.resources.get(uri).map(|entry| {
let (server_name, resource) = entry.value();
ResourceInfo {
uri: uri.to_string(),
name: resource.name.clone(),
description: resource.description.clone(),
mime_type: resource.mime_type.clone(),
server: server_name.clone(),
}
})
}
/// Subscribe to resource changes
pub async fn subscribe_resource(&self, uri: &str) -> McpResult<()> {
let (server_name, _resource) = self.resource_entry(uri)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Subscribing to '{}' on '{}'", uri, server_name);
client
.peer()
.subscribe(rmcp::model::SubscribeRequestParam {
uri: uri.to_string(),
})
.await
.map_err(|e| McpError::ToolExecution(format!("Failed to subscribe: {}", e)))
}
/// Unsubscribe from resource changes
pub async fn unsubscribe_resource(&self, uri: &str) -> McpResult<()> {
let (server_name, _resource) = self.resource_entry(uri)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Unsubscribing from '{}' on '{}'", uri, server_name);
client
.peer()
.unsubscribe(rmcp::model::UnsubscribeRequestParam {
uri: uri.to_string(),
})
.await
.map_err(|e| McpError::ToolExecution(format!("Failed to unsubscribe: {}", e)))
}
/// Disconnect from all servers (for cleanup)
pub async fn shutdown(&mut self) {
for (name, client) in self.clients.drain() {
if let Err(e) = client.cancel().await {
tracing::warn!("Error disconnecting from '{}': {}", name, e);
}
}
self.tools.clear();
self.prompts.clear();
self.resources.clear();
}
}

View File

@@ -0,0 +1,52 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpConfig {
pub servers: Vec<McpServerConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpServerConfig {
pub name: String,
#[serde(flatten)]
pub transport: McpTransport,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "protocol", rename_all = "lowercase")]
pub enum McpTransport {
Stdio {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default)]
envs: HashMap<String, String>,
},
Sse {
url: String,
#[serde(skip_serializing_if = "Option::is_none")]
token: Option<String>,
},
Streamable {
url: String,
#[serde(skip_serializing_if = "Option::is_none")]
token: Option<String>,
},
}
impl McpConfig {
/// Load configuration from a YAML file
pub async fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let content = tokio::fs::read_to_string(path).await?;
let config: Self = serde_yaml::from_str(&content)?;
Ok(config)
}
/// Load configuration from environment variables (optional)
pub fn from_env() -> Option<Self> {
// This could be expanded to read from env vars
// For now, return None to indicate env config not implemented
None
}
}

View File

@@ -0,0 +1,42 @@
use thiserror::Error;
pub type McpResult<T> = Result<T, McpError>;
#[derive(Debug, Error)]
pub enum McpError {
#[error("Server not found: {0}")]
ServerNotFound(String),
#[error("Tool not found: {0}")]
ToolNotFound(String),
#[error("Transport error: {0}")]
Transport(String),
#[error("Tool execution failed: {0}")]
ToolExecution(String),
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Resource not found: {0}")]
ResourceNotFound(String),
#[error("Prompt not found: {0}")]
PromptNotFound(String),
#[error(transparent)]
Sdk(#[from] Box<rmcp::RmcpError>),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Http(#[from] reqwest::Error),
}

18
sgl-router/src/mcp/mod.rs Normal file
View File

@@ -0,0 +1,18 @@
// MCP Client for SGLang Router
//
// This module provides a complete MCP (Model Context Protocol) client implementation
// supporting multiple transport types (stdio, SSE, HTTP) and all MCP features:
// - Tools: Discovery and execution
// - Prompts: Reusable templates for LLM interactions
// - Resources: File/data access with subscription support
// - OAuth: Secure authentication for remote servers
pub mod client_manager;
pub mod config;
pub mod error;
pub mod oauth;
// Re-export the main types for convenience
pub use client_manager::{McpClientManager, PromptInfo, ResourceInfo, ToolInfo};
pub use config::{McpConfig, McpServerConfig, McpTransport};
pub use error::{McpError, McpResult};

191
sgl-router/src/mcp/oauth.rs Normal file
View File

@@ -0,0 +1,191 @@
// OAuth authentication support for MCP servers
use axum::{
extract::{Query, State},
response::Html,
routing::get,
Router,
};
use rmcp::transport::auth::OAuthState;
use serde::Deserialize;
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::{oneshot, Mutex};
use crate::mcp::error::{McpError, McpResult};
/// OAuth callback parameters
#[derive(Debug, Deserialize)]
struct CallbackParams {
code: String,
#[allow(dead_code)]
state: Option<String>,
}
/// State for the callback server
#[derive(Clone)]
struct CallbackState {
code_receiver: Arc<Mutex<Option<oneshot::Sender<String>>>>,
}
/// HTML page returned after successful OAuth callback
const CALLBACK_HTML: &str = r#"
<!DOCTYPE html>
<html>
<head>
<title>OAuth Success</title>
<style>
body {
font-family: Arial, sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
.container {
background: white;
padding: 40px;
border-radius: 10px;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
text-align: center;
}
h1 { color: #333; }
p { color: #666; margin: 20px 0; }
.success { color: #4CAF50; font-size: 48px; }
</style>
</head>
<body>
<div class="container">
<div class="success">✓</div>
<h1>Authentication Successful!</h1>
<p>You can now close this window and return to your application.</p>
</div>
</body>
</html>
"#;
/// OAuth authentication helper for MCP servers
pub struct OAuthHelper {
server_url: String,
redirect_uri: String,
callback_port: u16,
}
impl OAuthHelper {
/// Create a new OAuth helper
pub fn new(server_url: String, redirect_uri: String, callback_port: u16) -> Self {
Self {
server_url,
redirect_uri,
callback_port,
}
}
/// Perform OAuth authentication flow
pub async fn authenticate(
&self,
scopes: &[&str],
) -> McpResult<rmcp::transport::auth::AuthorizationManager> {
// Initialize OAuth state machine
let mut oauth_state = OAuthState::new(&self.server_url, None)
.await
.map_err(|e| McpError::Auth(format!("Failed to initialize OAuth: {}", e)))?;
oauth_state
.start_authorization(scopes, &self.redirect_uri)
.await
.map_err(|e| McpError::Auth(format!("Failed to start authorization: {}", e)))?;
// Get authorization URL
let auth_url = oauth_state
.get_authorization_url()
.await
.map_err(|e| McpError::Auth(format!("Failed to get authorization URL: {}", e)))?;
tracing::info!("OAuth authorization URL: {}", auth_url);
// Start callback server and wait for code
let auth_code = self.start_callback_server().await?;
// Exchange code for token
oauth_state
.handle_callback(&auth_code)
.await
.map_err(|e| McpError::Auth(format!("Failed to handle OAuth callback: {}", e)))?;
// Get authorization manager
oauth_state
.into_authorization_manager()
.ok_or_else(|| McpError::Auth("Failed to get authorization manager".to_string()))
}
/// Start a local HTTP server to receive the OAuth callback
async fn start_callback_server(&self) -> McpResult<String> {
let (code_sender, code_receiver) = oneshot::channel::<String>();
let state = CallbackState {
code_receiver: Arc::new(Mutex::new(Some(code_sender))),
};
// Create router for callback
let app = Router::new()
.route("/callback", get(Self::callback_handler))
.with_state(state);
let addr = SocketAddr::from(([127, 0, 0, 1], self.callback_port));
// Start server in background
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
McpError::Auth(format!(
"Failed to bind to callback port {}: {}",
self.callback_port, e
))
})?;
tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
tracing::info!(
"OAuth callback server started on port {}",
self.callback_port
);
// Wait for authorization code
code_receiver
.await
.map_err(|_| McpError::Auth("Failed to receive authorization code".to_string()))
}
/// Handle OAuth callback
async fn callback_handler(
Query(params): Query<CallbackParams>,
State(state): State<CallbackState>,
) -> Html<String> {
tracing::debug!("Received OAuth callback with code");
// Send code to waiting task
if let Some(sender) = state.code_receiver.lock().await.take() {
let _ = sender.send(params.code);
}
Html(CALLBACK_HTML.to_string())
}
}
/// Create an OAuth-authenticated client
pub async fn create_oauth_client(
server_url: String,
_sse_url: String,
redirect_uri: String,
callback_port: u16,
scopes: &[&str],
) -> McpResult<rmcp::transport::auth::AuthClient<reqwest::Client>> {
let helper = OAuthHelper::new(server_url, redirect_uri, callback_port);
let auth_manager = helper.authenticate(scopes).await?;
let client = rmcp::transport::auth::AuthClient::new(reqwest::Client::default(), auth_manager);
Ok(client)
}

1011
sgl-router/src/metrics.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,502 @@
use axum::{
extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next,
response::IntoResponse, response::Response,
};
use rand::Rng;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::{mpsc, oneshot};
use tower::{Layer, Service};
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
pub use crate::core::token_bucket::TokenBucket;
use crate::server::AppState;
/// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String {
let prefix = if path.contains("/chat/completions") {
"chatcmpl-"
} else if path.contains("/completions") {
"cmpl-"
} else if path.contains("/generate") {
"gnt-"
} else {
"req-"
};
// Generate a random string similar to OpenAI's format
let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::rng();
let random_part: String = (0..24)
.map(|_| {
let idx = rng.random_range(0..chars.len());
chars.chars().nth(idx).unwrap()
})
.collect();
format!("{}{}", prefix, random_part)
}
/// Extension type for storing request ID
#[derive(Clone, Debug)]
pub struct RequestId(pub String);
/// Tower Layer for request ID middleware
#[derive(Clone)]
pub struct RequestIdLayer {
headers: Arc<Vec<String>>,
}
impl RequestIdLayer {
pub fn new(headers: Vec<String>) -> Self {
Self {
headers: Arc::new(headers),
}
}
}
impl<S> Layer<S> for RequestIdLayer {
type Service = RequestIdMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestIdMiddleware {
inner,
headers: self.headers.clone(),
}
}
}
/// Tower Service for request ID middleware
#[derive(Clone)]
pub struct RequestIdMiddleware<S> {
inner: S,
headers: Arc<Vec<String>>,
}
impl<S> Service<Request> for RequestIdMiddleware<S>
where
S: Service<Request, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request) -> Self::Future {
let headers = self.headers.clone();
// Extract request ID from headers or generate new one
let mut request_id = None;
for header_name in headers.iter() {
if let Some(header_value) = req.headers().get(header_name) {
if let Ok(value) = header_value.to_str() {
request_id = Some(value.to_string());
break;
}
}
}
let request_id = request_id.unwrap_or_else(|| generate_request_id(req.uri().path()));
// Insert request ID into request extensions
req.extensions_mut().insert(RequestId(request_id.clone()));
// Create a span with the request ID for this request
let span = tracing::info_span!(
"http_request",
method = %req.method(),
uri = %req.uri(),
version = ?req.version(),
request_id = %request_id
);
// Log within the span
let _enter = span.enter();
tracing::info!(
target: "sglang_router_rs::request",
"started processing request"
);
drop(_enter);
// Capture values we need in the async block
let method = req.method().clone();
let uri = req.uri().clone();
let version = req.version();
// Call the inner service
let future = self.inner.call(req);
Box::pin(async move {
let start_time = Instant::now();
let mut response = future.await?;
let latency = start_time.elapsed();
// Add request ID to response headers
response.headers_mut().insert(
"x-request-id",
HeaderValue::from_str(&request_id)
.unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id")),
);
// Log the response with proper request ID in span
let status = response.status();
let span = tracing::info_span!(
"http_request",
method = %method,
uri = %uri,
version = ?version,
request_id = %request_id,
status = %status,
latency = ?latency
);
let _enter = span.enter();
if status.is_server_error() {
tracing::error!(
target: "sglang_router_rs::response",
"request failed with server error"
);
} else if status.is_client_error() {
tracing::warn!(
target: "sglang_router_rs::response",
"request failed with client error"
);
} else {
tracing::info!(
target: "sglang_router_rs::response",
"finished processing request"
);
}
Ok(response)
})
}
}
// ============= Logging Middleware =============
/// Custom span maker that includes request ID
#[derive(Clone, Debug)]
pub struct RequestSpan;
impl<B> MakeSpan<B> for RequestSpan {
fn make_span(&mut self, request: &Request<B>) -> Span {
// Don't try to extract request ID here - it won't be available yet
// The RequestIdLayer runs after TraceLayer creates the span
info_span!(
"http_request",
method = %request.method(),
uri = %request.uri(),
version = ?request.version(),
request_id = Empty, // Will be set later
status_code = Empty,
latency = Empty,
error = Empty,
)
}
}
/// Custom on_request handler
#[derive(Clone, Debug)]
pub struct RequestLogger;
impl<B> OnRequest<B> for RequestLogger {
fn on_request(&mut self, request: &Request<B>, span: &Span) {
let _enter = span.enter();
// Try to get the request ID from extensions
// This will work if RequestIdLayer has already run
if let Some(request_id) = request.extensions().get::<RequestId>() {
span.record("request_id", request_id.0.as_str());
}
// Don't log here - we already log in RequestIdService with the proper request_id
}
}
/// Custom on_response handler
#[derive(Clone, Debug)]
pub struct ResponseLogger {
_start_time: Instant,
}
impl Default for ResponseLogger {
fn default() -> Self {
Self {
_start_time: Instant::now(),
}
}
}
impl<B> OnResponse<B> for ResponseLogger {
fn on_response(self, response: &Response<B>, latency: std::time::Duration, span: &Span) {
let status = response.status();
// Record these in the span for structured logging/observability tools
span.record("status_code", status.as_u16());
span.record("latency", format!("{:?}", latency));
// Don't log here - RequestIdService handles all logging with proper request IDs
}
}
/// Create a configured TraceLayer for HTTP logging
/// Note: Actual request/response logging with request IDs is done in RequestIdService
pub fn create_logging_layer() -> TraceLayer<
tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
RequestSpan,
RequestLogger,
ResponseLogger,
> {
TraceLayer::new_for_http()
.make_span_with(RequestSpan)
.on_request(RequestLogger)
.on_response(ResponseLogger::default())
}
/// Structured logging data for requests
#[derive(Debug, serde::Serialize)]
pub struct RequestLogEntry {
pub timestamp: String,
pub request_id: String,
pub method: String,
pub uri: String,
pub status: u16,
pub latency_ms: u64,
pub user_agent: Option<String>,
pub remote_addr: Option<String>,
pub error: Option<String>,
}
/// Log a request with structured data
pub fn log_request(entry: RequestLogEntry) {
if entry.status >= 500 {
tracing::error!(
target: "sglang_router_rs::http",
request_id = %entry.request_id,
method = %entry.method,
uri = %entry.uri,
status = entry.status,
latency_ms = entry.latency_ms,
user_agent = ?entry.user_agent,
remote_addr = ?entry.remote_addr,
error = ?entry.error,
"HTTP request failed"
);
} else if entry.status >= 400 {
tracing::warn!(
target: "sglang_router_rs::http",
request_id = %entry.request_id,
method = %entry.method,
uri = %entry.uri,
status = entry.status,
latency_ms = entry.latency_ms,
user_agent = ?entry.user_agent,
remote_addr = ?entry.remote_addr,
"HTTP request client error"
);
} else {
tracing::info!(
target: "sglang_router_rs::http",
request_id = %entry.request_id,
method = %entry.method,
uri = %entry.uri,
status = entry.status,
latency_ms = entry.latency_ms,
user_agent = ?entry.user_agent,
remote_addr = ?entry.remote_addr,
"HTTP request completed"
);
}
}
// ============ Concurrency Limiting with Queue Support ============
/// Request queue entry
pub struct QueuedRequest {
/// Time when the request was queued
queued_at: Instant,
/// Channel to send the permit back when acquired
permit_tx: oneshot::Sender<Result<(), StatusCode>>,
}
/// Queue metrics for monitoring
#[derive(Debug, Default)]
pub struct QueueMetrics {
pub total_queued: std::sync::atomic::AtomicU64,
pub current_queued: std::sync::atomic::AtomicU64,
pub total_timeout: std::sync::atomic::AtomicU64,
pub total_rejected: std::sync::atomic::AtomicU64,
}
/// Queue processor that handles queued requests
pub struct QueueProcessor {
token_bucket: Arc<TokenBucket>,
queue_rx: mpsc::Receiver<QueuedRequest>,
queue_timeout: Duration,
}
impl QueueProcessor {
pub fn new(
token_bucket: Arc<TokenBucket>,
queue_rx: mpsc::Receiver<QueuedRequest>,
queue_timeout: Duration,
) -> Self {
Self {
token_bucket,
queue_rx,
queue_timeout,
}
}
pub async fn run(mut self) {
info!("Starting concurrency queue processor");
// Process requests in a single task to reduce overhead
while let Some(queued) = self.queue_rx.recv().await {
// Check timeout immediately
let elapsed = queued.queued_at.elapsed();
if elapsed >= self.queue_timeout {
warn!("Request already timed out in queue");
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
continue;
}
let remaining_timeout = self.queue_timeout - elapsed;
// Try to acquire token for this request
if self.token_bucket.try_acquire(1.0).await.is_ok() {
// Got token immediately
debug!("Queue: acquired token immediately for queued request");
let _ = queued.permit_tx.send(Ok(()));
} else {
// Need to wait for token
let token_bucket = self.token_bucket.clone();
// Spawn task only when we actually need to wait
tokio::spawn(async move {
if token_bucket
.acquire_timeout(1.0, remaining_timeout)
.await
.is_ok()
{
debug!("Queue: acquired token after waiting");
let _ = queued.permit_tx.send(Ok(()));
} else {
warn!("Queue: request timed out waiting for token");
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
}
});
}
}
warn!("Concurrency queue processor shutting down");
}
}
/// State for the concurrency limiter
pub struct ConcurrencyLimiter {
pub queue_tx: Option<mpsc::Sender<QueuedRequest>>,
}
impl ConcurrencyLimiter {
/// Create new concurrency limiter with optional queue
pub fn new(
token_bucket: Arc<TokenBucket>,
queue_size: usize,
queue_timeout: Duration,
) -> (Self, Option<QueueProcessor>) {
if queue_size > 0 {
let (queue_tx, queue_rx) = mpsc::channel(queue_size);
let processor = QueueProcessor::new(token_bucket, queue_rx, queue_timeout);
(
Self {
queue_tx: Some(queue_tx),
},
Some(processor),
)
} else {
(Self { queue_tx: None }, None)
}
}
}
/// Middleware function for concurrency limiting with optional queuing
pub async fn concurrency_limit_middleware(
State(app_state): State<Arc<AppState>>,
request: Request<axum::body::Body>,
next: Next,
) -> Response {
let token_bucket = app_state.context.rate_limiter.clone();
// Try to acquire token immediately
if token_bucket.try_acquire(1.0).await.is_ok() {
debug!("Acquired token immediately");
let response = next.run(request).await;
// Return the token to the bucket
token_bucket.return_tokens(1.0).await;
response
} else {
// No tokens available, try to queue if enabled
if let Some(queue_tx) = &app_state.concurrency_queue_tx {
debug!("No tokens available, attempting to queue request");
// Create a channel for the token response
let (permit_tx, permit_rx) = oneshot::channel();
let queued = QueuedRequest {
queued_at: Instant::now(),
permit_tx,
};
// Try to send to queue
match queue_tx.try_send(queued) {
Ok(_) => {
// Wait for token from queue processor
match permit_rx.await {
Ok(Ok(())) => {
debug!("Acquired token from queue");
let response = next.run(request).await;
// Return the token to the bucket
token_bucket.return_tokens(1.0).await;
response
}
Ok(Err(status)) => {
warn!("Queue returned error status: {}", status);
status.into_response()
}
Err(_) => {
error!("Queue response channel closed");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
Err(_) => {
warn!("Request queue is full, returning 429");
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}
} else {
warn!("No tokens available and queuing is disabled, returning 429");
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}
}

View File

@@ -0,0 +1,423 @@
/*
Cache-Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution:
1. Cache-Aware Routing (Approximate Tree)
2. Load Balancing (Shortest Queue with Balance Thresholds)
The router dynamically switches between these strategies based on load conditions:
- Uses load balancing when the system is imbalanced
- Uses cache-aware routing when the system is balanced
A system is considered imbalanced if both conditions are met:
1. (max - min) > abs_threshold
2. max > rel_threshold * min
Strategy Details:
1. Cache-Aware Routing (Approximate Tree)
-------------------------------------------
This strategy maintains an approximate radix tree for each worker based on request history,
eliminating the need for direct cache state queries. The tree stores raw text characters
instead of token IDs to avoid tokenization overhead.
Process:
a. For each request, find the worker with the highest prefix match
b. If match rate > cache_threshold:
Route to the worker with highest match (likely has relevant data cached)
c. If match rate ≤ cache_threshold:
Route to the worker with smallest tree size (most available cache capacity)
d. Background maintenance:
Periodically evict least recently used leaf nodes to prevent memory overflow
2. Load Balancing (Shortest Queue)
-------------------------------------------
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker when the system is detected to be imbalanced.
Configuration Parameters:
------------------------
1. cache_threshold: (float, 0.0 to 1.0)
Minimum prefix match ratio to use highest-match routing.
Below this threshold, routes to worker with most available cache space.
2. balance_abs_threshold: (integer)
Absolute difference threshold for load imbalance detection.
System is potentially imbalanced if (max_load - min_load) > abs_threshold
3. balance_rel_threshold: (float)
Relative ratio threshold for load imbalance detection.
System is potentially imbalanced if max_load > min_load * rel_threshold
Used in conjunction with abs_threshold to determine final imbalance state.
4. eviction_interval_secs: (integer)
Interval between LRU eviction cycles for the approximate trees.
5. max_tree_size: (integer)
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use crate::tree::Tree;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use tracing::debug;
/// Cache-aware routing policy
///
/// Routes requests based on cache affinity when load is balanced,
/// switches to shortest-queue routing when load is imbalanced.
#[derive(Debug)]
pub struct CacheAwarePolicy {
config: CacheAwareConfig,
tree: Arc<Mutex<Tree>>,
eviction_handle: Option<thread::JoinHandle<()>>,
}
impl CacheAwarePolicy {
pub fn new() -> Self {
Self::with_config(CacheAwareConfig::default())
}
pub fn with_config(config: CacheAwareConfig) -> Self {
let tree = Arc::new(Mutex::new(Tree::new()));
// Start background eviction thread if configured
let eviction_handle = if config.eviction_interval_secs > 0 {
let tree_clone = Arc::clone(&tree);
let max_tree_size = config.max_tree_size;
let interval = config.eviction_interval_secs;
Some(thread::spawn(move || loop {
thread::sleep(Duration::from_secs(interval));
if let Ok(tree_guard) = tree_clone.lock() {
tree_guard.evict_tenant_by_size(max_tree_size);
debug!("Cache eviction completed, max_size: {}", max_tree_size);
}
}))
} else {
None
};
Self {
config,
tree,
eviction_handle,
}
}
/// Initialize the tree with worker URLs (used only during initial setup)
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
if let Ok(tree) = self.tree.lock() {
for worker in workers {
tree.insert("", worker.url());
}
}
}
/// Add a single worker to the tree (incremental update)
pub fn add_worker(&self, url: &str) {
if let Ok(tree) = self.tree.lock() {
tree.insert("", url);
}
}
/// Remove a worker from the tree
pub fn remove_worker(&self, url: &str) {
if let Ok(tree) = self.tree.lock() {
tree.remove_tenant(url);
}
}
/// Run cache eviction to prevent unbounded growth
pub fn evict_cache(&self, max_size: usize) {
if let Ok(tree) = self.tree.lock() {
tree.evict_tenant_by_size(max_size);
}
}
}
impl LoadBalancingPolicy for CacheAwarePolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
if healthy_indices.is_empty() {
return None;
}
// Get current load statistics
let loads: Vec<usize> = workers.iter().map(|w| w.load()).collect();
let max_load = *loads.iter().max().unwrap_or(&0);
let min_load = *loads.iter().min().unwrap_or(&0);
// Check if load is imbalanced
let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold
&& (max_load as f32) > (min_load as f32 * self.config.balance_rel_threshold);
if is_imbalanced {
// Log load balancing trigger
let worker_loads: Vec<(String, usize)> = workers
.iter()
.map(|w| (w.url().to_string(), w.load()))
.collect();
debug!(
"Load balancing triggered | max: {} | min: {} | workers: {:?}",
max_load, min_load, worker_loads
);
RouterMetrics::record_load_balancing_event();
RouterMetrics::set_load_range(max_load, min_load);
// Use shortest queue when imbalanced
let min_load_idx = healthy_indices
.iter()
.min_by_key(|&&idx| workers[idx].load())
.copied()?;
// Even in imbalanced mode, update the tree to maintain cache state
if let Some(text) = request_text {
if let Ok(tree) = self.tree.lock() {
tree.insert(text, workers[min_load_idx].url());
}
}
// Increment processed counter
workers[min_load_idx].increment_processed();
RouterMetrics::record_processed_request(workers[min_load_idx].url());
RouterMetrics::record_policy_decision(self.name(), workers[min_load_idx].url());
return Some(min_load_idx);
}
// Use cache-aware routing when balanced
let text = request_text.unwrap_or("");
if let Ok(tree) = self.tree.lock() {
let (matched_text, matched_worker) = tree.prefix_match(text);
let match_rate = if text.is_empty() {
0.0
} else {
matched_text.chars().count() as f32 / text.chars().count() as f32
};
let selected_url = if match_rate > self.config.cache_threshold {
RouterMetrics::record_cache_hit();
matched_worker.to_string()
} else {
RouterMetrics::record_cache_miss();
tree.get_smallest_tenant()
};
// Find the index of the selected worker
if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) {
// Only proceed if the worker is healthy
if workers[selected_idx].is_healthy() {
// Update the tree with this request
tree.insert(text, &selected_url);
// Increment processed counter
workers[selected_idx].increment_processed();
RouterMetrics::record_processed_request(&selected_url);
return Some(selected_idx);
}
} else {
// Selected worker no longer exists, remove it from tree
tree.remove_tenant(&selected_url);
debug!("Removed stale worker {} from cache tree", selected_url);
}
// Fallback to first healthy worker
return healthy_indices.first().copied();
}
// Fallback to first healthy worker if tree operations fail
healthy_indices.first().copied()
}
fn name(&self) -> &'static str {
"cache_aware"
}
fn needs_request_text(&self) -> bool {
true // Cache-aware policy needs request text for cache affinity
}
fn on_request_complete(&self, worker_url: &str, success: bool) {
// Could track success rates per worker for more intelligent routing
if !success {
// Optionally reduce affinity for failed requests
tracing::debug!(
"Request to {} completed with success={}",
worker_url,
success
);
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn select_worker_pair(
&self,
prefill_workers: &[Box<dyn Worker>],
decode_workers: &[Box<dyn Worker>],
request_text: Option<&str>,
) -> Option<(usize, usize)> {
// DEPRECATED: This method is no longer used when separate policies are configured.
// The PD router now uses separate policies for prefill and decode selection.
// This implementation remains for backward compatibility when a single policy is used.
// In PD mode with single policy:
// - Prefill: Use cache-aware routing for better cache utilization
// - Decode: Use least-load routing for better load distribution
// Select prefill worker using cache-aware logic
let prefill_idx = self.select_worker(prefill_workers, request_text)?;
// Select decode worker using least-load logic
let healthy_decode = get_healthy_worker_indices(decode_workers);
if healthy_decode.is_empty() {
return None;
}
let decode_idx = healthy_decode
.iter()
.min_by_key(|&&idx| decode_workers[idx].load())
.copied()?;
Some((prefill_idx, decode_idx))
}
}
impl Default for CacheAwarePolicy {
fn default() -> Self {
Self::new()
}
}
impl Drop for CacheAwarePolicy {
fn drop(&mut self) {
// Note: We can't properly stop the eviction thread since it's in an infinite loop
// In a production system, we'd use a channel or atomic flag to signal shutdown
if let Some(handle) = self.eviction_handle.take() {
// The thread will continue running until the program exits
// This is acceptable for now since the router typically runs for the lifetime of the program
drop(handle);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{BasicWorker, WorkerType};
#[test]
fn test_cache_aware_with_balanced_load() {
// Create policy without eviction thread for testing
let config = CacheAwareConfig {
eviction_interval_secs: 0, // Disable eviction thread
..Default::default()
};
let policy = CacheAwarePolicy::with_config(config);
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
];
// Initialize the policy with workers
policy.init_workers(&workers);
// First request should be distributed
let idx1 = policy.select_worker(&workers, Some("hello world")).unwrap();
// Same request should go to same worker (cache hit)
let idx2 = policy.select_worker(&workers, Some("hello world")).unwrap();
assert_eq!(idx1, idx2);
// Similar request should also go to same worker
let idx3 = policy.select_worker(&workers, Some("hello")).unwrap();
assert_eq!(idx1, idx3);
}
#[test]
fn test_cache_aware_with_imbalanced_load() {
let policy = CacheAwarePolicy::with_config(CacheAwareConfig {
cache_threshold: 0.5,
balance_abs_threshold: 5,
balance_rel_threshold: 2.0,
eviction_interval_secs: 0, // Disable eviction thread
max_tree_size: 10000,
});
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
// Create significant load imbalance
for _ in 0..20 {
worker1.increment_load();
}
// worker2 has load 0
let workers: Vec<Box<dyn Worker>> = vec![Box::new(worker1), Box::new(worker2)];
policy.init_workers(&workers);
// Should select worker2 (lower load) despite cache affinity
for _ in 0..5 {
let idx = policy.select_worker(&workers, Some("test")).unwrap();
assert_eq!(idx, 1); // Should always pick worker2
}
}
#[test]
fn test_cache_aware_worker_removal() {
let config = CacheAwareConfig {
eviction_interval_secs: 0, // Disable eviction thread
..Default::default()
};
let policy = CacheAwarePolicy::with_config(config);
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
];
policy.init_workers(&workers);
// Route some requests
policy.select_worker(&workers, Some("test1"));
policy.select_worker(&workers, Some("test2"));
// Remove a worker
policy.remove_worker("http://w1:8000");
workers[0].set_healthy(false);
// All requests should now go to worker2
let idx = policy.select_worker(&workers, Some("test1")).unwrap();
assert_eq!(idx, 1);
}
}

View File

@@ -0,0 +1,94 @@
//! Factory for creating load balancing policies
use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy,
};
use crate::config::PolicyConfig;
use std::sync::Arc;
/// Factory for creating policy instances
pub struct PolicyFactory;
impl PolicyFactory {
/// Create a policy from configuration
pub fn create_from_config(config: &PolicyConfig) -> Arc<dyn LoadBalancingPolicy> {
match config {
PolicyConfig::Random => Arc::new(RandomPolicy::new()),
PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()),
PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()),
PolicyConfig::CacheAware {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
} => {
let config = CacheAwareConfig {
cache_threshold: *cache_threshold,
balance_abs_threshold: *balance_abs_threshold,
balance_rel_threshold: *balance_rel_threshold,
eviction_interval_secs: *eviction_interval_secs,
max_tree_size: *max_tree_size,
};
Arc::new(CacheAwarePolicy::with_config(config))
}
}
}
/// Create a policy by name (for dynamic loading)
pub fn create_by_name(name: &str) -> Option<Arc<dyn LoadBalancingPolicy>> {
match name.to_lowercase().as_str() {
"random" => Some(Arc::new(RandomPolicy::new())),
"round_robin" | "roundrobin" => Some(Arc::new(RoundRobinPolicy::new())),
"power_of_two" | "poweroftwo" => Some(Arc::new(PowerOfTwoPolicy::new())),
"cache_aware" | "cacheaware" => Some(Arc::new(CacheAwarePolicy::new())),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_from_config() {
// Test Random
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
assert_eq!(policy.name(), "random");
// Test RoundRobin
let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin);
assert_eq!(policy.name(), "round_robin");
// Test PowerOfTwo
let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
});
assert_eq!(policy.name(), "power_of_two");
// Test CacheAware
let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware {
cache_threshold: 0.7,
balance_abs_threshold: 10,
balance_rel_threshold: 1.5,
eviction_interval_secs: 30,
max_tree_size: 1000,
});
assert_eq!(policy.name(), "cache_aware");
}
#[test]
fn test_create_by_name() {
assert!(PolicyFactory::create_by_name("random").is_some());
assert!(PolicyFactory::create_by_name("RANDOM").is_some());
assert!(PolicyFactory::create_by_name("round_robin").is_some());
assert!(PolicyFactory::create_by_name("RoundRobin").is_some());
assert!(PolicyFactory::create_by_name("power_of_two").is_some());
assert!(PolicyFactory::create_by_name("PowerOfTwo").is_some());
assert!(PolicyFactory::create_by_name("cache_aware").is_some());
assert!(PolicyFactory::create_by_name("CacheAware").is_some());
assert!(PolicyFactory::create_by_name("unknown").is_none());
}
}

View File

@@ -0,0 +1,148 @@
//! Load balancing policies for SGLang router
//!
//! This module provides a unified abstraction for routing policies that work
//! across both regular and prefill-decode (PD) routing modes.
use crate::core::Worker;
use std::fmt::Debug;
mod cache_aware;
mod factory;
mod power_of_two;
mod random;
mod round_robin;
pub use cache_aware::CacheAwarePolicy;
pub use factory::PolicyFactory;
pub use power_of_two::PowerOfTwoPolicy;
pub use random::RandomPolicy;
pub use round_robin::RoundRobinPolicy;
/// Core trait for load balancing policies
///
/// This trait provides a unified interface for implementing routing algorithms
/// that can work with both regular single-worker selection and PD dual-worker selection.
pub trait LoadBalancingPolicy: Send + Sync + Debug {
/// Select a single worker from the available workers
///
/// This is used for regular routing mode where requests go to a single worker.
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
request_text: Option<&str>,
) -> Option<usize>;
/// Select a pair of workers (prefill and decode) for PD routing
///
/// Returns indices of (prefill_worker, decode_worker) from their respective arrays.
/// Default implementation uses select_worker for each array independently.
fn select_worker_pair(
&self,
prefill_workers: &[Box<dyn Worker>],
decode_workers: &[Box<dyn Worker>],
request_text: Option<&str>,
) -> Option<(usize, usize)> {
// Default implementation: independently select from each pool
let prefill_idx = self.select_worker(prefill_workers, request_text)?;
let decode_idx = self.select_worker(decode_workers, request_text)?;
Some((prefill_idx, decode_idx))
}
/// Update policy state after request completion
///
/// This is called when a request completes (successfully or not) to allow
/// policies to update their internal state.
fn on_request_complete(&self, _worker_url: &str, _success: bool) {
// Default: no-op for stateless policies
}
/// Get policy name for metrics and debugging
fn name(&self) -> &'static str;
/// Check if this policy needs request text for routing decisions
fn needs_request_text(&self) -> bool {
false // Default: most policies don't need request text
}
/// Update worker load information
///
/// This is called periodically with current load information for load-aware policies.
fn update_loads(&self, _loads: &std::collections::HashMap<String, isize>) {
// Default: no-op for policies that don't use load information
}
/// Reset any internal state
///
/// This is useful for policies that maintain state (e.g., round-robin counters).
fn reset(&self) {
// Default: no-op for stateless policies
}
/// Get as Any for downcasting
fn as_any(&self) -> &dyn std::any::Any;
}
/// Configuration for cache-aware policy
#[derive(Debug, Clone)]
pub struct CacheAwareConfig {
pub cache_threshold: f32,
pub balance_abs_threshold: usize,
pub balance_rel_threshold: f32,
pub eviction_interval_secs: u64,
pub max_tree_size: usize,
}
impl Default for CacheAwareConfig {
fn default() -> Self {
Self {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 30,
max_tree_size: 10000,
}
}
}
/// Helper function to filter healthy workers and return their indices
pub(crate) fn get_healthy_worker_indices(workers: &[Box<dyn Worker>]) -> Vec<usize> {
workers
.iter()
.enumerate()
.filter(|(_, w)| w.is_healthy() && w.circuit_breaker().can_execute())
.map(|(idx, _)| idx)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{BasicWorker, WorkerType};
#[test]
fn test_get_healthy_worker_indices() {
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
];
// All healthy initially
let indices = get_healthy_worker_indices(&workers);
assert_eq!(indices, vec![0, 1, 2]);
// Mark one unhealthy
workers[1].set_healthy(false);
let indices = get_healthy_worker_indices(&workers);
assert_eq!(indices, vec![0, 2]);
}
}

View File

@@ -0,0 +1,201 @@
//! Power-of-two choices load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
use std::collections::HashMap;
use std::sync::RwLock;
use tracing::info;
/// Power-of-two choices policy
///
/// Randomly selects two workers and routes to the one with lower load.
/// This provides good load distribution with minimal coordination overhead.
#[derive(Debug)]
pub struct PowerOfTwoPolicy {
/// Cached load information from external monitoring
cached_loads: RwLock<HashMap<String, isize>>,
}
impl PowerOfTwoPolicy {
pub fn new() -> Self {
Self {
cached_loads: RwLock::new(HashMap::new()),
}
}
fn get_worker_load(&self, worker: &dyn Worker) -> isize {
// First check cached loads (from external monitoring)
if let Ok(loads) = self.cached_loads.read() {
if let Some(&load) = loads.get(worker.url()) {
return load;
}
}
// Fall back to local load counter
worker.load() as isize
}
}
impl LoadBalancingPolicy for PowerOfTwoPolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
_request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
if healthy_indices.is_empty() {
return None;
}
if healthy_indices.len() == 1 {
return Some(healthy_indices[0]);
}
// Select two random workers
let mut rng = rand::rng();
let idx1 = rng.random_range(0..healthy_indices.len());
let mut idx2 = rng.random_range(0..healthy_indices.len());
// Ensure we pick two different workers
while idx2 == idx1 {
idx2 = rng.random_range(0..healthy_indices.len());
}
let worker_idx1 = healthy_indices[idx1];
let worker_idx2 = healthy_indices[idx2];
// Compare loads and select the less loaded one
let load1 = self.get_worker_load(workers[worker_idx1].as_ref());
let load2 = self.get_worker_load(workers[worker_idx2].as_ref());
// Log selection for debugging
let selected_idx = if load1 <= load2 {
worker_idx1
} else {
worker_idx2
};
info!(
"Power-of-two selection: {}={} vs {}={} -> selected {}",
workers[worker_idx1].url(),
load1,
workers[worker_idx2].url(),
load2,
workers[selected_idx].url()
);
// Increment processed counter
workers[selected_idx].increment_processed();
RouterMetrics::record_processed_request(workers[selected_idx].url());
RouterMetrics::record_policy_decision(self.name(), workers[selected_idx].url());
Some(selected_idx)
}
fn name(&self) -> &'static str {
"power_of_two"
}
fn update_loads(&self, loads: &HashMap<String, isize>) {
if let Ok(mut cached) = self.cached_loads.write() {
*cached = loads.clone();
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl Default for PowerOfTwoPolicy {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{BasicWorker, WorkerType};
#[test]
fn test_power_of_two_selection() {
let policy = PowerOfTwoPolicy::new();
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
let worker3 = BasicWorker::new("http://w3:8000".to_string(), WorkerType::Regular);
// Set different loads
for _ in 0..10 {
worker1.increment_load();
}
for _ in 0..5 {
worker2.increment_load();
}
// worker3 has load 0
let workers: Vec<Box<dyn Worker>> =
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
// Run multiple selections
let mut selected_counts = [0; 3];
for _ in 0..100 {
if let Some(idx) = policy.select_worker(&workers, None) {
selected_counts[idx] += 1;
}
}
// Worker with lowest load (worker3) should be selected most often
assert!(selected_counts[2] > selected_counts[1]);
assert!(selected_counts[1] > selected_counts[0]);
}
#[test]
fn test_power_of_two_with_cached_loads() {
let policy = PowerOfTwoPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
];
// Update cached loads
let mut loads = HashMap::new();
loads.insert("http://w1:8000".to_string(), 100);
loads.insert("http://w2:8000".to_string(), 10);
policy.update_loads(&loads);
// Should prefer worker2 with lower cached load
let mut w2_selected = 0;
for _ in 0..50 {
if let Some(idx) = policy.select_worker(&workers, None) {
if idx == 1 {
w2_selected += 1;
}
}
}
// Worker2 should be selected significantly more often
assert!(w2_selected > 35); // Should win most of the time
}
#[test]
fn test_power_of_two_single_worker() {
let policy = PowerOfTwoPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
))];
// With single worker, should always select it
assert_eq!(policy.select_worker(&workers, None), Some(0));
}
}

View File

@@ -0,0 +1,121 @@
//! Random load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
/// Random selection policy
///
/// Selects workers randomly with uniform distribution among healthy workers.
#[derive(Debug, Default)]
pub struct RandomPolicy;
impl RandomPolicy {
pub fn new() -> Self {
Self
}
}
impl LoadBalancingPolicy for RandomPolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
_request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
if healthy_indices.is_empty() {
return None;
}
let mut rng = rand::rng();
let random_idx = rng.random_range(0..healthy_indices.len());
let worker = workers[healthy_indices[random_idx]].url();
RouterMetrics::record_processed_request(worker);
RouterMetrics::record_policy_decision(self.name(), worker);
Some(healthy_indices[random_idx])
}
fn name(&self) -> &'static str {
"random"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{BasicWorker, WorkerType};
use std::collections::HashMap;
#[test]
fn test_random_selection() {
let policy = RandomPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
];
// Test multiple selections to ensure randomness
let mut counts = HashMap::new();
for _ in 0..100 {
if let Some(idx) = policy.select_worker(&workers, None) {
*counts.entry(idx).or_insert(0) += 1;
}
}
// All workers should be selected at least once
assert_eq!(counts.len(), 3);
assert!(counts.values().all(|&count| count > 0));
}
#[test]
fn test_random_with_unhealthy_workers() {
let policy = RandomPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
];
// Mark first worker as unhealthy
workers[0].set_healthy(false);
// Should always select the healthy worker (index 1)
for _ in 0..10 {
assert_eq!(policy.select_worker(&workers, None), Some(1));
}
}
#[test]
fn test_random_no_healthy_workers() {
let policy = RandomPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
))];
workers[0].set_healthy(false);
assert_eq!(policy.select_worker(&workers, None), None);
}
}

View File

@@ -0,0 +1,140 @@
//! Round-robin load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use std::sync::atomic::{AtomicUsize, Ordering};
/// Round-robin selection policy
///
/// Selects workers in sequential order, cycling through all healthy workers.
#[derive(Debug, Default)]
pub struct RoundRobinPolicy {
counter: AtomicUsize,
}
impl RoundRobinPolicy {
pub fn new() -> Self {
Self {
counter: AtomicUsize::new(0),
}
}
}
impl LoadBalancingPolicy for RoundRobinPolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
_request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
if healthy_indices.is_empty() {
return None;
}
// Get and increment counter atomically
let count = self.counter.fetch_add(1, Ordering::Relaxed);
let selected_idx = count % healthy_indices.len();
let worker = workers[healthy_indices[selected_idx]].url();
RouterMetrics::record_processed_request(worker);
RouterMetrics::record_policy_decision(self.name(), worker);
Some(healthy_indices[selected_idx])
}
fn name(&self) -> &'static str {
"round_robin"
}
fn reset(&self) {
self.counter.store(0, Ordering::Relaxed);
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{BasicWorker, WorkerType};
#[test]
fn test_round_robin_selection() {
let policy = RoundRobinPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
];
// Should select workers in order: 0, 1, 2, 0, 1, 2, ...
assert_eq!(policy.select_worker(&workers, None), Some(0));
assert_eq!(policy.select_worker(&workers, None), Some(1));
assert_eq!(policy.select_worker(&workers, None), Some(2));
assert_eq!(policy.select_worker(&workers, None), Some(0));
assert_eq!(policy.select_worker(&workers, None), Some(1));
}
#[test]
fn test_round_robin_with_unhealthy_workers() {
let policy = RoundRobinPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
];
// Mark middle worker as unhealthy
workers[1].set_healthy(false);
// Should skip unhealthy worker: 0, 2, 0, 2, ...
assert_eq!(policy.select_worker(&workers, None), Some(0));
assert_eq!(policy.select_worker(&workers, None), Some(2));
assert_eq!(policy.select_worker(&workers, None), Some(0));
assert_eq!(policy.select_worker(&workers, None), Some(2));
}
#[test]
fn test_round_robin_reset() {
let policy = RoundRobinPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
];
// Advance the counter
assert_eq!(policy.select_worker(&workers, None), Some(0));
assert_eq!(policy.select_worker(&workers, None), Some(1));
// Reset should start from beginning
policy.reset();
assert_eq!(policy.select_worker(&workers, None), Some(0));
}
}

View File

@@ -0,0 +1,541 @@
syntax = "proto3";
package sglang.grpc.scheduler;
import "google/protobuf/timestamp.proto";
import "google/protobuf/struct.proto";
// Service definition for SGLang scheduler communication
// This protocol bridges the Rust router and Python scheduler
service SglangScheduler {
// Initialize connection and get model info
rpc Initialize(InitializeRequest) returns (InitializeResponse);
// Submit a generation request (supports streaming)
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
// Submit an embedding request
rpc Embed(EmbedRequest) returns (EmbedResponse);
// Health check and metrics
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
// Abort a running request
rpc Abort(AbortRequest) returns (AbortResponse);
// Flush KV cache
rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse);
}
// =====================
// Common Types
// =====================
// Sampling parameters matching SGLang's SamplingParams
message SamplingParams {
float temperature = 1;
float top_p = 2;
int32 top_k = 3;
float min_p = 4;
float frequency_penalty = 5;
float presence_penalty = 6;
float repetition_penalty = 7;
int32 max_new_tokens = 8;
repeated string stop = 9;
repeated int32 stop_token_ids = 10;
bool skip_special_tokens = 11;
bool spaces_between_special_tokens = 12;
// Structured generation
oneof constraint {
string regex = 13;
string json_schema = 14;
string ebnf_grammar = 15;
}
// LoRA adapter
string lora_path = 16;
// Speculative decoding
int32 n = 17; // Number of samples
// Token healing
bool token_healing = 18;
// Additional parameters
int32 min_new_tokens = 19;
bool ignore_eos = 20;
bool no_stop_trim = 21;
int32 stream_interval = 22;
map<string, float> logit_bias = 23;
string structural_tag = 24;
// Custom parameters for extensibility
google.protobuf.Struct custom_params = 25;
}
// Session parameters for continual prompting
message SessionParams {
string session_id = 1;
string request_id = 2;
int32 offset = 3;
bool replace = 4;
bool drop_previous_output = 5;
}
// Disaggregated serving parameters
message DisaggregatedParams {
string bootstrap_host = 1;
int32 bootstrap_port = 2;
int32 bootstrap_room = 3;
}
// =====================
// Initialize
// =====================
message InitializeRequest {
string client_id = 1;
string client_version = 2;
// Operating mode
enum Mode {
REGULAR = 0; // Normal mode with local scheduler
PREFILL = 1; // Prefill-only mode for disaggregated serving
DECODE = 2; // Decode-only mode for disaggregated serving
}
Mode mode = 3;
}
message InitializeResponse {
bool success = 1;
string scheduler_version = 2;
// Model information
ModelInfo model_info = 3;
// Server capabilities
ServerCapabilities capabilities = 4;
// Error message if success is false
string error_message = 5;
}
message ModelInfo {
string model_name = 1;
int32 max_context_length = 2;
int32 vocab_size = 3;
bool supports_tool_calling = 4;
bool supports_vision = 5;
repeated string special_tokens = 6;
// Additional model metadata
string model_type = 7;
int32 num_layers = 8;
int32 hidden_size = 9;
int32 num_attention_heads = 10;
int32 num_key_value_heads = 11;
// Tokenizer info
string tokenizer_type = 12;
repeated int32 eos_token_ids = 13;
int32 pad_token_id = 14;
int32 bos_token_id = 15;
}
message ServerCapabilities {
bool continuous_batching = 1;
bool disaggregated_serving = 2;
bool speculative_decoding = 3;
int32 max_batch_size = 4;
int32 max_num_batched_tokens = 5;
int32 max_prefill_tokens = 6;
string attention_backend = 7; // "flashinfer", "triton", "torch"
// Additional capabilities
bool supports_lora = 8;
bool supports_grammar = 9;
bool supports_multimodal = 10;
repeated string supported_modalities = 11; // ["image", "video", "audio"]
bool supports_custom_logit_processor = 12;
bool supports_session = 13;
// Hardware info
int32 num_gpus = 14;
string gpu_type = 15;
int64 total_gpu_memory = 16;
// Parallelism info
int32 tensor_parallel_size = 17;
int32 pipeline_parallel_size = 18;
int32 data_parallel_size = 19;
}
// =====================
// Generate Request
// =====================
message GenerateRequest {
string request_id = 1;
// Input can be either text or tokenized
oneof input {
string text = 2;
TokenizedInput tokenized = 3;
}
// Multimodal inputs
MultimodalInputs mm_inputs = 4;
// Generation parameters
SamplingParams sampling_params = 5;
// Return options
bool return_logprob = 6;
int32 logprob_start_len = 7;
int32 top_logprobs_num = 8;
repeated int32 token_ids_logprob = 9;
bool return_hidden_states = 10;
// Session management
SessionParams session_params = 11;
// For disaggregated serving
DisaggregatedParams disaggregated_params = 12;
// Custom logit processor (serialized)
string custom_logit_processor = 13;
// Request metadata
google.protobuf.Timestamp timestamp = 14;
bool log_metrics = 15;
// Input embeddings (alternative to text/tokens)
repeated float input_embeds = 16;
// LoRA adapter ID (if pre-loaded)
string lora_id = 17;
// Data parallel routing
int32 data_parallel_rank = 18;
// For load balancing
int32 dp_balance_id = 19;
}
message TokenizedInput {
string original_text = 1; // For reference
repeated int32 input_ids = 2;
}
message MultimodalInputs {
// Simplified multimodal handling - actual data processed by tokenizer
repeated string image_urls = 1;
repeated string video_urls = 2;
repeated string audio_urls = 3;
// Pre-processed multimodal features (if available)
google.protobuf.Struct processed_features = 4;
// Raw data for direct processing
repeated bytes image_data = 5;
repeated bytes video_data = 6;
repeated bytes audio_data = 7;
// Modality metadata
repeated string modalities = 8;
}
// =====================
// Generate Response
// =====================
message GenerateResponse {
string request_id = 1;
// Response type
oneof response {
GenerateStreamChunk chunk = 2;
GenerateComplete complete = 3;
GenerateError error = 4;
}
}
message GenerateStreamChunk {
// Generated token
int32 token_id = 1;
string text = 2;
// Cumulative counts
int32 prompt_tokens = 3;
int32 completion_tokens = 4;
int32 cached_tokens = 5;
// Logprobs (if requested)
LogProbs logprobs = 6;
// Hidden states (if requested)
repeated float hidden_states = 7;
// Metadata
float generation_time = 8; // Time to generate this token
int32 queue_time = 9; // Time spent in queue
}
message GenerateComplete {
// Final output
repeated int32 output_ids = 1;
string output_text = 2;
// Finish reason
enum FinishReason {
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 3;
// Final counts
int32 prompt_tokens = 4;
int32 completion_tokens = 5;
int32 cached_tokens = 6;
// Performance metrics
float total_generation_time = 7;
float time_to_first_token = 8;
float tokens_per_second = 9;
// Spec decode metrics
int32 spec_verify_count = 10;
// All logprobs if requested
repeated LogProbs all_logprobs = 11;
// All hidden states if requested
repeated HiddenStates all_hidden_states = 12;
}
message GenerateError {
string message = 1;
string http_status_code = 2;
string details = 3;
}
message LogProbs {
repeated float token_logprobs = 1;
repeated int32 token_ids = 2;
// Top logprobs at each position
repeated TopLogProbs top_logprobs = 3;
// Decoded text for tokens
repeated string token_texts = 4;
}
message TopLogProbs {
repeated float values = 1;
repeated int32 token_ids = 2;
repeated string token_texts = 3;
}
message HiddenStates {
repeated float values = 1;
int32 layer = 2;
int32 position = 3;
}
// =====================
// Embedding Request
// =====================
message EmbedRequest {
string request_id = 1;
oneof input {
string text = 2;
TokenizedInput tokenized = 3;
}
// Multimodal inputs
MultimodalInputs mm_inputs = 4;
// Dummy sampling params for compatibility
// EmbedRequest doesn't use sampling_params
SamplingParams sampling_params = 5;
bool log_metrics = 6;
// Token type IDs for models that require them
repeated int32 token_type_ids = 7;
// Data parallel routing
int32 data_parallel_rank = 8;
// For cross-encoder requests
bool is_cross_encoder = 9;
repeated string texts = 10; // For cross-encoder batch
}
message EmbedResponse {
string request_id = 1;
oneof response {
EmbedComplete complete = 2;
EmbedError error = 3;
}
}
message EmbedComplete {
repeated float embedding = 1;
int32 prompt_tokens = 2;
int32 cached_tokens = 3;
// Additional metadata
int32 embedding_dim = 4;
float generation_time = 5;
// For batch embeddings
repeated Embedding batch_embeddings = 6;
}
message Embedding {
repeated float values = 1;
int32 index = 2;
}
message EmbedError {
string message = 1;
string code = 2;
string details = 3;
}
// =====================
// Management Operations
// =====================
message HealthCheckRequest {
bool include_detailed_metrics = 1;
}
message HealthCheckResponse {
bool healthy = 1;
// Current load metrics
int32 num_requests_running = 2;
int32 num_requests_waiting = 3;
float gpu_cache_usage = 4;
float gpu_memory_usage = 5;
// KV cache metrics
int32 kv_cache_total_blocks = 6;
int32 kv_cache_used_blocks = 7;
float kv_cache_hit_rate = 8;
// Additional metrics
int32 num_grammar_queue_requests = 9;
float generation_throughput = 10; // tokens/sec
float average_queue_time = 11; // seconds
float average_generation_time = 12; // seconds
// System metrics
float cpu_usage = 13;
int64 memory_usage = 14;
// Disaggregation metrics
int32 num_prefill_requests = 15;
int32 num_decode_requests = 16;
// Detailed metrics (optional)
google.protobuf.Struct detailed_metrics = 17;
}
message AbortRequest {
string request_id = 1;
string reason = 2;
}
message AbortResponse {
bool success = 1;
string message = 2;
}
message FlushCacheRequest {
bool flush_all = 1;
repeated string session_ids = 2; // Flush specific sessions
}
message FlushCacheResponse {
bool success = 1;
int32 num_entries_flushed = 2;
int64 memory_freed = 3; // bytes
string message = 4;
}
// =====================
// Additional Operations (Future)
// =====================
// Load LoRA adapter
message LoadLoRARequest {
string adapter_id = 1;
string adapter_path = 2;
int32 rank = 3;
}
message LoadLoRAResponse {
bool success = 1;
string adapter_id = 2;
string message = 3;
}
// Unload LoRA adapter
message UnloadLoRARequest {
string adapter_id = 1;
}
message UnloadLoRAResponse {
bool success = 1;
string message = 2;
}
// Update weights
message UpdateWeightsRequest {
oneof source {
string disk_path = 1;
bytes tensor_data = 2;
string remote_url = 3;
}
string weight_name = 4;
}
message UpdateWeightsResponse {
bool success = 1;
string message = 2;
}
// Get internal state for debugging
message GetInternalStateRequest {
repeated string state_keys = 1;
}
message GetInternalStateResponse {
google.protobuf.Struct state = 1;
}
// Set internal state for testing
message SetInternalStateRequest {
google.protobuf.Struct state = 1;
}
message SetInternalStateResponse {
bool success = 1;
string message = 2;
}

View File

@@ -0,0 +1,5 @@
// Protocol definitions and validation for various LLM APIs
// This module provides a structured approach to handling different API protocols
pub mod spec;
pub mod validation;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,474 @@
# Reasoning Parser Architecture
## 1. Executive Summary
### High-Level Overview
The reasoning parser layer provides a unified interface for detecting and extracting reasoning content from Large Language Model (LLM) outputs, particularly from models that support Chain-of-Thought (CoT) reasoning with explicit thinking blocks. The architecture follows a trait-based design pattern enabling pluggable parser implementations while maintaining consistent APIs across different model families that use various reasoning token formats.
**Key Components:**
- **Factory Pattern**: Registry-based creation and pooling of model-specific parsers
- **Trait System**: `ReasoningParser` trait for implementation flexibility
- **Parser Pooling**: Efficient reuse of parser instances across concurrent requests
- **Streaming Support**: Incremental parsing with partial token buffering
- **Model Detection**: Pattern-based matching for automatic parser selection
- **State Management**: Stateful parsing for streaming scenarios with buffer management
- **Thread Safety**: Arc<Mutex> based sharing for high-concurrency environments
- **Extensibility**: Easy addition of new model-specific parsers
**Data Flow:**
1. Request → Factory (model detection) → Pooled Parser Retrieval
2. One-Shot: Text → Parser → ParserResult (normal + reasoning text)
3. Streaming: Chunks → Parser (stateful) → Incremental ParserResult
4. Buffer Management: Partial Tokens → Buffer → Complete Token Detection
5. Reset: Parser State → Clear Buffers → Ready for Reuse
### Architecture Highlights
- **Model-Specific Parsers**: DeepSeek-R1, Qwen3, Kimi, GLM45, Step3 variants
- **Parser Pooling**: Singleton instances per model type for memory efficiency
- **High Concurrency**: Mutex-protected parsers handle 1000+ req/sec
- **Buffer Overflow Protection**: Configurable max buffer size (default 64KB)
- **Partial Token Detection**: Intelligent buffering for incomplete delimiters
- **Passthrough Mode**: Graceful fallback for unknown models
- **Zero-Copy Where Possible**: Efficient string handling in hot paths
## 2. Mermaid Diagrams
### Component Flow Diagram
```mermaid
graph TB
subgraph Input
R[Request] --> MID[Model ID]
end
subgraph Factory Layer
MID --> PF[ParserFactory]
PF --> REG[ParserRegistry]
REG --> PM[Pattern Matching]
PM --> PP[Parser Pool]
end
subgraph Parser Pool
PP --> DS[DeepSeek-R1]
PP --> QW[Qwen3]
PP --> QWT[Qwen3-Thinking]
PP --> KM[Kimi]
PP --> GL[GLM45]
PP --> S3[Step3]
PP --> PT[Passthrough]
end
subgraph Parser Instance
DS --> BP[BaseReasoningParser]
QW --> BP
KM --> BP
GL --> BP
S3 --> BP
end
subgraph Processing
BP --> DAP[detect_and_parse]
BP --> PSI[parse_streaming]
BP --> RST[reset]
end
subgraph State Management
BP --> BUF[Buffer]
BP --> IR[in_reasoning flag]
BP --> STS[stripped_think_start]
end
subgraph Output
DAP --> PR[ParserResult]
PSI --> PR
PR --> NT[normal_text]
PR --> RT[reasoning_text]
end
```
### Sequence Flow Diagram
```mermaid
sequenceDiagram
participant C as Client
participant F as ParserFactory
participant R as Registry
participant P as Parser Pool
participant BP as BaseParser
participant PR as ParserResult
C->>F: get_pooled("deepseek-r1-model")
F->>R: find_pooled_parser_for_model()
R->>R: pattern_match("deepseek-r1")
R->>P: get_pooled_parser("deepseek_r1")
alt Parser exists in pool
P-->>F: Arc<Mutex<Parser>>
else Create new parser
P->>BP: new DeepSeekR1Parser()
P->>P: insert into pool
P-->>F: Arc<Mutex<Parser>>
end
F-->>C: PooledParser
C->>BP: lock().parse_reasoning_streaming_incremental()
loop streaming chunks
C->>BP: parse_reasoning_streaming_incremental(chunk)
BP->>BP: buffer.push_str(chunk)
BP->>BP: check partial tokens
alt Complete token found
BP->>PR: create result
BP->>BP: clear buffer
BP-->>C: ParserResult
else Partial token
BP->>BP: keep buffering
BP-->>C: ParserResult::default()
end
end
C->>BP: reset()
BP->>BP: clear buffers & flags
C->>BP: unlock()
```
### Class/Type Diagram
```mermaid
classDiagram
class ReasoningParser {
<<trait>>
+detect_and_parse_reasoning(&mut self, text: &str) Result~ParserResult~
+parse_reasoning_streaming_incremental(&mut self, text: &str) Result~ParserResult~
+reset(&mut self)
+model_type(&self) &str
}
class ParserResult {
+normal_text: String
+reasoning_text: String
+new(normal: String, reasoning: String) Self
+normal(text: String) Self
+reasoning(text: String) Self
+is_empty() bool
}
class ParserConfig {
+think_start_token: String
+think_end_token: String
+stream_reasoning: bool
+max_buffer_size: usize
+initial_in_reasoning: bool
+default() Self
}
class BaseReasoningParser {
-config: ParserConfig
-in_reasoning: bool
-buffer: String
-stripped_think_start: bool
-model_type: String
+new(config: ParserConfig) Self
+with_model_type(model: String) Self
-is_partial_token(&self, text: &str) bool
}
class DeepSeekR1Parser {
-base: BaseReasoningParser
+new() Self
}
class Qwen3Parser {
-base: BaseReasoningParser
+new() Self
}
class QwenThinkingParser {
-base: BaseReasoningParser
+new() Self
}
class KimiParser {
-base: BaseReasoningParser
+new() Self
}
class Glm45Parser {
-base: BaseReasoningParser
+new() Self
}
class Step3Parser {
-base: BaseReasoningParser
+new() Self
}
class ParserFactory {
-registry: ParserRegistry
+new() Self
+get_pooled(model_id: &str) PooledParser
+create(model_id: &str) Result~Box~dyn ReasoningParser~~
+clear_pool()
}
class ParserRegistry {
-creators: Arc~RwLock~HashMap~~
-pool: Arc~RwLock~HashMap~~
-patterns: Arc~RwLock~Vec~~
+register_parser(name: &str, creator: F)
+register_pattern(pattern: &str, parser_name: &str)
+get_pooled_parser(name: &str) Option~PooledParser~
+find_pooled_parser_for_model(model: &str) Option~PooledParser~
}
ReasoningParser <|.. BaseReasoningParser
ReasoningParser <|.. DeepSeekR1Parser
ReasoningParser <|.. Qwen3Parser
ReasoningParser <|.. QwenThinkingParser
ReasoningParser <|.. KimiParser
ReasoningParser <|.. Glm45Parser
ReasoningParser <|.. Step3Parser
DeepSeekR1Parser o-- BaseReasoningParser
Qwen3Parser o-- BaseReasoningParser
QwenThinkingParser o-- BaseReasoningParser
KimiParser o-- BaseReasoningParser
Glm45Parser o-- BaseReasoningParser
Step3Parser o-- BaseReasoningParser
BaseReasoningParser o-- ParserConfig
ParserFactory o-- ParserRegistry
ParserRegistry o-- ReasoningParser
```
## 3. Module-by-Module Deep Dive
### 3.1 mod.rs (Main Module)
**Key Responsibilities:**
- Module organization and public API surface
- Re-exports for convenient access to core types
- Separation of concerns across submodules
**Module Structure:**
- `factory`: Parser creation and pooling logic
- `parsers`: Concrete parser implementations
- `traits`: Core trait definitions and types
### 3.2 traits.rs (Trait Definitions)
**ParserResult Methods**:
- `new()`: Create with both normal and reasoning text
- `normal()`: Create with only normal text (convenience)
- `reasoning()`: Create with only reasoning text (convenience)
- `is_empty()`: Check if result contains any text
**ReasoningParser Trait**:
- **`detect_and_parse_reasoning`**: One-shot parsing for complete text
- **`parse_reasoning_streaming_incremental`**: Stateful streaming parser
- **`reset`**: Clear state for parser reuse
- **`model_type`**: Identify parser variant for debugging
**ParserConfig Defaults**:
- Default tokens: `<think>` and `</think>`
- Stream reasoning: true (immediate output)
- Max buffer: 65536 bytes (64KB)
- Initial state: false (explicit reasoning blocks)
### 3.3 factory.rs (Parser Creation & Pooling)
**ParserRegistry Methods**:
1. **`register_parser`**:
- Register creator function for parser type
- Lazy instantiation when requested
- Thread-safe registration
2. **`register_pattern`**:
- Map model ID patterns to parser names
- First-match-wins ordering
- Case-insensitive matching
3. **`get_pooled_parser`**:
- Check pool for existing instance
- Create and pool if not present
- Return Arc<Mutex> for sharing
4. **`find_pooled_parser_for_model`**:
- Pattern match against model ID
- Delegate to get_pooled_parser
- Case-insensitive comparison
**ParserFactory Methods**:
1. **`new()`**:
- Register all built-in parsers
- Setup model pattern mappings
- Initialize empty pool
2. **`get_pooled`**:
- Primary API for getting parsers
- Automatic passthrough fallback
- Guaranteed non-null return
3. **`create`**:
- Create fresh parser instance
- No pooling (for testing/isolation)
- Returns Result for error handling
**Registered Parsers**:
- `base`: Generic configurable parser
- `deepseek_r1`: DeepSeek-R1 (initial_in_reasoning=true)
- `qwen3`: Qwen3 base model (initial_in_reasoning=false)
- `qwen3_thinking`: Qwen3 thinking variant (initial_in_reasoning=true)
- `kimi`: Kimi with Unicode tokens
- `glm45`: GLM-4.5 parser
- `step3`: Step3 parser
- `passthrough`: No-op fallback parser
**Model Pattern Mappings**:
```
"deepseek-r1" → "deepseek_r1"
"qwen3-thinking" → "qwen3_thinking"
"qwen-thinking" → "qwen3_thinking"
"qwen3" → "qwen3"
"qwen" → "qwen3"
"glm45" → "glm45"
"kimi" → "kimi"
"step3" → "step3"
```
### 3.4 parsers/base.rs (Base Implementation)
**Key Methods:**
**`detect_and_parse_reasoning`**:
```
Algorithm:
1. Check buffer overflow protection
2. Detect reasoning presence (in_reasoning OR contains start_token)
3. If no reasoning → return as normal text
4. Remove start token and trim
5. If no end token → assume truncated reasoning
6. Split on end token
7. Extract reasoning and normal portions
```
**`parse_reasoning_streaming_incremental`**:
```
Algorithm:
1. Check buffer capacity
2. Append text to buffer
3. Check if buffer is partial token prefix
4. If partial → buffer and return empty
5. Strip start token if present
6. Find end token position
7. Handle based on state:
- In reasoning + end found → split and return both
- In reasoning + streaming → return accumulated reasoning
- Not in reasoning → return as normal text
- In reasoning + no end → continue buffering
```
**Critical Features:**
1. **Partial Token Detection**:
- Prevents premature token matching
- Buffers incomplete delimiters
- Essential for streaming correctness
2. **Buffer Management**:
- Overflow protection
- Accumulation for partial content
- Clear on complete token detection
3. **State Tracking**:
- `in_reasoning`: Current parsing state
- `stripped_think_start`: Prevent double processing
- `buffer`: Accumulated partial content
## 4. Extensibility Guide
### Adding a New Parser
**Step 1: Create Parser Implementation**
```rust
// src/reasoning_parser/parsers/mymodel.rs
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParserConfig, ReasoningParser};
pub struct MyModelParser {
base: BaseReasoningParser,
}
impl MyModelParser {
pub fn new() -> Self {
let config = ParserConfig {
think_start_token: "<reasoning>".to_string(),
think_end_token: "</reasoning>".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: false, // or true for implicit
};
Self {
base: BaseReasoningParser::new(config)
.with_model_type("mymodel".to_string()),
}
}
}
impl ReasoningParser for MyModelParser {
// Delegate to base or implement custom logic
fn detect_and_parse_reasoning(&mut self, text: &str)
-> Result<ParserResult, ParseError> {
self.base.detect_and_parse_reasoning(text)
}
// ... other trait methods
}
```
**Step 2: Register in Factory**
```rust
// In factory.rs ParserFactory::new()
registry.register_parser("mymodel", || {
Box::new(MyModelParser::new())
});
// Register patterns
registry.register_pattern("my-model", "mymodel");
registry.register_pattern("mymodel", "mymodel");
```
**Step 3: Export from Module**
```rust
// In parsers/mod.rs
pub use self::mymodel::MyModelParser;
// In reasoning_parser/mod.rs
pub use parsers::MyModelParser;
```
### Custom Parsing Logic
For parsers requiring custom logic beyond configuration:
```rust
impl ReasoningParser for CustomParser {
fn parse_reasoning_streaming_incremental(&mut self, text: &str)
-> Result<ParserResult, ParseError> {
// Custom state machine
// Custom token detection
// Custom buffering strategy
// Return appropriate ParserResult
}
}
```

View File

@@ -0,0 +1,566 @@
// Factory and registry for creating model-specific reasoning parsers.
// Now with parser pooling support for efficient reuse across requests.
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use crate::reasoning_parser::parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser,
};
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
/// Type alias for pooled parser instances.
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
/// Registry for model-specific parsers with pooling support.
#[derive(Clone)]
pub struct ParserRegistry {
/// Creator functions for parsers (used when pool is empty)
creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
/// Pooled parser instances for reuse
pool: Arc<RwLock<HashMap<String, PooledParser>>>,
/// Model pattern to parser name mappings
patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
}
impl ParserRegistry {
/// Create a new empty registry.
pub fn new() -> Self {
Self {
creators: Arc::new(RwLock::new(HashMap::new())),
pool: Arc::new(RwLock::new(HashMap::new())),
patterns: Arc::new(RwLock::new(Vec::new())),
}
}
/// Register a parser creator for a given parser type.
pub fn register_parser<F>(&self, name: &str, creator: F)
where
F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
{
let mut creators = self.creators.write().unwrap();
creators.insert(name.to_string(), Arc::new(creator));
}
/// Register a model pattern to parser mapping.
/// Patterns are checked in order, first match wins.
pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
let mut patterns = self.patterns.write().unwrap();
patterns.push((pattern.to_string(), parser_name.to_string()));
}
/// Get a pooled parser by exact name.
/// Returns a shared parser instance from the pool, creating one if needed.
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
// First check if we have a pooled instance
{
let pool = self.pool.read().unwrap();
if let Some(parser) = pool.get(name) {
return Some(Arc::clone(parser));
}
}
// If not in pool, create one and add to pool
let creators = self.creators.read().unwrap();
if let Some(creator) = creators.get(name) {
let parser = Arc::new(Mutex::new(creator()));
// Add to pool for future use
let mut pool = self.pool.write().unwrap();
pool.insert(name.to_string(), Arc::clone(&parser));
Some(parser)
} else {
None
}
}
/// Get a parser by exact name (creates new instance, not pooled).
/// Use this for compatibility or when you need a fresh instance.
pub fn get_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
let creators = self.creators.read().unwrap();
creators.get(name).map(|creator| creator())
}
/// Find a pooled parser for a given model ID by pattern matching.
pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
let patterns = self.patterns.read().unwrap();
let model_lower = model_id.to_lowercase();
for (pattern, parser_name) in patterns.iter() {
if model_lower.contains(&pattern.to_lowercase()) {
return self.get_pooled_parser(parser_name);
}
}
None
}
/// Find a parser for a given model ID by pattern matching (creates new instance).
pub fn find_parser_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
let patterns = self.patterns.read().unwrap();
let model_lower = model_id.to_lowercase();
for (pattern, parser_name) in patterns.iter() {
if model_lower.contains(&pattern.to_lowercase()) {
return self.get_parser(parser_name);
}
}
None
}
/// Clear the parser pool, forcing new instances to be created.
/// Useful for testing or when parsers need to be reset globally.
pub fn clear_pool(&self) {
let mut pool = self.pool.write().unwrap();
pool.clear();
}
}
impl Default for ParserRegistry {
fn default() -> Self {
Self::new()
}
}
/// Factory for creating reasoning parsers based on model type.
#[derive(Clone)]
pub struct ParserFactory {
registry: ParserRegistry,
}
impl ParserFactory {
/// Create a new factory with default parsers registered.
pub fn new() -> Self {
let registry = ParserRegistry::new();
// Register base parser
registry.register_parser("base", || {
Box::new(BaseReasoningParser::new(ParserConfig::default()))
});
// Register DeepSeek-R1 parser (starts with in_reasoning=true)
registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
// Register Qwen3 parser (starts with in_reasoning=false)
registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
// Register Qwen3-thinking parser (starts with in_reasoning=true)
registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
// Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
registry.register_parser("kimi", || Box::new(KimiParser::new()));
// Register GLM45 parser (same format as Qwen3 but separate for debugging)
registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
// Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
registry.register_parser("step3", || Box::new(Step3Parser::new()));
// Register model patterns
registry.register_pattern("deepseek-r1", "deepseek_r1");
registry.register_pattern("qwen3-thinking", "qwen3_thinking");
registry.register_pattern("qwen-thinking", "qwen3_thinking");
registry.register_pattern("qwen3", "qwen3");
registry.register_pattern("qwen", "qwen3");
registry.register_pattern("glm45", "glm45");
registry.register_pattern("kimi", "kimi");
registry.register_pattern("step3", "step3");
Self { registry }
}
/// Get a pooled parser for the given model ID.
/// Returns a shared instance that can be used concurrently.
/// Falls back to a passthrough parser if model is not recognized.
pub fn get_pooled(&self, model_id: &str) -> PooledParser {
// First try to find by pattern
if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
return parser;
}
// Fall back to no-op parser (get or create passthrough in pool)
self.registry
.get_pooled_parser("passthrough")
.unwrap_or_else(|| {
// Register passthrough if not already registered
self.registry.register_parser("passthrough", || {
let config = ParserConfig {
think_start_token: "".to_string(),
think_end_token: "".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: false,
};
Box::new(
BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
)
});
self.registry.get_pooled_parser("passthrough").unwrap()
})
}
/// Create a new parser instance for the given model ID.
/// Returns a fresh instance (not pooled).
/// Use this when you need an isolated parser instance.
pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
// First try to find by pattern
if let Some(parser) = self.registry.find_parser_for_model(model_id) {
return Ok(parser);
}
// Fall back to no-op parser (base parser without reasoning detection)
let config = ParserConfig {
think_start_token: "".to_string(),
think_end_token: "".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: false,
};
Ok(Box::new(
BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
))
}
/// Get the internal registry for custom registration.
pub fn registry(&self) -> &ParserRegistry {
&self.registry
}
/// Clear the parser pool.
/// Useful for testing or when parsers need to be reset globally.
pub fn clear_pool(&self) {
self.registry.clear_pool();
}
}
impl Default for ParserFactory {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_factory_creates_deepseek_r1() {
let factory = ParserFactory::new();
let parser = factory.create("deepseek-r1-distill").unwrap();
assert_eq!(parser.model_type(), "deepseek_r1");
}
#[test]
fn test_factory_creates_qwen3() {
let factory = ParserFactory::new();
let parser = factory.create("qwen3-7b").unwrap();
assert_eq!(parser.model_type(), "qwen3");
}
#[test]
fn test_factory_creates_kimi() {
let factory = ParserFactory::new();
let parser = factory.create("kimi-chat").unwrap();
assert_eq!(parser.model_type(), "kimi");
}
#[test]
fn test_factory_fallback_to_passthrough() {
let factory = ParserFactory::new();
let parser = factory.create("unknown-model").unwrap();
assert_eq!(parser.model_type(), "passthrough");
}
#[test]
fn test_case_insensitive_matching() {
let factory = ParserFactory::new();
let parser1 = factory.create("DeepSeek-R1").unwrap();
let parser2 = factory.create("QWEN3").unwrap();
let parser3 = factory.create("Kimi").unwrap();
assert_eq!(parser1.model_type(), "deepseek_r1");
assert_eq!(parser2.model_type(), "qwen3");
assert_eq!(parser3.model_type(), "kimi");
}
#[test]
fn test_step3_model() {
let factory = ParserFactory::new();
let step3 = factory.create("step3-model").unwrap();
assert_eq!(step3.model_type(), "step3");
}
#[test]
fn test_glm45_model() {
let factory = ParserFactory::new();
let glm45 = factory.create("glm45-v2").unwrap();
assert_eq!(glm45.model_type(), "glm45");
}
#[test]
fn test_pooled_parser_reuse() {
let factory = ParserFactory::new();
// Get the same parser twice - should be the same instance
let parser1 = factory.get_pooled("deepseek-r1");
let parser2 = factory.get_pooled("deepseek-r1");
// Both should point to the same Arc
assert!(Arc::ptr_eq(&parser1, &parser2));
// Different models should get different parsers
let parser3 = factory.get_pooled("qwen3");
assert!(!Arc::ptr_eq(&parser1, &parser3));
}
#[test]
fn test_pooled_parser_concurrent_access() {
use std::thread;
let factory = ParserFactory::new();
let parser = factory.get_pooled("deepseek-r1");
// Spawn multiple threads that use the same parser
let mut handles = vec![];
for i in 0..3 {
let parser_clone = Arc::clone(&parser);
let handle = thread::spawn(move || {
let mut parser = parser_clone.lock().unwrap();
let input = format!("thread {} reasoning</think>answer", i);
let result = parser.detect_and_parse_reasoning(&input).unwrap();
assert_eq!(result.normal_text, "answer");
assert!(result.reasoning_text.contains("reasoning"));
});
handles.push(handle);
}
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_pool_clearing() {
let factory = ParserFactory::new();
// Get a pooled parser
let parser1 = factory.get_pooled("deepseek-r1");
// Clear the pool
factory.clear_pool();
// Get another parser - should be a new instance
let parser2 = factory.get_pooled("deepseek-r1");
// They should be different instances (different Arc pointers)
assert!(!Arc::ptr_eq(&parser1, &parser2));
}
#[test]
fn test_passthrough_parser_pooling() {
let factory = ParserFactory::new();
// Unknown models should get passthrough parser
let parser1 = factory.get_pooled("unknown-model-1");
let parser2 = factory.get_pooled("unknown-model-2");
// Both should use the same passthrough parser instance
assert!(Arc::ptr_eq(&parser1, &parser2));
// Verify it's actually a passthrough parser
let parser = parser1.lock().unwrap();
assert_eq!(parser.model_type(), "passthrough");
}
#[test]
fn test_high_concurrency_parser_access() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Instant;
let factory = ParserFactory::new();
let num_threads = 100;
let requests_per_thread = 50;
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
// Track successful operations
let success_count = Arc::new(AtomicUsize::new(0));
let error_count = Arc::new(AtomicUsize::new(0));
let start = Instant::now();
let mut handles = vec![];
for thread_id in 0..num_threads {
let factory = factory.clone();
let models = models.clone();
let success_count = Arc::clone(&success_count);
let error_count = Arc::clone(&error_count);
let handle = thread::spawn(move || {
for request_id in 0..requests_per_thread {
// Rotate through different models
let model = &models[(thread_id + request_id) % models.len()];
let parser = factory.get_pooled(model);
// Use blocking lock - this is the realistic scenario
// In production, requests would wait for the parser to be available
// Handle poisoned locks gracefully
let mut p = match parser.lock() {
Ok(guard) => guard,
Err(_poisoned) => {
// Lock was poisoned by a panicking thread
// In production, we might want to recreate the parser
// For testing, we'll just skip this iteration
error_count.fetch_add(1, Ordering::Relaxed);
continue;
}
};
// Simulate realistic parsing work with substantial text
// Typical reasoning can be 500-5000 tokens
let reasoning_text = format!(
"Thread {} is processing request {}. Let me think through this step by step. \
First, I need to understand the problem. The problem involves analyzing data \
and making calculations. Let me break this down: \n\
1. Initial analysis shows that we have multiple variables to consider. \
2. The data suggests a pattern that needs further investigation. \
3. Computing the values: {} * {} = {}. \
4. Cross-referencing with previous results indicates consistency. \
5. The mathematical proof follows from the axioms... \
6. Considering edge cases and boundary conditions... \
7. Validating against known constraints... \
8. The conclusion follows logically from premises A, B, and C. \
This reasoning chain demonstrates the validity of our approach.",
thread_id, request_id, thread_id, request_id, thread_id * request_id
);
let answer_text = format!(
"Based on my analysis, the answer for thread {} request {} is: \
The solution involves multiple steps as outlined in the reasoning. \
The final result is {} with confidence level high. \
This conclusion is supported by rigorous mathematical analysis \
and has been validated against multiple test cases. \
The implementation should handle edge cases appropriately.",
thread_id,
request_id,
thread_id * request_id
);
let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
match p.detect_and_parse_reasoning(&input) {
Ok(result) => {
// Verify parsing worked correctly with substantial content
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
assert!(result
.normal_text
.contains(&format!("thread {}", thread_id)));
// For parsers that accumulate reasoning (stream_reasoning=false)
// the reasoning_text should be populated
if !result.reasoning_text.is_empty() {
assert!(result
.reasoning_text
.contains(&format!("Thread {}", thread_id)));
assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
}
// Normal text should always be present
assert!(result.normal_text.len() > 100); // Ensure substantial answer
success_count.fetch_add(1, Ordering::Relaxed);
}
Err(e) => {
eprintln!("Parse error: {:?}", e);
error_count.fetch_add(1, Ordering::Relaxed);
}
}
// Explicitly drop the lock to release it quickly
drop(p);
}
});
handles.push(handle);
}
// Wait for all threads
for handle in handles {
handle.join().unwrap();
}
let duration = start.elapsed();
let total_requests = num_threads * requests_per_thread;
let successes = success_count.load(Ordering::Relaxed);
let errors = error_count.load(Ordering::Relaxed);
// Print stats for debugging
println!(
"High concurrency test: {} threads, {} requests each",
num_threads, requests_per_thread
);
println!(
"Completed in {:?}, {} successes, {} errors",
duration, successes, errors
);
println!(
"Throughput: {:.0} requests/sec",
(total_requests as f64) / duration.as_secs_f64()
);
// All requests should succeed
assert_eq!(successes, total_requests);
assert_eq!(errors, 0);
// Performance check: should handle at least 1000 req/sec
let throughput = (total_requests as f64) / duration.as_secs_f64();
assert!(
throughput > 1000.0,
"Throughput too low: {:.0} req/sec",
throughput
);
}
#[test]
fn test_concurrent_pool_modifications() {
use std::thread;
let factory = ParserFactory::new();
let mut handles = vec![];
// Thread 1: Continuously get parsers
let factory1 = factory.clone();
handles.push(thread::spawn(move || {
for _ in 0..100 {
let _parser = factory1.get_pooled("deepseek-r1");
}
}));
// Thread 2: Continuously clear pool
let factory2 = factory.clone();
handles.push(thread::spawn(move || {
for _ in 0..10 {
factory2.clear_pool();
thread::sleep(std::time::Duration::from_micros(100));
}
}));
// Thread 3: Get different parsers
let factory3 = factory.clone();
handles.push(thread::spawn(move || {
for i in 0..100 {
let models = ["qwen3", "kimi", "unknown"];
let _parser = factory3.get_pooled(models[i % 3]);
}
}));
// Wait for all threads - should not deadlock or panic
for handle in handles {
handle.join().unwrap();
}
}
}

View File

@@ -0,0 +1,10 @@
pub mod factory;
pub mod parsers;
pub mod traits;
pub use factory::{ParserFactory, ParserRegistry, PooledParser};
pub use parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser,
};
pub use traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};

View File

@@ -0,0 +1,386 @@
// Base implementation of reasoning parser that handles common logic
// for detecting and extracting reasoning blocks from text.
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
use tracing as log;
/// Base reasoning parser implementation.
///
/// This parser handles the common logic for detecting reasoning blocks
/// delimited by start and end tokens (e.g., <think> and </think>).
#[derive(Debug, Clone)]
pub struct BaseReasoningParser {
config: ParserConfig,
in_reasoning: bool,
buffer: String,
stripped_think_start: bool,
model_type: String,
}
impl BaseReasoningParser {
/// Create a new BaseReasoningParser with the given configuration.
pub fn new(config: ParserConfig) -> Self {
let in_reasoning = config.initial_in_reasoning;
Self {
config,
in_reasoning,
buffer: String::new(),
stripped_think_start: false,
model_type: "base".to_string(),
}
}
/// Create with custom model type identifier.
pub fn with_model_type(mut self, model_type: String) -> Self {
self.model_type = model_type;
self
}
/// Check if the current buffer is a prefix of one of the tokens.
fn is_partial_token(&self, text: &str) -> bool {
(self.config.think_start_token.starts_with(text) && self.config.think_start_token != text)
|| (self.config.think_end_token.starts_with(text)
&& self.config.think_end_token != text)
}
}
impl ReasoningParser for BaseReasoningParser {
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
log::debug!("detect_and_parse_reasoning called with text: {:?}", text);
// Check input size against buffer limit
if text.len() > self.config.max_buffer_size {
return Err(ParseError::BufferOverflow(text.len()));
}
let in_reasoning = self.in_reasoning || text.contains(&self.config.think_start_token);
log::debug!("in_reasoning: {}", in_reasoning);
if !in_reasoning {
log::debug!("No reasoning detected, returning normal text.");
return Ok(ParserResult::normal(text.to_string()));
}
// The text is considered to be in a reasoning block.
let processed_text = text
.replace(&self.config.think_start_token, "")
.trim()
.to_string();
log::debug!(
"Processed text after removing think_start_token: {:?}",
processed_text
);
if !processed_text.contains(&self.config.think_end_token) {
log::debug!(
"Reasoning truncated, think_end_token not found. Returning reasoning text."
);
// Assume reasoning was truncated before end token
return Ok(ParserResult::reasoning(processed_text));
}
// Extract reasoning content
let splits: Vec<&str> = processed_text
.splitn(2, &self.config.think_end_token)
.collect();
let reasoning_text = splits.first().unwrap_or(&"").to_string();
let normal_text = splits
.get(1)
.map(|s| s.trim().to_string())
.unwrap_or_default();
log::debug!("Extracted reasoning_text: {:?}", reasoning_text);
log::debug!("Extracted normal_text: {:?}", normal_text);
Ok(ParserResult::new(normal_text, reasoning_text))
}
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
) -> Result<ParserResult, ParseError> {
// Check if adding this text would exceed buffer limit
if self.buffer.len() + text.len() > self.config.max_buffer_size {
return Err(ParseError::BufferOverflow(self.buffer.len() + text.len()));
}
// Incrementally parse the streaming text
self.buffer.push_str(text);
let mut current_text = self.buffer.clone();
log::debug!(
"parse_reasoning_streaming_incremental called with text: {:?}",
text
);
log::debug!("current buffer: {:?}", self.buffer);
log::debug!("current_text: {:?}", current_text);
log::debug!(
"in_reasoning: {}, stripped_think_start: {}, stream_reasoning: {}",
self.in_reasoning,
self.stripped_think_start,
self.config.stream_reasoning
);
// If the current text is a prefix of a token, keep buffering
if self.is_partial_token(&current_text) {
return Ok(ParserResult::default());
}
// Strip start token if present
if !self.stripped_think_start && current_text.contains(&self.config.think_start_token) {
current_text = current_text.replace(&self.config.think_start_token, "");
self.buffer = current_text.clone();
self.stripped_think_start = true;
self.in_reasoning = true;
}
// Handle end of reasoning block
let think_end_idx = if self.in_reasoning {
current_text
.find(&self.config.think_end_token)
.unwrap_or(current_text.len())
} else {
current_text.len()
};
if self.in_reasoning && think_end_idx < current_text.len() {
let reasoning_text = &current_text[..think_end_idx];
self.buffer.clear();
self.in_reasoning = false;
let start_idx = think_end_idx + self.config.think_end_token.len();
let normal_text = if start_idx < current_text.len() {
&current_text[start_idx..]
} else {
""
};
return Ok(ParserResult::new(
normal_text.to_string(),
reasoning_text.trim().to_string(),
));
}
// Continue with reasoning content
if self.in_reasoning && self.config.stream_reasoning {
// Stream the content immediately
let reasoning_text = current_text;
self.buffer.clear();
Ok(ParserResult::reasoning(reasoning_text))
} else if !self.in_reasoning {
// If we're not in a reasoning block, return as normal text
// CRITICAL FIX: Return current_text (with buffer) not just text
// This prevents buffer loss when partial tokens are followed by normal text
let normal_text = current_text;
self.buffer.clear();
Ok(ParserResult::normal(normal_text))
} else {
// If we are in a reasoning block but no end token is found, buffer it
Ok(ParserResult::default())
}
}
fn reset(&mut self) {
self.in_reasoning = self.config.initial_in_reasoning;
self.buffer.clear();
self.stripped_think_start = false;
}
fn model_type(&self) -> &str {
&self.model_type
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_parser(
initial_in_reasoning: bool,
stream_reasoning: bool,
) -> BaseReasoningParser {
let config = ParserConfig {
think_start_token: "<think>".to_string(),
think_end_token: "</think>".to_string(),
stream_reasoning,
max_buffer_size: 65536,
initial_in_reasoning,
};
BaseReasoningParser::new(config)
}
#[test]
fn test_detect_and_parse_reasoning() {
let mut parser = create_test_parser(false, true);
let result = parser
.detect_and_parse_reasoning("<think>with reasoning</think> and more text.")
.unwrap();
assert_eq!(result.normal_text, "and more text.");
assert_eq!(result.reasoning_text, "with reasoning");
}
#[test]
fn test_detect_and_parse_no_reasoning() {
let mut parser = create_test_parser(false, true);
let result = parser
.detect_and_parse_reasoning("This is a test without reasoning.")
.unwrap();
assert_eq!(result.normal_text, "This is a test without reasoning.");
assert_eq!(result.reasoning_text, "");
}
#[test]
fn test_detect_and_parse_truncated_reasoning() {
let mut parser = create_test_parser(false, true);
let result = parser
.detect_and_parse_reasoning("<think>with truncated reasoning")
.unwrap();
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "with truncated reasoning");
}
#[test]
fn test_parse_streaming_partial_token() {
let mut parser = create_test_parser(false, true);
let result = parser
.parse_reasoning_streaming_incremental("<thi")
.unwrap();
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "");
}
#[test]
fn test_parse_streaming_complete() {
let mut parser = create_test_parser(false, true);
let result = parser
.parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.")
.unwrap();
assert_eq!(result.normal_text, " and more text.");
assert_eq!(result.reasoning_text, "with reasoning");
}
#[test]
fn test_parse_streaming_no_end_token() {
let mut parser = create_test_parser(true, true);
let result = parser
.parse_reasoning_streaming_incremental("<think>with reasoning")
.unwrap();
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "with reasoning");
}
#[test]
fn test_initial_in_reasoning_true() {
// Parser starts with in_reasoning=true (like DeepSeek-R1)
let mut parser = create_test_parser(true, true);
let result = parser
.detect_and_parse_reasoning("no think tags here")
.unwrap();
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "no think tags here");
}
#[test]
fn test_buffer_loss_bug_fix() {
// Critical test for buffer preservation
let mut parser = create_test_parser(false, true);
// Step 1: Send partial end tag when not in reasoning mode
let result1 = parser.parse_reasoning_streaming_incremental("</").unwrap();
assert_eq!(result1.normal_text, "");
assert_eq!(result1.reasoning_text, "");
// Step 2: Send normal text that doesn't complete the end tag
// Must return "</answer" not just "answer"
let result2 = parser
.parse_reasoning_streaming_incremental("answer")
.unwrap();
assert_eq!(result2.normal_text, "</answer");
assert_eq!(result2.reasoning_text, "");
}
#[test]
fn test_streaming_with_stream_reasoning_enabled() {
let mut parser = create_test_parser(false, true);
// Start reasoning block
let result1 = parser
.parse_reasoning_streaming_incremental("<think>reasoning ")
.unwrap();
assert_eq!(result1.normal_text, "");
assert_eq!(result1.reasoning_text, "reasoning ");
// Continue streaming reasoning
let result2 = parser
.parse_reasoning_streaming_incremental("content ")
.unwrap();
assert_eq!(result2.normal_text, "");
assert_eq!(result2.reasoning_text, "content ");
// End reasoning block
let result3 = parser
.parse_reasoning_streaming_incremental("more</think> normal")
.unwrap();
assert_eq!(result3.normal_text, " normal");
assert_eq!(result3.reasoning_text, "more");
}
#[test]
fn test_reset_state() {
let mut parser = create_test_parser(false, true);
// Process some text
parser
.parse_reasoning_streaming_incremental("<think>reasoning</think> normal")
.unwrap();
// Reset and verify state
parser.reset();
assert!(!parser.in_reasoning);
assert!(parser.buffer.is_empty());
assert!(!parser.stripped_think_start);
}
#[test]
fn test_buffer_overflow_detect_and_parse() {
let config = ParserConfig {
max_buffer_size: 10, // Set a very small buffer
..Default::default()
};
let mut parser = BaseReasoningParser::new(config);
let large_text = "a".repeat(20);
let result = parser.detect_and_parse_reasoning(&large_text);
assert!(result.is_err());
match result {
Err(ParseError::BufferOverflow(size)) => {
assert_eq!(size, 20);
}
_ => panic!("Expected BufferOverflow error"),
}
}
#[test]
fn test_buffer_overflow_streaming() {
let config = ParserConfig {
max_buffer_size: 10, // Set a very small buffer
..Default::default()
};
let mut parser = BaseReasoningParser::new(config);
// Send a partial token that will be buffered
let result1 = parser.parse_reasoning_streaming_incremental("<thi");
assert!(result1.is_ok());
assert_eq!(result1.unwrap().normal_text, "");
// Second chunk would exceed buffer
// Buffer has "<thi" (4 chars) + "this_is_too_large" (17 chars) = 21 total
let result2 = parser.parse_reasoning_streaming_incremental("this_is_too_large");
assert!(result2.is_err());
match result2 {
Err(ParseError::BufferOverflow(size)) => {
assert_eq!(size, 21); // 4 + 17
}
_ => panic!("Expected BufferOverflow error"),
}
}
}

View File

@@ -0,0 +1,112 @@
// DeepSeek-R1 specific reasoning parser.
// This parser starts with in_reasoning=true, assuming all text is reasoning
// until an end token is encountered.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
/// DeepSeek-R1 reasoning parser.
///
/// This parser assumes reasoning from the start of text (in_reasoning=true)
/// and uses <think> and </think> tokens.
pub struct DeepSeekR1Parser {
base: BaseReasoningParser,
}
impl DeepSeekR1Parser {
/// Create a new DeepSeek-R1 parser.
pub fn new() -> Self {
let config = ParserConfig {
think_start_token: "<think>".to_string(),
think_end_token: "</think>".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: true, // Always starts with reasoning
};
Self {
base: BaseReasoningParser::new(config).with_model_type("deepseek_r1".to_string()),
}
}
}
impl Default for DeepSeekR1Parser {
fn default() -> Self {
Self::new()
}
}
impl ReasoningParser for DeepSeekR1Parser {
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
self.base.detect_and_parse_reasoning(text)
}
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
) -> Result<ParserResult, ParseError> {
self.base.parse_reasoning_streaming_incremental(text)
}
fn reset(&mut self) {
self.base.reset()
}
fn model_type(&self) -> &str {
self.base.model_type()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deepseek_r1_initial_state() {
let mut parser = DeepSeekR1Parser::new();
// Should treat text as reasoning even without start token
let result = parser
.detect_and_parse_reasoning("This is reasoning content")
.unwrap();
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "This is reasoning content");
}
#[test]
fn test_deepseek_r1_with_end_token() {
let mut parser = DeepSeekR1Parser::new();
// Should extract reasoning until end token
let result = parser
.detect_and_parse_reasoning("reasoning content</think>normal content")
.unwrap();
assert_eq!(result.normal_text, "normal content");
assert_eq!(result.reasoning_text, "reasoning content");
}
#[test]
fn test_deepseek_r1_streaming() {
let mut parser = DeepSeekR1Parser::new();
// First chunk - all reasoning
let result1 = parser
.parse_reasoning_streaming_incremental("thinking about")
.unwrap();
assert_eq!(result1.reasoning_text, "thinking about");
assert_eq!(result1.normal_text, "");
// Second chunk - ends reasoning
let result2 = parser
.parse_reasoning_streaming_incremental(" the problem</think>answer")
.unwrap();
assert_eq!(result2.reasoning_text, "the problem"); // Text is trimmed
assert_eq!(result2.normal_text, "answer");
}
#[test]
fn test_model_type() {
let parser = DeepSeekR1Parser::new();
assert_eq!(parser.model_type(), "deepseek_r1");
}
}

View File

@@ -0,0 +1,118 @@
// GLM45 specific reasoning parser.
// Uses the same format as Qwen3 but has its own implementation for debugging.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
/// GLM45 reasoning parser.
///
/// This parser uses the same format as Qwen3 (<think>...</think>) but has
/// its own implementation for better debugging and potential future customization.
pub struct Glm45Parser {
base: BaseReasoningParser,
}
impl Glm45Parser {
/// Create a new GLM45 parser.
pub fn new() -> Self {
let config = ParserConfig {
think_start_token: "<think>".to_string(),
think_end_token: "</think>".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: false, // Requires explicit start token like Qwen3
};
Self {
base: BaseReasoningParser::new(config).with_model_type("glm45".to_string()),
}
}
}
impl Default for Glm45Parser {
fn default() -> Self {
Self::new()
}
}
impl ReasoningParser for Glm45Parser {
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
self.base.detect_and_parse_reasoning(text)
}
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
) -> Result<ParserResult, ParseError> {
self.base.parse_reasoning_streaming_incremental(text)
}
fn reset(&mut self) {
self.base.reset()
}
fn model_type(&self) -> &str {
self.base.model_type()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_glm45_initial_state() {
let mut parser = Glm45Parser::new();
// Should NOT treat text as reasoning without start token
let result = parser
.detect_and_parse_reasoning("This is normal content")
.unwrap();
assert_eq!(result.normal_text, "This is normal content");
assert_eq!(result.reasoning_text, "");
}
#[test]
fn test_glm45_with_tokens() {
let mut parser = Glm45Parser::new();
// Should extract reasoning with proper tokens
let result = parser
.detect_and_parse_reasoning("<think>reasoning content</think>answer")
.unwrap();
assert_eq!(result.normal_text, "answer");
assert_eq!(result.reasoning_text, "reasoning content");
}
#[test]
fn test_glm45_streaming() {
let mut parser = Glm45Parser::new();
// First chunk - normal text
let result1 = parser
.parse_reasoning_streaming_incremental("normal text ")
.unwrap();
assert_eq!(result1.normal_text, "normal text ");
assert_eq!(result1.reasoning_text, "");
// Second chunk - enters reasoning
let result2 = parser
.parse_reasoning_streaming_incremental("<think>reasoning")
.unwrap();
assert_eq!(result2.normal_text, "");
assert_eq!(result2.reasoning_text, "reasoning");
// Third chunk - exits reasoning
let result3 = parser
.parse_reasoning_streaming_incremental("</think>answer")
.unwrap();
assert_eq!(result3.normal_text, "answer");
assert_eq!(result3.reasoning_text, "");
}
#[test]
fn test_model_type() {
let parser = Glm45Parser::new();
assert_eq!(parser.model_type(), "glm45");
}
}

View File

@@ -0,0 +1,137 @@
// Kimi specific reasoning parser.
// This parser uses Unicode tokens and starts with in_reasoning=false.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
/// Kimi reasoning parser.
///
/// This parser uses Unicode tokens (◁think▷ and ◁/think▷) and requires
/// explicit start tokens to enter reasoning mode.
pub struct KimiParser {
base: BaseReasoningParser,
}
impl KimiParser {
/// Create a new Kimi parser.
pub fn new() -> Self {
let config = ParserConfig {
think_start_token: "◁think▷".to_string(),
think_end_token: "◁/think▷".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: false, // Requires explicit start token
};
Self {
base: BaseReasoningParser::new(config).with_model_type("kimi".to_string()),
}
}
}
impl Default for KimiParser {
fn default() -> Self {
Self::new()
}
}
impl ReasoningParser for KimiParser {
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
self.base.detect_and_parse_reasoning(text)
}
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
) -> Result<ParserResult, ParseError> {
self.base.parse_reasoning_streaming_incremental(text)
}
fn reset(&mut self) {
self.base.reset()
}
fn model_type(&self) -> &str {
self.base.model_type()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kimi_initial_state() {
let mut parser = KimiParser::new();
// Should NOT treat text as reasoning without start token
let result = parser
.detect_and_parse_reasoning("This is normal content")
.unwrap();
assert_eq!(result.normal_text, "This is normal content");
assert_eq!(result.reasoning_text, "");
}
#[test]
fn test_kimi_with_unicode_tokens() {
let mut parser = KimiParser::new();
// Should extract reasoning with Unicode tokens
let result = parser
.detect_and_parse_reasoning("◁think▷reasoning content◁/think▷answer")
.unwrap();
assert_eq!(result.normal_text, "answer");
assert_eq!(result.reasoning_text, "reasoning content");
}
#[test]
fn test_kimi_partial_unicode() {
let mut parser = KimiParser::new();
// Test partial Unicode token buffering
let result1 = parser
.parse_reasoning_streaming_incremental("◁thi")
.unwrap();
assert_eq!(result1.normal_text, "");
assert_eq!(result1.reasoning_text, "");
// Complete the token
let result2 = parser
.parse_reasoning_streaming_incremental("nk▷reasoning")
.unwrap();
assert_eq!(result2.normal_text, "");
assert_eq!(result2.reasoning_text, "reasoning");
}
#[test]
fn test_kimi_streaming() {
let mut parser = KimiParser::new();
// Normal text first
let result1 = parser
.parse_reasoning_streaming_incremental("normal ")
.unwrap();
assert_eq!(result1.normal_text, "normal ");
assert_eq!(result1.reasoning_text, "");
// Enter reasoning with Unicode token
let result2 = parser
.parse_reasoning_streaming_incremental("◁think▷thinking")
.unwrap();
assert_eq!(result2.normal_text, "");
assert_eq!(result2.reasoning_text, "thinking");
// Exit reasoning
let result3 = parser
.parse_reasoning_streaming_incremental("◁/think▷answer")
.unwrap();
assert_eq!(result3.normal_text, "answer");
assert_eq!(result3.reasoning_text, ""); // Already returned in stream mode
}
#[test]
fn test_model_type() {
let parser = KimiParser::new();
assert_eq!(parser.model_type(), "kimi");
}
}

View File

@@ -0,0 +1,13 @@
pub mod base;
pub mod deepseek_r1;
pub mod glm45;
pub mod kimi;
pub mod qwen3;
pub mod step3;
pub use base::BaseReasoningParser;
pub use deepseek_r1::DeepSeekR1Parser;
pub use glm45::Glm45Parser;
pub use kimi::KimiParser;
pub use qwen3::{Qwen3Parser, QwenThinkingParser};
pub use step3::Step3Parser;

View File

@@ -0,0 +1,178 @@
// Qwen3 specific reasoning parser.
// This parser starts with in_reasoning=false, requiring an explicit
// start token to enter reasoning mode.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
/// Qwen3 reasoning parser.
///
/// This parser requires explicit <think> tokens to enter reasoning mode
/// (in_reasoning=false initially).
pub struct Qwen3Parser {
base: BaseReasoningParser,
}
impl Qwen3Parser {
/// Create a new Qwen3 parser.
pub fn new() -> Self {
let config = ParserConfig {
think_start_token: "<think>".to_string(),
think_end_token: "</think>".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: false, // Requires explicit start token
};
Self {
base: BaseReasoningParser::new(config).with_model_type("qwen3".to_string()),
}
}
}
impl Default for Qwen3Parser {
fn default() -> Self {
Self::new()
}
}
impl ReasoningParser for Qwen3Parser {
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
self.base.detect_and_parse_reasoning(text)
}
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
) -> Result<ParserResult, ParseError> {
self.base.parse_reasoning_streaming_incremental(text)
}
fn reset(&mut self) {
self.base.reset()
}
fn model_type(&self) -> &str {
self.base.model_type()
}
}
/// QwenThinking parser - variant that assumes reasoning from start.
///
/// This is for qwen*thinking models that behave like DeepSeek-R1.
pub struct QwenThinkingParser {
base: BaseReasoningParser,
}
impl QwenThinkingParser {
/// Create a new QwenThinking parser.
pub fn new() -> Self {
let config = ParserConfig {
think_start_token: "<think>".to_string(),
think_end_token: "</think>".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: true, // Assumes reasoning from start
};
Self {
base: BaseReasoningParser::new(config).with_model_type("qwen_thinking".to_string()),
}
}
}
impl Default for QwenThinkingParser {
fn default() -> Self {
Self::new()
}
}
impl ReasoningParser for QwenThinkingParser {
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
self.base.detect_and_parse_reasoning(text)
}
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
) -> Result<ParserResult, ParseError> {
self.base.parse_reasoning_streaming_incremental(text)
}
fn reset(&mut self) {
self.base.reset()
}
fn model_type(&self) -> &str {
self.base.model_type()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qwen3_initial_state() {
let mut parser = Qwen3Parser::new();
// Should NOT treat text as reasoning without start token
let result = parser
.detect_and_parse_reasoning("This is normal content")
.unwrap();
assert_eq!(result.normal_text, "This is normal content");
assert_eq!(result.reasoning_text, "");
}
#[test]
fn test_qwen3_with_tokens() {
let mut parser = Qwen3Parser::new();
// Should extract reasoning with proper tokens
let result = parser
.detect_and_parse_reasoning("<think>reasoning</think>answer")
.unwrap();
assert_eq!(result.normal_text, "answer");
assert_eq!(result.reasoning_text, "reasoning");
}
#[test]
fn test_qwen_thinking_initial_state() {
let mut parser = QwenThinkingParser::new();
// Should treat text as reasoning even without start token
let result = parser
.detect_and_parse_reasoning("This is reasoning content")
.unwrap();
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "This is reasoning content");
}
#[test]
fn test_qwen3_streaming() {
let mut parser = Qwen3Parser::new();
// First chunk - normal text (no start token yet)
let result1 = parser
.parse_reasoning_streaming_incremental("normal text ")
.unwrap();
assert_eq!(result1.normal_text, "normal text ");
assert_eq!(result1.reasoning_text, "");
// Second chunk - enters reasoning
let result2 = parser
.parse_reasoning_streaming_incremental("<think>reasoning")
.unwrap();
assert_eq!(result2.normal_text, "");
assert_eq!(result2.reasoning_text, "reasoning");
}
#[test]
fn test_model_types() {
let qwen3 = Qwen3Parser::new();
assert_eq!(qwen3.model_type(), "qwen3");
let qwen_thinking = QwenThinkingParser::new();
assert_eq!(qwen_thinking.model_type(), "qwen_thinking");
}
}

View File

@@ -0,0 +1,123 @@
// Step3 specific reasoning parser.
// Uses the same format as DeepSeek-R1 but has its own implementation for debugging.
use crate::reasoning_parser::parsers::BaseReasoningParser;
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
/// Step3 reasoning parser.
///
/// This parser uses the same format as DeepSeek-R1 (<think>...</think>) but has
/// its own implementation for better debugging and potential future customization.
pub struct Step3Parser {
base: BaseReasoningParser,
}
impl Step3Parser {
/// Create a new Step3 parser.
pub fn new() -> Self {
let config = ParserConfig {
think_start_token: "<think>".to_string(),
think_end_token: "</think>".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: true, // Assumes reasoning from start like DeepSeek-R1
};
Self {
base: BaseReasoningParser::new(config).with_model_type("step3".to_string()),
}
}
}
impl Default for Step3Parser {
fn default() -> Self {
Self::new()
}
}
impl ReasoningParser for Step3Parser {
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
self.base.detect_and_parse_reasoning(text)
}
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
) -> Result<ParserResult, ParseError> {
self.base.parse_reasoning_streaming_incremental(text)
}
fn reset(&mut self) {
self.base.reset()
}
fn model_type(&self) -> &str {
self.base.model_type()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_step3_initial_state() {
let mut parser = Step3Parser::new();
// Should treat text as reasoning even without start token
let result = parser
.detect_and_parse_reasoning("This is reasoning content")
.unwrap();
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "This is reasoning content");
}
#[test]
fn test_step3_with_end_token() {
let mut parser = Step3Parser::new();
// Should handle text with end token
let result = parser
.detect_and_parse_reasoning("reasoning content</think>answer")
.unwrap();
assert_eq!(result.normal_text, "answer");
assert_eq!(result.reasoning_text, "reasoning content");
}
#[test]
fn test_step3_with_both_tokens() {
let mut parser = Step3Parser::new();
// Should handle both start and end tokens
let result = parser
.detect_and_parse_reasoning("<think>reasoning content</think>answer")
.unwrap();
assert_eq!(result.normal_text, "answer");
assert_eq!(result.reasoning_text, "reasoning content");
}
#[test]
fn test_step3_streaming() {
let mut parser = Step3Parser::new();
// First chunk - treated as reasoning (initial_in_reasoning=true)
let result1 = parser
.parse_reasoning_streaming_incremental("reasoning text ")
.unwrap();
assert_eq!(result1.normal_text, "");
assert_eq!(result1.reasoning_text, "reasoning text ");
// Second chunk - continues reasoning until end token
let result2 = parser
.parse_reasoning_streaming_incremental("more reasoning</think>answer")
.unwrap();
assert_eq!(result2.normal_text, "answer");
assert_eq!(result2.reasoning_text, "more reasoning");
}
#[test]
fn test_model_type() {
let parser = Step3Parser::new();
assert_eq!(parser.model_type(), "step3");
}
}

View File

@@ -0,0 +1,130 @@
use std::fmt;
/// Result of parsing text for reasoning content.
#[derive(Debug, Clone, Default, PartialEq)]
pub struct ParserResult {
/// The normal text outside of reasoning blocks.
pub normal_text: String,
/// The extracted reasoning text from within reasoning blocks.
pub reasoning_text: String,
}
impl ParserResult {
/// Create a new ParserResult with the given normal and reasoning text.
pub fn new(normal_text: String, reasoning_text: String) -> Self {
Self {
normal_text,
reasoning_text,
}
}
/// Create a result with only normal text.
pub fn normal(text: String) -> Self {
Self {
normal_text: text,
reasoning_text: String::new(),
}
}
/// Create a result with only reasoning text.
pub fn reasoning(text: String) -> Self {
Self {
normal_text: String::new(),
reasoning_text: text,
}
}
/// Check if this result contains any text.
pub fn is_empty(&self) -> bool {
self.normal_text.is_empty() && self.reasoning_text.is_empty()
}
}
/// Trait for parsing reasoning content from LLM outputs.
pub trait ReasoningParser: Send + Sync {
/// Detects and parses reasoning from the input text (one-time parsing).
///
/// This method is used for non-streaming scenarios where the complete
/// text is available at once.
///
/// Returns an error if the text exceeds buffer limits or contains invalid UTF-8.
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError>;
/// Parses reasoning incrementally from streaming input.
///
/// This method maintains internal state across calls to handle partial
/// tokens and chunk boundaries correctly.
///
/// Returns an error if the buffer exceeds max_buffer_size.
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
) -> Result<ParserResult, ParseError>;
/// Reset the parser state for reuse.
///
/// This should clear any buffers and reset flags to initial state.
fn reset(&mut self);
/// Get the model type this parser is designed for.
fn model_type(&self) -> &str;
}
/// Error types for reasoning parsing operations.
#[derive(Debug, thiserror::Error)]
pub enum ParseError {
#[error("Invalid UTF-8 in stream: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("Buffer overflow: {0} bytes exceeds maximum")]
BufferOverflow(usize),
#[error("Unknown model type: {0}")]
UnknownModel(String),
#[error("Parser configuration error: {0}")]
ConfigError(String),
}
/// Configuration for parser behavior.
#[derive(Debug, Clone)]
pub struct ParserConfig {
/// The token that marks the start of reasoning content.
pub think_start_token: String,
/// The token that marks the end of reasoning content.
pub think_end_token: String,
/// Whether to stream reasoning content as it arrives.
pub stream_reasoning: bool,
/// Maximum buffer size in bytes.
pub max_buffer_size: usize,
/// Initial state for in_reasoning flag (fixed per parser type).
pub initial_in_reasoning: bool,
}
impl Default for ParserConfig {
fn default() -> Self {
Self {
think_start_token: "<think>".to_string(),
think_end_token: "</think>".to_string(),
stream_reasoning: true,
max_buffer_size: 65536, // 64KB default
initial_in_reasoning: false, // Default to false (explicit reasoning)
}
}
}
impl fmt::Display for ParserResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ParserResult {{ normal: {} chars, reasoning: {} chars }}",
self.normal_text.len(),
self.reasoning_text.len()
)
}
}

View File

@@ -0,0 +1,195 @@
//! Factory for creating router instances
use super::{
http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router},
RouterTrait,
};
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
use crate::policies::PolicyFactory;
use crate::server::AppContext;
use std::sync::Arc;
/// Factory for creating router instances based on configuration
pub struct RouterFactory;
impl RouterFactory {
/// Create a router instance from application context
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// Check if IGW mode is enabled
if ctx.router_config.enable_igw {
return Self::create_igw_router(ctx).await;
}
// Check connection mode and route to appropriate implementation
match ctx.router_config.connection_mode {
ConnectionMode::Grpc => {
// Route to gRPC implementation based on routing mode
match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} => {
Self::create_grpc_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(),
decode_policy.as_ref(),
&ctx.router_config.policy,
ctx,
)
.await
}
RoutingMode::OpenAI { .. } => {
Err("OpenAI mode requires HTTP connection_mode".to_string())
}
}
}
ConnectionMode::Http => {
// Route to HTTP implementation based on routing mode
match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
.await
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} => {
Self::create_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(),
decode_policy.as_ref(),
&ctx.router_config.policy,
ctx,
)
.await
}
RoutingMode::OpenAI { worker_urls, .. } => {
Self::create_openai_router(worker_urls.clone(), ctx).await
}
}
}
}
}
/// Create a regular router with injected policy
async fn create_regular_router(
worker_urls: &[String],
policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// Create policy
let policy = PolicyFactory::create_from_config(policy_config);
// Create regular router with injected policy and context
let router = Router::new(worker_urls.to_vec(), policy, ctx).await?;
Ok(Box::new(router))
}
/// Create a PD router with injected policy
async fn create_pd_router(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
prefill_policy_config: Option<&PolicyConfig>,
decode_policy_config: Option<&PolicyConfig>,
main_policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// Create policies - use specific policies if provided, otherwise fall back to main policy
let prefill_policy =
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Create PD router with separate policies and context
let router = PDRouter::new(
prefill_urls.to_vec(),
decode_urls.to_vec(),
prefill_policy,
decode_policy,
ctx,
)
.await?;
Ok(Box::new(router))
}
/// Create a gRPC router with injected policy
pub async fn create_grpc_router(
worker_urls: &[String],
policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
use super::grpc::router::GrpcRouter;
// Create policy
let policy = PolicyFactory::create_from_config(policy_config);
// Create gRPC router with context
let router = GrpcRouter::new(worker_urls.to_vec(), policy, ctx).await?;
Ok(Box::new(router))
}
/// Create a gRPC PD router with tokenizer and worker configuration
pub async fn create_grpc_pd_router(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
prefill_policy_config: Option<&PolicyConfig>,
decode_policy_config: Option<&PolicyConfig>,
main_policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
use super::grpc::pd_router::GrpcPDRouter;
// Create policies - use specific policies if provided, otherwise fall back to main policy
let prefill_policy =
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Create gRPC PD router with context
let router = GrpcPDRouter::new(
prefill_urls.to_vec(),
decode_urls.to_vec(),
prefill_policy,
decode_policy,
ctx,
)
.await?;
Ok(Box::new(router))
}
/// Create an OpenAI router
async fn create_openai_router(
worker_urls: Vec<String>,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// Use the first worker URL as the OpenAI-compatible base
let base_url = worker_urls
.first()
.cloned()
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
let router =
OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?;
Ok(Box::new(router))
}
/// Create an IGW router (placeholder for future implementation)
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error indicating IGW is not yet implemented
Err("IGW mode is not yet implemented".to_string())
}
}

View File

@@ -0,0 +1,4 @@
//! gRPC router implementations
pub mod pd_router;
pub mod router;

View File

@@ -0,0 +1,328 @@
// PD (Prefill-Decode) gRPC Router Implementation
use crate::config::types::RetryConfig;
use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tracing::{info, warn};
/// gRPC PD (Prefill-Decode) router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter {
/// Prefill worker connections
prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
/// Decode worker connections
decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
/// gRPC clients for prefill workers
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// gRPC clients for decode workers
decode_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// Load balancing policy for prefill
prefill_policy: Arc<dyn LoadBalancingPolicy>,
/// Load balancing policy for decode
decode_policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding
tokenizer: Arc<dyn Tokenizer>,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry,
/// Worker health checkers
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
/// Configuration
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
}
impl GrpcPDRouter {
/// Create a new gRPC PD router
pub async fn new(
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
// Update metrics
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
// Extract necessary components from context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
.clone();
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
.clone();
let tool_parser_registry = ctx
.tool_parser_registry
.ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?;
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Create gRPC clients for prefill workers
let mut prefill_grpc_clients = HashMap::new();
for (url, _bootstrap_port) in &prefill_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
prefill_grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC prefill worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
// Continue with other workers
}
}
}
// Create gRPC clients for decode workers
let mut decode_grpc_clients = HashMap::new();
for url in &decode_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
decode_grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC decode worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
// Continue with other workers
}
}
}
if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
return Err("Failed to connect to any gRPC workers".to_string());
}
// Create Prefill Worker trait objects with gRPC connection mode
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.iter()
.map(|(url, bootstrap_port)| {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
},
crate::core::ConnectionMode::Grpc {
port: *bootstrap_port,
},
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
// Create Decode Worker trait objects with gRPC connection mode
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
.iter()
.map(|url| {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Decode,
crate::core::ConnectionMode::Grpc { port: None },
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
// Initialize policies with workers if needed
if let Some(cache_aware) = prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&prefill_workers);
}
if let Some(cache_aware) = decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&decode_workers);
}
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers));
let prefill_health_checker = crate::core::start_health_checker(
Arc::clone(&prefill_workers),
ctx.router_config.worker_startup_check_interval_secs,
);
let decode_health_checker = crate::core::start_health_checker(
Arc::clone(&decode_workers),
ctx.router_config.worker_startup_check_interval_secs,
);
Ok(GrpcPDRouter {
prefill_workers,
decode_workers,
prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)),
decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)),
prefill_policy,
decode_policy,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config,
})
}
}
impl std::fmt::Debug for GrpcPDRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GrpcPDRouter")
.field(
"prefill_workers_count",
&self.prefill_workers.read().unwrap().len(),
)
.field(
"decode_workers_count",
&self.decode_workers.read().unwrap().len(),
)
.field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware)
.finish()
}
}
#[async_trait]
impl RouterTrait for GrpcPDRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_chat(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn flush_cache(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc_pd"
}
fn readiness(&self) -> Response {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcPDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
vec![]
}
}

View File

@@ -0,0 +1,266 @@
// gRPC Router Implementation
use crate::config::types::RetryConfig;
use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tracing::{info, warn};
/// gRPC router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcRouter {
/// Worker connections
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
/// gRPC clients for each worker
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// Load balancing policy
policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding
tokenizer: Arc<dyn Tokenizer>,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry,
/// Worker health checker
_health_checker: Option<HealthChecker>,
/// Configuration
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
}
impl GrpcRouter {
/// Create a new gRPC router
pub async fn new(
worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
// Update metrics
RouterMetrics::set_active_workers(worker_urls.len());
// Extract necessary components from context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| "gRPC router requires tokenizer".to_string())?
.clone();
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
.clone();
let tool_parser_registry = ctx
.tool_parser_registry
.ok_or_else(|| "gRPC router requires tool parser registry".to_string())?;
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Create gRPC clients for each worker
let mut grpc_clients = HashMap::new();
for url in &worker_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC worker at {}: {}", url, e);
// Continue with other workers
}
}
}
if grpc_clients.is_empty() {
return Err("Failed to connect to any gRPC workers".to_string());
}
// Create Worker trait objects with gRPC connection mode
let mut workers: Vec<Box<dyn Worker>> = Vec::new();
// Move clients from the HashMap to the workers
for url in &worker_urls {
if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Regular,
crate::core::ConnectionMode::Grpc { port: None },
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.with_grpc_client(client);
workers.push(Box::new(worker) as Box<dyn Worker>);
} else {
warn!("No gRPC client for worker {}, skipping", url);
}
}
// Initialize policy with workers if needed
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
}
let workers = Arc::new(RwLock::new(workers));
let health_checker = crate::core::start_health_checker(
Arc::clone(&workers),
ctx.router_config.worker_startup_check_interval_secs,
);
Ok(GrpcRouter {
workers,
grpc_clients: Arc::new(RwLock::new(grpc_clients)),
policy,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
_health_checker: Some(health_checker),
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config,
})
}
}
impl std::fmt::Debug for GrpcRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GrpcRouter")
.field("workers_count", &self.workers.read().unwrap().len())
.field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware)
.finish()
}
}
#[async_trait]
impl RouterTrait for GrpcRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_chat(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn flush_cache(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc"
}
fn readiness(&self) -> Response {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
self.workers
.read()
.unwrap()
.iter()
.map(|w| w.url().to_string())
.collect()
}
}

View File

@@ -0,0 +1,53 @@
use axum::body::Body;
use axum::extract::Request;
use axum::http::HeaderMap;
/// Copy request headers to a Vec of name-value string pairs
/// Used for forwarding headers to backend workers
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
req.headers()
.iter()
.filter_map(|(name, value)| {
// Convert header value to string, skipping non-UTF8 headers
value
.to_str()
.ok()
.map(|v| (name.to_string(), v.to_string()))
})
.collect()
}
/// Convert headers from reqwest Response to axum HeaderMap
/// Filters out hop-by-hop headers that shouldn't be forwarded
pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap {
let mut headers = HeaderMap::new();
for (name, value) in reqwest_headers.iter() {
// Skip hop-by-hop headers that shouldn't be forwarded
let name_str = name.as_str().to_lowercase();
if should_forward_header(&name_str) {
// The original name and value are already valid, so we can just clone them
headers.insert(name.clone(), value.clone());
}
}
headers
}
/// Determine if a header should be forwarded from backend to client
fn should_forward_header(name: &str) -> bool {
// List of headers that should NOT be forwarded (hop-by-hop headers)
!matches!(
name,
"connection" |
"keep-alive" |
"proxy-authenticate" |
"proxy-authorization" |
"te" |
"trailers" |
"transfer-encoding" |
"upgrade" |
"content-encoding" | // Let axum/hyper handle encoding
"host" // Should not forward the backend's host header
)
}

View File

@@ -0,0 +1,6 @@
//! HTTP router implementations
pub mod openai_router;
pub mod pd_router;
pub mod pd_types;
pub mod router;

View File

@@ -0,0 +1,379 @@
//! OpenAI router implementation (reqwest-based)
use crate::config::CircuitBreakerConfig;
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use futures_util::StreamExt;
use std::{
any::Any,
sync::atomic::{AtomicBool, Ordering},
};
/// Router for OpenAI backend
#[derive(Debug)]
pub struct OpenAIRouter {
/// HTTP client for upstream OpenAI-compatible API
client: reqwest::Client,
/// Base URL for identification (no trailing slash)
base_url: String,
/// Circuit breaker
circuit_breaker: CircuitBreaker,
/// Health status
healthy: AtomicBool,
}
impl OpenAIRouter {
/// Create a new OpenAI router
pub async fn new(
base_url: String,
circuit_breaker_config: Option<CircuitBreakerConfig>,
) -> Result<Self, String> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let base_url = base_url.trim_end_matches('/').to_string();
// Convert circuit breaker config
let core_cb_config = circuit_breaker_config
.map(|cb| CoreCircuitBreakerConfig {
failure_threshold: cb.failure_threshold,
success_threshold: cb.success_threshold,
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
})
.unwrap_or_default();
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
Ok(Self {
client,
base_url,
circuit_breaker,
healthy: AtomicBool::new(true),
})
}
}
#[async_trait]
impl super::super::WorkerManagement for OpenAIRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Cannot add workers to OpenAI router".to_string())
}
fn remove_worker(&self, _worker_url: &str) {
// No-op for OpenAI router
}
fn get_worker_urls(&self) -> Vec<String> {
vec![self.base_url.clone()]
}
}
#[async_trait]
impl super::super::RouterTrait for OpenAIRouter {
fn as_any(&self) -> &dyn Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
// Simple upstream probe: GET {base}/v1/models without auth
let url = format!("{}/v1/models", self.base_url);
match self
.client
.get(&url)
.timeout(std::time::Duration::from_secs(2))
.send()
.await
{
Ok(resp) => {
let code = resp.status();
// Treat success and auth-required as healthy (endpoint reachable)
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
(StatusCode::OK, "OK").into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
format!("Upstream status: {}", code),
)
.into_response()
}
}
Err(e) => (
StatusCode::SERVICE_UNAVAILABLE,
format!("Upstream error: {}", e),
)
.into_response(),
}
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// For OpenAI, health_generate is the same as health
self.health(_req).await
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
let info = serde_json::json!({
"router_type": "openai",
"workers": 1,
"base_url": &self.base_url
});
(StatusCode::OK, info.to_string()).into_response()
}
async fn get_models(&self, req: Request<Body>) -> Response {
// Proxy to upstream /v1/models; forward Authorization header if provided
let headers = req.headers();
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
if let Some(auth) = headers
.get("authorization")
.or_else(|| headers.get("Authorization"))
{
upstream = upstream.header("Authorization", auth);
}
match upstream.send().await {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let content_type = res.headers().get(CONTENT_TYPE).cloned();
match res.bytes().await {
Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read upstream response: {}", e),
)
.into_response(),
}
}
Err(e) => (
StatusCode::BAD_GATEWAY,
format!("Failed to contact upstream: {}", e),
)
.into_response(),
}
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
// Not directly supported without model param; return 501
(
StatusCode::NOT_IMPLEMENTED,
"get_model_info not implemented for OpenAI router",
)
.into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &GenerateRequest,
) -> Response {
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
(
StatusCode::NOT_IMPLEMENTED,
"Generate endpoint not supported for OpenAI backend",
)
.into_response()
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response {
if !self.circuit_breaker.can_execute() {
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
}
// Serialize request body, removing SGLang-only fields
let mut payload = match serde_json::to_value(body) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Failed to serialize request: {}", e),
)
.into_response();
}
};
if let Some(obj) = payload.as_object_mut() {
for key in [
"top_k",
"min_p",
"min_tokens",
"regex",
"ebnf",
"stop_token_ids",
"no_stop_trim",
"ignore_eos",
"continue_final_message",
"skip_special_tokens",
"lora_path",
"session_params",
"separate_reasoning",
"stream_reasoning",
"chat_template_kwargs",
"return_hidden_states",
"repetition_penalty",
] {
obj.remove(key);
}
}
let url = format!("{}/v1/chat/completions", self.base_url);
let mut req = self.client.post(&url).json(&payload);
// Forward Authorization header if provided
if let Some(h) = headers {
if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) {
req = req.header("Authorization", auth);
}
}
// Accept SSE when stream=true
if body.stream {
req = req.header("Accept", "text/event-stream");
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
self.circuit_breaker.record_failure();
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("Failed to contact upstream: {}", e),
)
.into_response();
}
};
let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !body.stream {
// Capture Content-Type before consuming response body
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
match resp.bytes().await {
Ok(body) => {
self.circuit_breaker.record_success();
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => {
self.circuit_breaker.record_failure();
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
)
.into_response()
}
}
} else {
// Stream SSE bytes to client
let stream = resp.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
let mut s = stream;
while let Some(chunk) = s.next().await {
match chunk {
Ok(bytes) => {
if tx.send(Ok(bytes)).is_err() {
break;
}
}
Err(e) => {
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let mut response = Response::new(Body::from_stream(
tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
));
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
}
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &CompletionRequest,
) -> Response {
// Completion endpoint not implemented for OpenAI backend
(
StatusCode::NOT_IMPLEMENTED,
"Completion endpoint not implemented for OpenAI backend",
)
.into_response()
}
async fn flush_cache(&self) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"flush_cache not supported for OpenAI router",
)
.into_response()
}
async fn get_worker_loads(&self) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"get_worker_loads not supported for OpenAI router",
)
.into_response()
}
fn router_type(&self) -> &'static str {
"openai"
}
fn readiness(&self) -> Response {
if self.healthy.load(Ordering::Acquire) && self.circuit_breaker.can_execute() {
(StatusCode::OK, "Ready").into_response()
} else {
(StatusCode::SERVICE_UNAVAILABLE, "Not ready").into_response()
}
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Embeddings endpoint not implemented for OpenAI backend",
)
.into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Rerank endpoint not implemented for OpenAI backend",
)
.into_response()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,81 @@
// Custom error type for PD router operations
#[derive(Debug, thiserror::Error)]
pub enum PDRouterError {
#[error("Worker already exists: {url}")]
WorkerAlreadyExists { url: String },
#[error("Worker not found: {url}")]
WorkerNotFound { url: String },
#[error("Lock acquisition failed: {operation}")]
LockError { operation: String },
#[error("Health check failed for worker: {url}")]
HealthCheckFailed { url: String },
#[error("Invalid worker configuration: {reason}")]
InvalidConfiguration { reason: String },
#[error("Network error: {message}")]
NetworkError { message: String },
#[error("Timeout waiting for worker: {url}")]
Timeout { url: String },
}
// Helper functions for workers
pub fn api_path(url: &str, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", url, api_path)
} else {
format!("{}/{}", url, api_path)
}
}
pub fn get_hostname(url: &str) -> String {
// Simple hostname extraction without external dependencies
let url = url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap_or("localhost").to_string()
}
use serde::Serialize;
// Optimized bootstrap wrapper for single requests
#[derive(Serialize)]
pub struct RequestWithBootstrap<'a, T: Serialize> {
#[serde(flatten)]
pub original: &'a T,
pub bootstrap_host: String,
pub bootstrap_port: Option<u16>,
pub bootstrap_room: u64,
}
// Optimized bootstrap wrapper for batch requests
#[derive(Serialize)]
pub struct BatchRequestWithBootstrap<'a, T: Serialize> {
#[serde(flatten)]
pub original: &'a T,
pub bootstrap_host: Vec<String>,
pub bootstrap_port: Vec<Option<u16>>,
pub bootstrap_room: Vec<u64>,
}
// Helper to generate bootstrap room ID
pub fn generate_room_id() -> u64 {
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64)
}
// PD-specific routing policies
#[derive(Debug, Clone, PartialEq)]
pub enum PDSelectionPolicy {
Random,
PowerOfTwo,
CacheAware {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
},
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,107 @@
//! Router implementations
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::fmt::Debug;
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory;
pub mod grpc;
pub mod header_utils;
pub mod http;
pub use factory::RouterFactory;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
pub use http::{openai_router, pd_router, pd_types, router};
/// Worker management trait for administrative operations
///
/// This trait is separate from RouterTrait to allow Send futures
/// for use in service discovery and other background tasks
#[async_trait]
pub trait WorkerManagement: Send + Sync {
/// Add a worker to the router
async fn add_worker(&self, worker_url: &str) -> Result<String, String>;
/// Remove a worker from the router
fn remove_worker(&self, worker_url: &str);
/// Get all worker URLs
fn get_worker_urls(&self) -> Vec<String>;
}
/// Core trait for all router implementations
///
/// This trait provides a unified interface for routing requests,
/// regardless of whether it's a regular router or PD router.
#[async_trait]
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
/// Get a reference to self as Any for downcasting
fn as_any(&self) -> &dyn std::any::Any;
/// Route a health check request
async fn health(&self, req: Request<Body>) -> Response;
/// Route a health generate request
async fn health_generate(&self, req: Request<Body>) -> Response;
/// Get server information
async fn get_server_info(&self, req: Request<Body>) -> Response;
/// Get available models
async fn get_models(&self, req: Request<Body>) -> Response;
/// Get model information
async fn get_model_info(&self, req: Request<Body>) -> Response;
/// Route a generate request
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest)
-> Response;
/// Route a chat completion request
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response;
/// Route a completion request
async fn route_completion(
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
) -> Response;
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
/// Flush cache on all workers
async fn flush_cache(&self) -> Response;
/// Get worker loads (for monitoring)
async fn get_worker_loads(&self) -> Response;
/// Get router type name
fn router_type(&self) -> &'static str;
/// Check if this is a PD router
fn is_pd_mode(&self) -> bool {
self.router_type() == "pd"
}
/// Server liveness check - is the server process running
fn liveness(&self) -> Response {
// Simple liveness check - if we can respond, we're alive
(StatusCode::OK, "OK").into_response()
}
/// Server readiness check - is the server ready to handle requests
fn readiness(&self) -> Response;
}

474
sgl-router/src/server.rs Normal file
View File

@@ -0,0 +1,474 @@
use crate::config::RouterConfig;
use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig};
use crate::middleware::TokenBucket;
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
use crate::tool_parser::ParserRegistry;
use axum::{
extract::{Query, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use reqwest::Client;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::signal;
use tokio::spawn;
use tracing::{error, info, warn, Level};
#[derive(Clone)]
pub struct AppContext {
pub client: Client,
pub router_config: RouterConfig,
pub rate_limiter: Arc<TokenBucket>,
pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>,
pub tool_parser_registry: Option<&'static ParserRegistry>,
}
impl AppContext {
pub fn new(
router_config: RouterConfig,
client: Client,
max_concurrent_requests: usize,
rate_limit_tokens_per_second: Option<usize>,
) -> Result<Self, String> {
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
// Initialize gRPC-specific components only when in gRPC mode
let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
if router_config.connection_mode == crate::config::ConnectionMode::Grpc {
// Get tokenizer path (required for gRPC mode)
let tokenizer_path = router_config
.tokenizer_path
.clone()
.or_else(|| router_config.model_path.clone())
.ok_or_else(|| {
"gRPC mode requires either --tokenizer-path or --model-path to be specified"
.to_string()
})?;
// Initialize all gRPC components
let tokenizer = Some(
tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?,
);
let reasoning_parser_factory = Some(ParserFactory::new());
let tool_parser_registry = Some(ParserRegistry::new());
(tokenizer, reasoning_parser_factory, tool_parser_registry)
} else {
// HTTP mode doesn't need these components
(None, None, None)
};
Ok(Self {
client,
router_config,
rate_limiter,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
})
}
}
#[derive(Clone)]
pub struct AppState {
pub router: Arc<dyn RouterTrait>,
pub context: Arc<AppContext>,
pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
}
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
StatusCode::NOT_FOUND.into_response()
}
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
state.router.liveness()
}
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
state.router.readiness()
}
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.health(req).await
}
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.health_generate(req).await
}
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_server_info(req).await
}
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_models(req).await
}
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_model_info(req).await
}
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
async fn generate(
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<GenerateRequest>,
) -> Response {
state.router.route_generate(Some(&headers), &body).await
}
async fn v1_chat_completions(
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<ChatCompletionRequest>,
) -> Response {
state.router.route_chat(Some(&headers), &body).await
}
async fn v1_completions(
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<CompletionRequest>,
) -> Response {
state.router.route_completion(Some(&headers), &body).await
}
// Worker management endpoints
async fn add_worker(
State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>,
) -> Response {
let worker_url = match params.get("url") {
Some(url) => url.to_string(),
None => {
return (
StatusCode::BAD_REQUEST,
"Worker URL required. Provide 'url' query parameter",
)
.into_response();
}
};
match state.router.add_worker(&worker_url).await {
Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
}
}
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
let worker_list = state.router.get_worker_urls();
Json(serde_json::json!({ "urls": worker_list })).into_response()
}
async fn remove_worker(
State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>,
) -> Response {
let worker_url = match params.get("url") {
Some(url) => url.to_string(),
None => return StatusCode::BAD_REQUEST.into_response(),
};
state.router.remove_worker(&worker_url);
(
StatusCode::OK,
format!("Successfully removed worker: {}", worker_url),
)
.into_response()
}
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.flush_cache().await
}
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.get_worker_loads().await
}
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub router_config: RouterConfig,
pub max_payload_size: usize,
pub log_dir: Option<String>,
pub log_level: Option<String>,
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
pub prometheus_config: Option<PrometheusConfig>,
pub request_timeout_secs: u64,
pub request_id_headers: Option<Vec<String>>,
}
/// Build the Axum application with all routes and middleware
pub fn build_app(
app_state: Arc<AppState>,
max_payload_size: usize,
request_id_headers: Vec<String>,
cors_allowed_origins: Vec<String>,
) -> Router {
// Create routes
let protected_routes = Router::new()
.route("/generate", post(generate))
.route("/v1/chat/completions", post(v1_chat_completions))
.route("/v1/completions", post(v1_completions))
.route_layer(axum::middleware::from_fn_with_state(
app_state.clone(),
crate::middleware::concurrency_limit_middleware,
));
let public_routes = Router::new()
.route("/liveness", get(liveness))
.route("/readiness", get(readiness))
.route("/health", get(health))
.route("/health_generate", get(health_generate))
.route("/v1/models", get(v1_models))
.route("/get_model_info", get(get_model_info))
.route("/get_server_info", get(get_server_info));
let admin_routes = Router::new()
.route("/add_worker", post(add_worker))
.route("/remove_worker", post(remove_worker))
.route("/list_workers", get(list_workers))
.route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads));
// Build app with all routes and middleware
Router::new()
.merge(protected_routes)
.merge(public_routes)
.merge(admin_routes)
// Request body size limiting
.layer(tower_http::limit::RequestBodyLimitLayer::new(
max_payload_size,
))
// Request ID layer - must be added AFTER logging layer in the code
// so it executes BEFORE logging layer at runtime (layers execute bottom-up)
.layer(crate::middleware::RequestIdLayer::new(request_id_headers))
// Custom logging layer that can now see request IDs from extensions
.layer(crate::middleware::create_logging_layer())
// CORS (should be outermost)
.layer(create_cors_layer(cors_allowed_origins))
// Fallback
.fallback(sink_handler)
// State - apply last to get Router<Arc<AppState>>
.with_state(app_state)
}
pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
// Only initialize logging if not already done (for Python bindings support)
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
Some(logging::init_logging(LoggingConfig {
level: config
.log_level
.as_deref()
.and_then(|s| match s.to_uppercase().parse::<Level>() {
Ok(l) => Some(l),
Err(_) => {
warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
None
}
})
.unwrap_or(Level::INFO),
json_format: false,
log_dir: config.log_dir.clone(),
colorize: true,
log_file_name: "sgl-router".to_string(),
log_targets: None,
}))
} else {
None
};
// Initialize prometheus metrics exporter
if let Some(prometheus_config) = config.prometheus_config {
metrics::start_prometheus(prometheus_config);
}
info!(
"Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
config.host,
config.port,
config.router_config.mode,
config.router_config.policy,
config.max_payload_size / (1024 * 1024)
);
let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50)))
.pool_max_idle_per_host(500) // Increase to 500 connections per host
.timeout(Duration::from_secs(config.request_timeout_secs))
.connect_timeout(Duration::from_secs(10)) // Separate connection timeout
.tcp_nodelay(true)
.tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
.build()
.expect("Failed to create HTTP client");
// Create the application context with all dependencies
let app_context = Arc::new(AppContext::new(
config.router_config.clone(),
client.clone(),
config.router_config.max_concurrent_requests,
config.router_config.rate_limit_tokens_per_second,
)?);
// Create router with the context
let router = RouterFactory::create_router(&app_context).await?;
// Set up concurrency limiter with queue if configured
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
app_context.rate_limiter.clone(),
config.router_config.queue_size,
Duration::from_secs(config.router_config.queue_timeout_secs),
);
// Start queue processor if enabled
if let Some(processor) = processor {
tokio::spawn(processor.run());
info!(
"Started request queue with size: {}, timeout: {}s",
config.router_config.queue_size, config.router_config.queue_timeout_secs
);
}
// Create app state with router and context
let app_state = Arc::new(AppState {
router: Arc::from(router),
context: app_context.clone(),
concurrency_queue_tx: limiter.queue_tx.clone(),
});
let router_arc = Arc::clone(&app_state.router);
// Start the service discovery if enabled
if let Some(service_discovery_config) = config.service_discovery_config {
if service_discovery_config.enabled {
match start_service_discovery(service_discovery_config, router_arc).await {
Ok(handle) => {
info!("Service discovery started");
// Spawn a task to handle the service discovery thread
spawn(async move {
if let Err(e) = handle.await {
error!("Service discovery task failed: {:?}", e);
}
});
}
Err(e) => {
error!("Failed to start service discovery: {}", e);
warn!("Continuing without service discovery");
}
}
}
}
info!(
"Router ready | workers: {:?}",
app_state.router.get_worker_urls()
);
// Configure request ID headers
let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
vec![
"x-request-id".to_string(),
"x-correlation-id".to_string(),
"x-trace-id".to_string(),
"request-id".to_string(),
]
});
// Build the application
let app = build_app(
app_state,
config.max_payload_size,
request_id_headers,
config.router_config.cors_allowed_origins.clone(),
);
// Create TCP listener - use the configured host
let addr = format!("{}:{}", config.host, config.port);
let listener = TcpListener::bind(&addr).await?;
// Start server with graceful shutdown
info!("Starting server on {}", addr);
// Serve the application with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
Ok(())
}
// Graceful shutdown handler
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
info!("Received Ctrl+C, starting graceful shutdown");
},
_ = terminate => {
info!("Received terminate signal, starting graceful shutdown");
},
}
}
// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
use tower_http::cors::Any;
let cors = if allowed_origins.is_empty() {
// Allow all origins if none specified
tower_http::cors::CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
.expose_headers(Any)
} else {
// Restrict to specific origins
let origins: Vec<http::HeaderValue> = allowed_origins
.into_iter()
.filter_map(|origin| origin.parse().ok())
.collect();
tower_http::cors::CorsLayer::new()
.allow_origin(origins)
.allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
.allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
.expose_headers([http::header::HeaderName::from_static("x-request-id")])
};
cors.max_age(Duration::from_secs(3600))
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,182 @@
//! Chat template support for tokenizers using Jinja2 templates
//!
//! This module provides functionality to apply chat templates to messages,
//! similar to HuggingFace transformers' apply_chat_template method.
use anyhow::{anyhow, Result};
use minijinja::{context, Environment, Value};
use serde::{Deserialize, Serialize};
use serde_json;
/// Represents a chat message with role and content
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
ChatMessage {
role: role.into(),
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
}
/// Chat template processor using Jinja2
pub struct ChatTemplateProcessor {
template: String,
bos_token: Option<String>,
eos_token: Option<String>,
}
impl ChatTemplateProcessor {
/// Create a new chat template processor
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
ChatTemplateProcessor {
template,
bos_token,
eos_token,
}
}
/// Apply the chat template to a list of messages
///
/// This mimics the behavior of HuggingFace's apply_chat_template method
/// but returns the formatted string instead of token IDs.
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
add_generation_prompt: bool,
) -> Result<String> {
let mut env = Environment::new();
// Register the template
env.add_template("chat", &self.template)
.map_err(|e| anyhow!("Failed to add template: {}", e))?;
// Get the template
let tmpl = env
.get_template("chat")
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
// Convert messages to a format Jinja can work with
let messages_value: Vec<Value> = messages
.iter()
.map(|msg| {
context! {
role => msg.role.clone(),
content => msg.content.clone()
}
})
.collect();
// Render the template
let rendered = tmpl
.render(context! {
messages => messages_value,
add_generation_prompt => add_generation_prompt,
bos_token => self.bos_token.clone().unwrap_or_default(),
eos_token => self.eos_token.clone().unwrap_or_default()
})
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
Ok(rendered)
}
}
/// Load chat template from tokenizer config JSON
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
use std::fs;
let content = fs::read_to_string(config_path)?;
let config: serde_json::Value = serde_json::from_str(&content)?;
// Look for chat_template in the config
if let Some(template) = config.get("chat_template") {
if let Some(template_str) = template.as_str() {
return Ok(Some(template_str.to_string()));
}
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(msg.role, "system");
assert_eq!(msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
}
#[test]
fn test_simple_chat_template() {
// Simple template that formats messages
let template = r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
assert!(result.contains("system: You are helpful"));
assert!(result.contains("user: Hello"));
assert!(result.contains("assistant:"));
}
#[test]
fn test_chat_template_with_tokens() {
// Template that uses special tokens
let template = r#"
{{ bos_token }}
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}{{ eos_token }}
{% endfor -%}
"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = vec![ChatMessage::user("Test")];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert!(result.contains("<s>"));
assert!(result.contains("</s>"));
}
}

View File

@@ -0,0 +1,318 @@
use super::traits;
use anyhow::{Error, Result};
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::sync::Arc;
use super::huggingface::HuggingFaceTokenizer;
use super::tiktoken::TiktokenTokenizer;
use crate::tokenizer::hub::download_tokenizer_from_hf;
/// Represents the type of tokenizer being used
#[derive(Debug, Clone)]
pub enum TokenizerType {
HuggingFace(String),
Mock,
Tiktoken(String),
// Future: SentencePiece, GGUF
}
/// Create a tokenizer from a file path to a tokenizer file.
/// The file extension is used to determine the tokenizer type.
/// Supported file types are:
/// - json: HuggingFace tokenizer
/// - For testing: can return mock tokenizer
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
create_tokenizer_with_chat_template(file_path, None)
}
/// Create a tokenizer from a file path with an optional chat template
pub fn create_tokenizer_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Arc<dyn traits::Tokenizer>> {
// Special case for testing
if file_path == "mock" || file_path == "test" {
return Ok(Arc::new(super::mock::MockTokenizer::new()));
}
let path = Path::new(file_path);
// Check if file exists
if !path.exists() {
return Err(Error::msg(format!("File not found: {}", file_path)));
}
// Try to determine tokenizer type from extension
let extension = path
.extension()
.and_then(std::ffi::OsStr::to_str)
.map(|s| s.to_lowercase());
let result = match extension.as_deref() {
Some("json") => {
let tokenizer =
HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
}
Some("model") => {
// SentencePiece model file
Err(Error::msg("SentencePiece models not yet supported"))
}
Some("gguf") => {
// GGUF format
Err(Error::msg("GGUF format not yet supported"))
}
_ => {
// Try to auto-detect by reading file content
auto_detect_tokenizer(file_path)
}
};
result
}
/// Auto-detect tokenizer type by examining file content
fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; 512]; // Read first 512 bytes for detection
let bytes_read = file.read(&mut buffer)?;
buffer.truncate(bytes_read);
// Check for JSON (HuggingFace format)
if is_likely_json(&buffer) {
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
return Ok(Arc::new(tokenizer));
}
// Check for GGUF magic number
if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
return Err(Error::msg("GGUF format detected but not yet supported"));
}
// Check for SentencePiece model
if is_likely_sentencepiece(&buffer) {
return Err(Error::msg(
"SentencePiece model detected but not yet supported",
));
}
Err(Error::msg(format!(
"Unable to determine tokenizer type for file: {}",
file_path
)))
}
/// Check if the buffer likely contains JSON data
fn is_likely_json(buffer: &[u8]) -> bool {
// Skip UTF-8 BOM if present
let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
&buffer[3..]
} else {
buffer
};
// Find first non-whitespace character without allocation
if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
*first_byte == b'{' || *first_byte == b'['
} else {
false
}
}
/// Check if the buffer likely contains a SentencePiece model
fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
// SentencePiece models often start with specific patterns
// This is a simplified check
buffer.len() >= 12
&& (buffer.starts_with(b"\x0a\x09")
|| buffer.starts_with(b"\x08\x00")
|| buffer.windows(4).any(|w| w == b"<unk")
|| buffer.windows(4).any(|w| w == b"<s>")
|| buffer.windows(4).any(|w| w == b"</s>"))
}
/// Factory function to create tokenizer from a model name or path (async version)
pub async fn create_tokenizer_async(
model_name_or_path: &str,
) -> Result<Arc<dyn traits::Tokenizer>> {
// Check if it's a file path
let path = Path::new(model_name_or_path);
if path.exists() {
return create_tokenizer_from_file(model_name_or_path);
}
// Check if it's a GPT model name that should use Tiktoken
if model_name_or_path.contains("gpt-")
|| model_name_or_path.contains("davinci")
|| model_name_or_path.contains("curie")
|| model_name_or_path.contains("babbage")
|| model_name_or_path.contains("ada")
{
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
return Ok(Arc::new(tokenizer));
}
// Try to download tokenizer files from HuggingFace
match download_tokenizer_from_hf(model_name_or_path).await {
Ok(cache_dir) => {
// Look for tokenizer.json in the cache directory
let tokenizer_path = cache_dir.join("tokenizer.json");
if tokenizer_path.exists() {
create_tokenizer_from_file(tokenizer_path.to_str().unwrap())
} else {
// Try other common tokenizer file names
let possible_files = ["tokenizer_config.json", "vocab.json"];
for file_name in &possible_files {
let file_path = cache_dir.join(file_name);
if file_path.exists() {
return create_tokenizer_from_file(file_path.to_str().unwrap());
}
}
Err(Error::msg(format!(
"Downloaded model '{}' but couldn't find a suitable tokenizer file",
model_name_or_path
)))
}
}
Err(e) => Err(Error::msg(format!(
"Failed to download tokenizer from HuggingFace: {}",
e
))),
}
}
/// Factory function to create tokenizer from a model name or path (blocking version)
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
// Check if it's a file path
let path = Path::new(model_name_or_path);
if path.exists() {
return create_tokenizer_from_file(model_name_or_path);
}
// Check if it's a GPT model name that should use Tiktoken
if model_name_or_path.contains("gpt-")
|| model_name_or_path.contains("davinci")
|| model_name_or_path.contains("curie")
|| model_name_or_path.contains("babbage")
|| model_name_or_path.contains("ada")
{
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
return Ok(Arc::new(tokenizer));
}
// Only use tokio for HuggingFace downloads
// Check if we're already in a tokio runtime
if let Ok(handle) = tokio::runtime::Handle::try_current() {
// We're in a runtime, use block_in_place
tokio::task::block_in_place(|| handle.block_on(create_tokenizer_async(model_name_or_path)))
} else {
// No runtime, create a temporary one
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(create_tokenizer_async(model_name_or_path))
}
}
/// Get information about a tokenizer file
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
let path = Path::new(file_path);
if !path.exists() {
return Err(Error::msg(format!("File not found: {}", file_path)));
}
let extension = path
.extension()
.and_then(std::ffi::OsStr::to_str)
.map(|s| s.to_lowercase());
match extension.as_deref() {
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
_ => {
// Try auto-detection
use std::fs::File;
use std::io::Read;
let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; 512];
let bytes_read = file.read(&mut buffer)?;
buffer.truncate(bytes_read);
if is_likely_json(&buffer) {
Ok(TokenizerType::HuggingFace(file_path.to_string()))
} else {
Err(Error::msg("Unknown tokenizer type"))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_detection() {
assert!(is_likely_json(b"{\"test\": \"value\"}"));
assert!(is_likely_json(b" \n\t{\"test\": \"value\"}"));
assert!(is_likely_json(b"[1, 2, 3]"));
assert!(!is_likely_json(b"not json"));
assert!(!is_likely_json(b""));
}
#[test]
fn test_mock_tokenizer_creation() {
let tokenizer = create_tokenizer_from_file("mock").unwrap();
assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens
}
#[test]
fn test_file_not_found() {
let result = create_tokenizer_from_file("/nonexistent/file.json");
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("File not found"));
}
}
#[test]
fn test_create_tiktoken_tokenizer() {
// Test creating tokenizer for GPT models
let tokenizer = create_tokenizer("gpt-4").unwrap();
assert!(tokenizer.vocab_size() > 0);
// Test encoding and decoding
let text = "Hello, world!";
let encoding = tokenizer.encode(text).unwrap();
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, text);
}
#[tokio::test]
async fn test_download_tokenizer_from_hf() {
// Test with a small model that should have tokenizer files
// Skip this test if HF_TOKEN is not set and we're in CI
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
println!("Skipping HF download test in CI without HF_TOKEN");
return;
}
// Try to create tokenizer for a known small model
let result = create_tokenizer_async("bert-base-uncased").await;
// The test might fail due to network issues or rate limiting
// so we just check that the function executes without panic
match result {
Ok(tokenizer) => {
assert!(tokenizer.vocab_size() > 0);
println!("Successfully downloaded and created tokenizer");
}
Err(e) => {
println!("Download failed (this might be expected): {}", e);
// Don't fail the test - network issues shouldn't break CI
}
}
}
}

View File

@@ -0,0 +1,238 @@
use hf_hub::api::tokio::ApiBuilder;
use std::env;
use std::path::{Path, PathBuf};
const IGNORED: [&str; 5] = [
".gitattributes",
"LICENSE",
"LICENSE.txt",
"README.md",
"USE_POLICY.md",
];
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
/// Checks if a file is a model weight file
fn is_weight_file(filename: &str) -> bool {
filename.ends_with(".bin")
|| filename.ends_with(".safetensors")
|| filename.ends_with(".h5")
|| filename.ends_with(".msgpack")
|| filename.ends_with(".ckpt.index")
}
/// Checks if a file is an image file
fn is_image(filename: &str) -> bool {
filename.ends_with(".png")
|| filename.ends_with("PNG")
|| filename.ends_with(".jpg")
|| filename.ends_with("JPG")
|| filename.ends_with(".jpeg")
|| filename.ends_with("JPEG")
}
/// Checks if a file is a tokenizer file
fn is_tokenizer_file(filename: &str) -> bool {
filename.ends_with("tokenizer.json")
|| filename.ends_with("tokenizer_config.json")
|| filename.ends_with("special_tokens_map.json")
|| filename.ends_with("vocab.json")
|| filename.ends_with("merges.txt")
|| filename.ends_with(".model") // SentencePiece models
|| filename.ends_with(".tiktoken")
}
/// Attempt to download tokenizer files from Hugging Face
/// Returns the directory containing the downloaded tokenizer files
pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
let model_id = model_id.as_ref();
let token = env::var(HF_TOKEN_ENV_VAR).ok();
let api = ApiBuilder::new()
.with_progress(true)
.with_token(token)
.build()?;
let model_name = model_id.display().to_string();
let repo = api.model(model_name.clone());
let info = match repo.info().await {
Ok(info) => info,
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
model_name,
e
));
}
};
if info.siblings.is_empty() {
return Err(anyhow::anyhow!(
"Model '{}' exists but contains no downloadable files.",
model_name
));
}
let mut cache_dir = None;
let mut tokenizer_files_found = false;
// First, identify all tokenizer files to download
let tokenizer_files: Vec<_> = info
.siblings
.iter()
.filter(|sib| {
!IGNORED.contains(&sib.rfilename.as_str())
&& !is_image(&sib.rfilename)
&& !is_weight_file(&sib.rfilename)
&& is_tokenizer_file(&sib.rfilename)
})
.collect();
if tokenizer_files.is_empty() {
return Err(anyhow::anyhow!(
"No tokenizer files found for model '{}'.",
model_name
));
}
// Download all tokenizer files
for sib in tokenizer_files {
match repo.get(&sib.rfilename).await {
Ok(path) => {
if cache_dir.is_none() {
cache_dir = path.parent().map(|p| p.to_path_buf());
}
tokenizer_files_found = true;
}
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to download tokenizer file '{}' from model '{}': {}",
sib.rfilename,
model_name,
e
));
}
}
}
if !tokenizer_files_found {
return Err(anyhow::anyhow!(
"No tokenizer files could be downloaded for model '{}'.",
model_name
));
}
match cache_dir {
Some(dir) => Ok(dir),
None => Err(anyhow::anyhow!(
"Invalid HF cache path for model '{}'",
model_name
)),
}
}
/// Attempt to download a model from Hugging Face (including weights)
/// Returns the directory it is in
/// If ignore_weights is true, model weight files will be skipped
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let name = name.as_ref();
let token = env::var(HF_TOKEN_ENV_VAR).ok();
let api = ApiBuilder::new()
.with_progress(true)
.with_token(token)
.build()?;
let model_name = name.display().to_string();
let repo = api.model(model_name.clone());
let info = match repo.info().await {
Ok(info) => info,
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
model_name,
e
));
}
};
if info.siblings.is_empty() {
return Err(anyhow::anyhow!(
"Model '{}' exists but contains no downloadable files.",
model_name
));
}
let mut p = PathBuf::new();
let mut files_downloaded = false;
for sib in info.siblings {
if IGNORED.contains(&sib.rfilename.as_str()) || is_image(&sib.rfilename) {
continue;
}
// If ignore_weights is true, skip weight files
if ignore_weights && is_weight_file(&sib.rfilename) {
continue;
}
match repo.get(&sib.rfilename).await {
Ok(path) => {
p = path;
files_downloaded = true;
}
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to download file '{}' from model '{}': {}",
sib.rfilename,
model_name,
e
));
}
}
}
if !files_downloaded {
let file_type = if ignore_weights {
"non-weight"
} else {
"valid"
};
return Err(anyhow::anyhow!(
"No {} files found for model '{}'.",
file_type,
model_name
));
}
match p.parent() {
Some(p) => Ok(p.to_path_buf()),
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_tokenizer_file() {
assert!(is_tokenizer_file("tokenizer.json"));
assert!(is_tokenizer_file("tokenizer_config.json"));
assert!(is_tokenizer_file("special_tokens_map.json"));
assert!(is_tokenizer_file("vocab.json"));
assert!(is_tokenizer_file("merges.txt"));
assert!(is_tokenizer_file("spiece.model"));
assert!(!is_tokenizer_file("model.bin"));
assert!(!is_tokenizer_file("README.md"));
}
#[test]
fn test_is_weight_file() {
assert!(is_weight_file("model.bin"));
assert!(is_weight_file("model.safetensors"));
assert!(is_weight_file("pytorch_model.bin"));
assert!(!is_weight_file("tokenizer.json"));
assert!(!is_weight_file("config.json"));
}
}

View File

@@ -0,0 +1,234 @@
use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
use anyhow::{Error, Result};
use std::collections::HashMap;
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
tokenizer: HfTokenizer,
special_tokens: SpecialTokens,
vocab: HashMap<String, TokenIdType>,
reverse_vocab: HashMap<TokenIdType, String>,
chat_template: Option<String>,
}
impl HuggingFaceTokenizer {
/// Create a tokenizer from a HuggingFace tokenizer JSON file
pub fn from_file(file_path: &str) -> Result<Self> {
Self::from_file_with_chat_template(file_path, None)
}
/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
pub fn from_file_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Self> {
let tokenizer = HfTokenizer::from_file(file_path)
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
// Extract special tokens
let special_tokens = Self::extract_special_tokens(&tokenizer);
// Build vocab mappings
let vocab = tokenizer.get_vocab(false);
let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
.collect();
// Load chat template
let chat_template = if let Some(template_path) = chat_template_path {
// Load from specified .jinja file
Self::load_chat_template_from_file(template_path)?
} else {
// Try to load from tokenizer_config.json
Self::load_chat_template(file_path)
};
Ok(HuggingFaceTokenizer {
tokenizer,
special_tokens,
vocab,
reverse_vocab,
chat_template,
})
}
/// Create from an existing HuggingFace tokenizer
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
let special_tokens = Self::extract_special_tokens(&tokenizer);
let vocab = tokenizer.get_vocab(false);
let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
.collect();
HuggingFaceTokenizer {
tokenizer,
special_tokens,
vocab,
reverse_vocab,
chat_template: None,
}
}
/// Extract special tokens from the tokenizer
fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
// Try to get special tokens from the tokenizer
// This is a simplified version - actual implementation would need to handle various formats
let vocab = tokenizer.get_vocab(true);
let find_token = |patterns: &[&str]| -> Option<String> {
for pattern in patterns {
if vocab.contains_key(*pattern) {
return Some(pattern.to_string());
}
}
None
};
SpecialTokens {
bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
additional_special_tokens: vec![],
}
}
/// Try to load chat template from tokenizer_config.json
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
// Try to find tokenizer_config.json in the same directory
let path = std::path::Path::new(tokenizer_path);
let dir = path.parent()?;
let config_path = dir.join("tokenizer_config.json");
if config_path.exists() {
if let Ok(template) =
super::chat_template::load_chat_template_from_config(config_path.to_str()?)
{
return template;
}
}
None
}
/// Load chat template from a .jinja file
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
use std::fs;
let content = fs::read_to_string(template_path)
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
// Clean up the template (similar to Python implementation)
let template = content.trim().replace("\\n", "\n");
Ok(Some(template))
}
/// Set or override the chat template
pub fn set_chat_template(&mut self, template: String) {
self.chat_template = Some(template);
}
/// Apply chat template if available
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
add_generation_prompt: bool,
) -> Result<String> {
if let Some(ref template) = self.chat_template {
let processor = ChatTemplateProcessor::new(
template.clone(),
self.special_tokens.bos_token.clone(),
self.special_tokens.eos_token.clone(),
);
processor.apply_chat_template(messages, add_generation_prompt)
} else {
// Fallback to simple formatting if no template is available
let mut result = String::new();
for msg in messages {
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
if add_generation_prompt {
result.push_str("assistant: ");
}
Ok(result)
}
}
}
impl Encoder for HuggingFaceTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
self.tokenizer
.encode(input, false)
.map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
.map(|encoding| Encoding::Hf(Box::new(encoding)))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
let encodings = self
.tokenizer
.encode_batch(inputs.to_vec(), false)
.map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
Ok(encodings
.into_iter()
.map(|e| Encoding::Hf(Box::new(e)))
.collect())
}
}
impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
self.tokenizer
.decode(token_ids, skip_special_tokens)
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
}
}
impl TokenizerTrait for HuggingFaceTokenizer {
fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(false)
}
fn get_special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
self.reverse_vocab.get(&id).cloned()
}
}
#[cfg(test)]
mod tests {
use super::ChatMessage;
#[test]
fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(msg.role, "system");
assert_eq!(msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
}
// Note: Actual tokenizer tests would require a real tokenizer file
// These would be integration tests rather than unit tests
}

View File

@@ -0,0 +1,112 @@
//! Mock tokenizer implementation for testing
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
use anyhow::Result;
use std::collections::HashMap;
/// Mock tokenizer for testing purposes
pub struct MockTokenizer {
vocab: HashMap<String, u32>,
reverse_vocab: HashMap<u32, String>,
special_tokens: SpecialTokens,
}
impl Default for MockTokenizer {
fn default() -> Self {
Self::new()
}
}
impl MockTokenizer {
pub fn new() -> Self {
let mut vocab = HashMap::new();
let mut reverse_vocab = HashMap::new();
// Add some basic tokens
let tokens = vec![
("Hello", 1),
("world", 2),
("test", 3),
("token", 4),
(" ", 5),
(".", 6),
("<eos>", 999),
("<bos>", 1000),
];
for (token, id) in tokens {
vocab.insert(token.to_string(), id);
reverse_vocab.insert(id, token.to_string());
}
let special_tokens = SpecialTokens {
bos_token: Some("<bos>".to_string()),
eos_token: Some("<eos>".to_string()),
unk_token: Some("<unk>".to_string()),
sep_token: None,
pad_token: None,
cls_token: None,
mask_token: None,
additional_special_tokens: vec![],
};
Self {
vocab,
reverse_vocab,
special_tokens,
}
}
}
impl Encoder for MockTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
// Simple word-based tokenization for testing
let tokens: Vec<u32> = input
.split_whitespace()
.filter_map(|word| self.vocab.get(word).copied())
.collect();
Ok(Encoding::Sp(tokens))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
inputs.iter().map(|input| self.encode(input)).collect()
}
}
impl Decoder for MockTokenizer {
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
let tokens: Vec<String> = token_ids
.iter()
.filter_map(|id| {
self.reverse_vocab.get(id).and_then(|token| {
if skip_special_tokens && (token == "<eos>" || token == "<bos>") {
None
} else {
Some(token.clone())
}
})
})
.collect();
Ok(tokens.join(" "))
}
}
impl TokenizerTrait for MockTokenizer {
fn vocab_size(&self) -> usize {
self.vocab.len()
}
fn get_special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.reverse_vocab.get(&id).cloned()
}
}

View File

@@ -0,0 +1,123 @@
use anyhow::Result;
use std::ops::Deref;
use std::sync::Arc;
pub mod factory;
pub mod hub;
pub mod mock;
pub mod sequence;
pub mod stop;
pub mod stream;
pub mod traits;
// Feature-gated modules
pub mod chat_template;
pub mod huggingface;
pub mod tiktoken;
#[cfg(test)]
mod tests;
// Re-exports
pub use factory::{
create_tokenizer, create_tokenizer_async, create_tokenizer_from_file,
create_tokenizer_with_chat_template, TokenizerType,
};
pub use sequence::Sequence;
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
pub use huggingface::HuggingFaceTokenizer;
pub use chat_template::ChatMessage;
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)]
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
impl Tokenizer {
/// Create a tokenizer from a file path
pub fn from_file(file_path: &str) -> Result<Tokenizer> {
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
}
/// Create a tokenizer from a file path with an optional chat template
pub fn from_file_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Tokenizer> {
Ok(Tokenizer(factory::create_tokenizer_with_chat_template(
file_path,
chat_template_path,
)?))
}
/// Create a tokenizer from an Arc<dyn Tokenizer>
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
Tokenizer(tokenizer)
}
/// Create a stateful sequence object for decoding token_ids into text
pub fn decode_stream(
&self,
prompt_token_ids: &[u32],
skip_special_tokens: bool,
) -> DecodeStream {
DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
}
/// Direct encode method
pub fn encode(&self, input: &str) -> Result<Encoding> {
self.0.encode(input)
}
/// Direct batch encode method
pub fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
self.0.encode_batch(inputs)
}
/// Direct decode method
pub fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
self.0.decode(token_ids, skip_special_tokens)
}
/// Get vocabulary size
pub fn vocab_size(&self) -> usize {
self.0.vocab_size()
}
/// Get special tokens
pub fn get_special_tokens(&self) -> &SpecialTokens {
self.0.get_special_tokens()
}
/// Convert token string to ID
pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.0.token_to_id(token)
}
/// Convert ID to token string
pub fn id_to_token(&self, id: u32) -> Option<String> {
self.0.id_to_token(id)
}
}
impl Deref for Tokenizer {
type Target = Arc<dyn traits::Tokenizer>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
Tokenizer(tokenizer)
}
}

View File

@@ -0,0 +1,238 @@
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
use anyhow::Result;
use std::sync::Arc;
/// Maintains state for an ongoing sequence of tokens and their decoded text
/// This provides a cleaner abstraction for managing token sequences
pub struct Sequence {
/// The tokenizer used for encoding/decoding
tokenizer: Arc<dyn TokenizerTrait>,
/// The current sequence of token ids
token_ids: Vec<TokenIdType>,
/// The position in the current sequence the last decoded token completed
prefix_offset: usize,
/// Current position in the sequence
read_offset: usize,
}
impl std::fmt::Debug for Sequence {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sequence")
.field("tokenizer", &"Arc<dyn Tokenizer>")
.field(
"token_ids",
&format_args!("{}", {
let token_ids = self.token_ids();
if token_ids.len() <= 20 {
format!("{:?}", token_ids)
} else {
let first_ten = &token_ids[..10];
let last_ten = &token_ids[token_ids.len() - 10..];
format!("{:?} ... {:?}", first_ten, last_ten)
}
}),
)
.field("prefix_offset", &self.prefix_offset)
.field("read_offset", &self.read_offset)
.field("token count", &self.token_ids.len())
.finish()
}
}
impl Sequence {
/// Create a new empty sequence
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
Self {
tokenizer,
token_ids: Vec::new(),
prefix_offset: 0,
read_offset: 0,
}
}
/// Create a sequence with initial tokens
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
let len = token_ids.len();
Self {
tokenizer,
token_ids,
prefix_offset: 0,
read_offset: len,
}
}
/// Check if the sequence is empty
pub fn is_empty(&self) -> bool {
self.token_ids.is_empty()
}
/// Get the length of the sequence
pub fn len(&self) -> usize {
self.token_ids.len()
}
/// Clear the sequence
pub fn clear(&mut self) {
self.token_ids.clear();
self.prefix_offset = 0;
self.read_offset = 0;
}
/// Append text to the sequence by encoding it
pub fn append_text(&mut self, input: &str) -> Result<()> {
let encoding = self.tokenizer.encode(input)?;
self.token_ids.extend(encoding.token_ids());
Ok(())
}
/// Append a single token to the sequence and return newly decoded text
/// Based on HuggingFace TGI incremental decoding
pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
// Store the old read offset before adding the new token
let old_read_offset = self.read_offset;
self.token_ids.push(token_id);
self.read_offset = self.token_ids.len();
// If this is the first token or we're at the beginning, decode everything
if self.prefix_offset == 0 && old_read_offset == 0 {
let text = self.tokenizer.decode(&self.token_ids, false)?;
if text.ends_with("<EFBFBD>") {
// Incomplete UTF-8 sequence, wait for more tokens
return Ok(String::new());
}
self.prefix_offset = 0;
return Ok(text);
}
// Decode the text up to the previous position
let prefix_text = self
.tokenizer
.decode(&self.token_ids[self.prefix_offset..old_read_offset], false)?;
// Decode the text including the new token
let new_text = self
.tokenizer
.decode(&self.token_ids[self.prefix_offset..], false)?;
// Handle multi-byte character boundaries
let mut prefix_text_len = prefix_text.len();
while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
prefix_text_len -= 1;
}
if new_text.len() > prefix_text.len() {
if new_text.ends_with("<EFBFBD>") {
// Incomplete UTF-8 sequence, wait for more tokens
return Ok(String::new());
} else {
// Return the new text portion
let incremental_text = new_text[prefix_text_len..].to_string().replace("<EFBFBD>", "");
self.prefix_offset = old_read_offset;
return Ok(incremental_text);
}
}
Ok(String::new())
}
/// Get a reference to the tokenizer
pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
&self.tokenizer
}
/// Get the current token ids
pub fn token_ids(&self) -> &[TokenIdType] {
&self.token_ids
}
/// Decode the entire sequence to text
pub fn text(&self) -> Result<String> {
self.tokenizer.decode(&self.token_ids, false)
}
/// Get the prefix offset
pub fn prefix_offset(&self) -> usize {
self.prefix_offset
}
/// Get the read offset
pub fn read_offset(&self) -> usize {
self.read_offset
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::mock::MockTokenizer;
#[test]
fn test_sequence_new() {
let tokenizer = Arc::new(MockTokenizer::new());
let seq = Sequence::new(tokenizer);
assert!(seq.is_empty());
assert_eq!(seq.len(), 0);
}
#[test]
fn test_sequence_append_text() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
seq.append_text("Hello").unwrap();
assert!(!seq.is_empty());
assert!(!seq.is_empty());
let text = seq.text().unwrap();
assert_eq!(text, "Hello");
}
#[test]
fn test_sequence_append_token() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer.clone());
// Start with an empty sequence and append token 1 ("Hello")
let text1 = seq.append_token(1).unwrap();
assert_eq!(text1, "Hello");
// Now append token 2 ("world")
// The mock tokenizer will decode [1, 2] as "Hello world" (with a space)
let text2 = seq.append_token(2).unwrap();
// The incremental text should be " world" (with the space that the mock tokenizer adds)
assert_eq!(text2, " world");
// Verify the full text
assert_eq!(seq.text().unwrap(), "Hello world");
}
#[test]
fn test_sequence_clear() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
seq.append_text("Hello world").unwrap();
assert!(!seq.is_empty());
seq.clear();
assert!(seq.is_empty());
assert_eq!(seq.len(), 0);
assert_eq!(seq.prefix_offset(), 0);
assert_eq!(seq.read_offset(), 0);
}
#[test]
fn test_sequence_debug() {
let tokenizer = Arc::new(MockTokenizer::new());
let mut seq = Sequence::new(tokenizer);
seq.append_text("Test").unwrap();
let debug_str = format!("{:?}", seq);
assert!(debug_str.contains("Sequence"));
assert!(debug_str.contains("token count"));
}
}

View File

@@ -0,0 +1,506 @@
use super::traits::{self, TokenIdType};
use anyhow::Result;
use std::collections::HashSet;
use std::sync::Arc;
/// Output from the sequence decoder
#[derive(Debug, Clone, PartialEq)]
pub enum SequenceDecoderOutput {
/// Normal text output
Text(String),
/// Text is being held due to partial stop sequence match
Held,
/// Stop sequence matched (hidden - not included in output)
Stopped,
/// Stop sequence matched with text (visible - included in output)
StoppedWithText(String),
}
/// Configuration for stop sequences
#[derive(Debug, Clone, Default)]
pub struct StopSequenceConfig {
/// Token IDs that trigger a stop
pub stop_tokens: HashSet<TokenIdType>,
/// String sequences that trigger a stop
pub stop_sequences: Vec<String>,
/// Token IDs for visible stops (included in output)
pub visible_stop_tokens: HashSet<TokenIdType>,
/// String sequences for visible stops (included in output)
pub visible_stop_sequences: Vec<String>,
}
impl StopSequenceConfig {
/// Builder pattern - add a stop token
pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
self.stop_tokens.insert(token_id);
self
}
/// Builder pattern - add a stop sequence
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.stop_sequences.push(sequence.into());
self
}
/// Builder pattern - add a visible stop token
pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.visible_stop_tokens.insert(token_id);
self
}
/// Builder pattern - add a visible stop sequence
pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.visible_stop_sequences.push(sequence.into());
self
}
}
/// Decoder that handles stop sequences
pub struct StopSequenceDecoder {
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
/// Buffer for partial matches (the "jail")
jail_buffer: String,
/// Accumulated tokens
token_buffer: Vec<TokenIdType>,
/// Offset where the prefix text starts (for context)
prefix_offset: usize,
/// Offset marking the end of previously decoded text
read_offset: usize,
/// Whether we've stopped
stopped: bool,
skip_special_tokens: bool,
}
impl StopSequenceDecoder {
/// Create a new stop sequence decoder
pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
skip_special_tokens: bool,
) -> Self {
StopSequenceDecoder {
tokenizer,
config,
jail_buffer: String::new(),
token_buffer: Vec::new(),
prefix_offset: 0,
read_offset: 0,
stopped: false,
skip_special_tokens,
}
}
/// Process a single token
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
if self.stopped {
return Ok(SequenceDecoderOutput::Stopped);
}
// Check for token-level stops first
if self.config.stop_tokens.contains(&token_id) {
self.stopped = true;
// Flush any jailed text before stopping
if !self.jail_buffer.is_empty() {
let output = self.jail_buffer.clone();
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
return Ok(SequenceDecoderOutput::Stopped);
}
if self.config.visible_stop_tokens.contains(&token_id) {
self.stopped = true;
// Include jailed text plus the stop token
let stop_text = self
.tokenizer
.decode(&[token_id], self.skip_special_tokens)?;
let output = format!("{}{}", self.jail_buffer, stop_text);
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
// Add token to buffer
self.token_buffer.push(token_id);
// Use incremental decoding like DecodeStream
// First decode the previous context (what we've already output)
let prefix_text = if self.read_offset > self.prefix_offset {
self.tokenizer.decode(
&self.token_buffer[self.prefix_offset..self.read_offset],
self.skip_special_tokens,
)?
} else {
String::new()
};
// Now decode from prefix to current position
let new_full_text = self.tokenizer.decode(
&self.token_buffer[self.prefix_offset..],
self.skip_special_tokens,
)?;
// Check for incomplete UTF-8 sequence
if new_full_text.ends_with("<EFBFBD>") {
// Wait for more tokens to complete the sequence
return Ok(SequenceDecoderOutput::Held);
}
// Calculate only the NEW text since last successful decode
let new_text = if new_full_text.len() > prefix_text.len() {
&new_full_text[prefix_text.len()..]
} else {
// No new text produced (can happen with special tokens)
return Ok(SequenceDecoderOutput::Held);
};
// Combine jail buffer with new text for checking
let check_text = format!("{}{}", self.jail_buffer, new_text);
// Check for complete stop sequences
for stop_seq in &self.config.stop_sequences {
if let Some(pos) = check_text.find(stop_seq) {
self.stopped = true;
// Output text before the stop sequence
let output = check_text[..pos].to_string();
self.jail_buffer.clear();
return Ok(if output.is_empty() {
SequenceDecoderOutput::Stopped
} else {
SequenceDecoderOutput::StoppedWithText(output)
});
}
}
// Check for visible stop sequences
for stop_seq in &self.config.visible_stop_sequences {
if let Some(pos) = check_text.find(stop_seq) {
self.stopped = true;
// Include the stop sequence in output
let end_pos = pos + stop_seq.len();
let output = check_text[..end_pos].to_string();
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
}
// Check for partial matches at the end of check_text
let mut partial_match_len = 0;
for stop_seq in self
.config
.stop_sequences
.iter()
.chain(&self.config.visible_stop_sequences)
{
// Check all possible suffixes that could be a prefix of stop_seq
for i in 1..=check_text.len().min(stop_seq.len() - 1) {
let suffix = &check_text[check_text.len() - i..];
if stop_seq.starts_with(suffix) {
partial_match_len = partial_match_len.max(i);
}
}
}
if partial_match_len > 0 {
// Split: output safe text, jail the potential match
let safe_end = check_text.len() - partial_match_len;
let safe_text = &check_text[..safe_end];
self.jail_buffer = check_text[safe_end..].to_string();
// Update offsets for next iteration
self.prefix_offset = self.read_offset;
self.read_offset = self.token_buffer.len();
if safe_text.is_empty() {
Ok(SequenceDecoderOutput::Held)
} else {
Ok(SequenceDecoderOutput::Text(safe_text.to_string()))
}
} else {
// No partial matches - output everything
self.jail_buffer.clear();
// Update offsets for next iteration
self.prefix_offset = self.read_offset;
self.read_offset = self.token_buffer.len();
Ok(SequenceDecoderOutput::Text(check_text))
}
}
/// Process multiple tokens
pub fn process_tokens(
&mut self,
token_ids: &[TokenIdType],
) -> Result<Vec<SequenceDecoderOutput>> {
let mut outputs = Vec::new();
for &token_id in token_ids {
outputs.push(self.process_token(token_id)?);
}
Ok(outputs)
}
/// Flush any held text
pub fn flush(&mut self) -> SequenceDecoderOutput {
if !self.jail_buffer.is_empty() {
let output = self.jail_buffer.clone();
self.jail_buffer.clear();
SequenceDecoderOutput::Text(output)
} else {
SequenceDecoderOutput::Text(String::new())
}
}
/// Check if decoding has stopped
pub fn is_stopped(&self) -> bool {
self.stopped
}
/// Reset the decoder state
pub fn reset(&mut self) {
self.jail_buffer.clear();
self.token_buffer.clear();
self.prefix_offset = 0;
self.read_offset = 0;
self.stopped = false;
}
}
/// Builder for StopSequenceDecoder
pub struct StopSequenceDecoderBuilder {
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
skip_special_tokens: bool,
}
impl StopSequenceDecoderBuilder {
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
StopSequenceDecoderBuilder {
tokenizer,
config: StopSequenceConfig::default(),
skip_special_tokens: true,
}
}
pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.stop_tokens.insert(token_id);
self
}
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.config.stop_sequences.push(sequence.into());
self
}
pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.visible_stop_tokens.insert(token_id);
self
}
pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.config.visible_stop_sequences.push(sequence.into());
self
}
pub fn skip_special_tokens(mut self, skip: bool) -> Self {
self.skip_special_tokens = skip;
self
}
pub fn build(self) -> StopSequenceDecoder {
StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::mock::MockTokenizer;
#[test]
fn test_stop_token_detection() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_token(999); // <eos> token
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process tokens before stop
let result = decoder.process_token(1).unwrap(); // "Hello"
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
// Process stop token
let result = decoder.process_token(999).unwrap(); // <eos>
assert_eq!(result, SequenceDecoderOutput::Stopped);
// Further tokens should also return Stopped
let result = decoder.process_token(2).unwrap();
assert_eq!(result, SequenceDecoderOutput::Stopped);
}
#[test]
fn test_visible_stop_token() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_visible_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(999).unwrap();
assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
}
#[test]
fn test_builder_pattern() {
let tokenizer = Arc::new(MockTokenizer::new());
let decoder = StopSequenceDecoderBuilder::new(tokenizer)
.stop_token(999)
.stop_sequence("STOP")
.visible_stop_token(1000)
.skip_special_tokens(true)
.build();
assert!(!decoder.is_stopped());
}
#[test]
fn test_incremental_decoding_no_repetition() {
// This test verifies the critical fix: no repeated output
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default();
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process tokens one by one and collect outputs
let mut outputs = Vec::new();
// Token 1: "Hello"
let result = decoder.process_token(1).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
// Token 2: "world"
let result = decoder.process_token(2).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
// Token 3: "test"
let result = decoder.process_token(3).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
// CRITICAL: Each output should be unique (no accumulation)
// The fix ensures we only output NEW text, not accumulated text
assert_eq!(outputs.len(), 3);
// Verify no text is repeated
for i in 0..outputs.len() {
for j in i + 1..outputs.len() {
// No output should contain another (no accumulation)
assert!(!outputs[j].contains(&outputs[i]));
}
}
}
#[test]
fn test_stop_sequence_detection() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("test");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process "Hello world"
decoder.process_token(1).unwrap(); // "Hello"
decoder.process_token(2).unwrap(); // "world"
// Process "test" which should trigger stop
let result = decoder.process_token(3).unwrap(); // "test"
// Should stop when we hit "test"
assert!(matches!(
result,
SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
));
}
#[test]
fn test_flush_after_partial() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process a token
decoder.process_token(1).unwrap(); // "Hello"
// Flush should return any remaining text in jail
let result = decoder.flush();
// After processing, flush should work
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
}
#[test]
fn test_reset_functionality() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process and stop
decoder.process_token(1).unwrap();
decoder.process_token(999).unwrap();
assert!(decoder.is_stopped());
// Reset should clear everything
decoder.reset();
assert!(!decoder.is_stopped());
// Should be able to process again
let result = decoder.process_token(2).unwrap();
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
}
#[test]
fn test_visible_stop_sequence() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process "Hello"
decoder.process_token(1).unwrap();
// Process "world" - should include it in output
let result = decoder.process_token(2).unwrap();
if let SequenceDecoderOutput::StoppedWithText(text) = result {
// Should include "world" in the output
assert!(text.contains("world"));
} else {
panic!("Expected StoppedWithText with visible stop sequence");
}
}
#[test]
fn test_multiple_tokens_processing() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default();
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process multiple tokens at once
let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
// Should get results for each token
assert_eq!(results.len(), 3);
// Each result should be Text (no stops configured)
for result in results {
assert!(matches!(
result,
SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
));
}
}
}

View File

@@ -0,0 +1,105 @@
// src/tokenizer/stream.rs
use super::traits::{self, TokenIdType};
use anyhow::Result;
use std::sync::Arc;
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
/// DecodeStream will keep the state necessary to produce individual chunks of
/// strings given an input stream of token_ids
pub struct DecodeStream {
/// The tokenizer used to decode token_ids
tokenizer: Arc<dyn traits::Tokenizer>,
skip_special_tokens: bool,
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks
all_token_ids: Vec<TokenIdType>,
prefix_offset: usize,
read_offset: usize,
}
impl DecodeStream {
pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>,
prompt_token_ids: &[TokenIdType],
skip_special_tokens: bool,
) -> Self {
let num_input_tokens = prompt_token_ids.len();
let prompt_token_ids = prompt_token_ids.to_vec();
Self {
tokenizer,
skip_special_tokens,
all_token_ids: prompt_token_ids,
prefix_offset: num_input_tokens
.saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
read_offset: num_input_tokens,
}
}
/// Step appends a token_id to the internal state and tries to produce a text chunk.
/// Returning `None` means the given id is not enough to produce a chunk.
pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
self.all_token_ids.push(id);
let prefix_text = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..self.read_offset],
self.skip_special_tokens,
)?;
let new_text = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..],
self.skip_special_tokens,
)?;
if new_text.len() > prefix_text.len() && !new_text.ends_with("<EFBFBD>") {
let new_text = new_text[prefix_text.len()..].to_string();
self.prefix_offset = self.read_offset;
self.read_offset = self.all_token_ids.len();
Ok(Some(new_text))
} else {
Ok(None)
}
}
/// Process multiple tokens at once
pub fn step_batch(&mut self, token_ids: &[u32]) -> Result<Vec<String>> {
let mut chunks = Vec::new();
for &token_id in token_ids {
if let Some(text) = self.step(token_id)? {
chunks.push(text);
}
}
Ok(chunks)
}
/// Force flush any remaining text
pub fn flush(&mut self) -> Result<Option<String>> {
if self.read_offset < self.all_token_ids.len() {
let remaining = self.tokenizer.decode(
&self.all_token_ids[self.read_offset..],
self.skip_special_tokens,
)?;
self.read_offset = self.all_token_ids.len();
if !remaining.is_empty() {
return Ok(Some(remaining));
}
}
Ok(None)
}
/// Get all tokens processed so far
pub fn tokens(&self) -> &[u32] {
&self.all_token_ids
}
}

View File

@@ -0,0 +1,143 @@
#[cfg(test)]
use super::*;
#[cfg(test)]
use std::sync::Arc;
#[test]
fn test_mock_tokenizer_encode() {
let tokenizer = mock::MockTokenizer::new();
let encoding = tokenizer.encode("Hello world").unwrap();
let token_ids = encoding.token_ids();
assert_eq!(token_ids, &[1, 2]); // "Hello" -> 1, "world" -> 2
}
#[test]
fn test_mock_tokenizer_decode() {
let tokenizer = mock::MockTokenizer::new();
let text = tokenizer.decode(&[1, 2], false).unwrap();
assert_eq!(text, "Hello world");
}
#[test]
fn test_mock_tokenizer_decode_skip_special() {
let tokenizer = mock::MockTokenizer::new();
// With special tokens
let text = tokenizer.decode(&[1000, 1, 2, 999], false).unwrap();
assert_eq!(text, "<bos> Hello world <eos>");
// Without special tokens
let text = tokenizer.decode(&[1000, 1, 2, 999], true).unwrap();
assert_eq!(text, "Hello world");
}
#[test]
fn test_tokenizer_wrapper() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
// Test encoding
let encoding = tokenizer.encode("Hello world").unwrap();
assert_eq!(encoding.token_ids(), &[1, 2]);
// Test decoding
let text = tokenizer.decode(&[1, 2], false).unwrap();
assert_eq!(text, "Hello world");
// Test vocab size
assert_eq!(tokenizer.vocab_size(), 8);
// Test token to ID
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
assert_eq!(tokenizer.token_to_id("unknown"), None);
// Test ID to token
assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string()));
assert_eq!(tokenizer.id_to_token(9999), None);
}
#[test]
fn test_decode_stream_basic() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
// Create a decode stream with initial tokens
let initial_tokens = vec![1, 2]; // "Hello world"
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
// Add a new token
let result = stream.step(3).unwrap(); // "test"
// Since we're using a mock, the actual incremental behavior depends on implementation
// For now, we just verify it doesn't crash
assert!(result.is_some() || result.is_none());
}
#[test]
fn test_decode_stream_flush() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let initial_tokens = vec![1];
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
// Add tokens
stream.step(2).unwrap();
stream.step(3).unwrap();
// Flush remaining
let flushed = stream.flush().unwrap();
// The flush behavior depends on the implementation
assert!(flushed.is_some() || flushed.is_none());
}
#[test]
fn test_special_tokens() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let special_tokens = tokenizer.get_special_tokens();
assert_eq!(special_tokens.bos_token, Some("<bos>".to_string()));
assert_eq!(special_tokens.eos_token, Some("<eos>".to_string()));
assert_eq!(special_tokens.unk_token, Some("<unk>".to_string()));
assert!(special_tokens.sep_token.is_none());
assert!(special_tokens.pad_token.is_none());
}
#[test]
fn test_batch_encode() {
let tokenizer = mock::MockTokenizer::new();
let inputs = vec!["Hello", "world", "test"];
let encodings = tokenizer.encode_batch(&inputs).unwrap();
assert_eq!(encodings.len(), 3);
assert_eq!(encodings[0].token_ids(), &[1]); // "Hello" -> 1
assert_eq!(encodings[1].token_ids(), &[2]); // "world" -> 2
assert_eq!(encodings[2].token_ids(), &[3]); // "test" -> 3
}
#[test]
fn test_thread_safety() {
use std::thread;
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
// Spawn multiple threads that use the same tokenizer
let handles: Vec<_> = (0..10)
.map(|i| {
let tokenizer_clone = tokenizer.clone();
thread::spawn(move || {
let text = "Hello test".to_string();
let encoding = tokenizer_clone.encode(&text).unwrap();
let decoded = tokenizer_clone.decode(encoding.token_ids(), false).unwrap();
assert!(decoded.contains("Hello") || decoded.contains("test"));
i
})
})
.collect();
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap();
}
}

View File

@@ -0,0 +1,276 @@
use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
use anyhow::{Error, Result};
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
/// Tiktoken tokenizer wrapper for OpenAI GPT models
pub struct TiktokenTokenizer {
tokenizer: CoreBPE,
#[allow(dead_code)]
model: TiktokenModel,
special_tokens: SpecialTokens,
vocab_size: usize,
}
/// Supported Tiktoken models
#[derive(Debug, Clone, Copy)]
pub enum TiktokenModel {
/// GPT-4, GPT-3.5-turbo, text-embedding-ada-002
Cl100kBase,
/// Codex models, text-davinci-002, text-davinci-003
P50kBase,
/// Use for edit models like text-davinci-edit-001, code-davinci-edit-001
P50kEdit,
/// GPT-3 models like davinci
R50kBase,
}
impl TiktokenTokenizer {
/// Create a new Tiktoken tokenizer for the specified model
pub fn new(model: TiktokenModel) -> Result<Self> {
let tokenizer =
match model {
TiktokenModel::Cl100kBase => cl100k_base()
.map_err(|e| Error::msg(format!("Failed to load cl100k_base: {}", e)))?,
TiktokenModel::P50kBase => p50k_base()
.map_err(|e| Error::msg(format!("Failed to load p50k_base: {}", e)))?,
TiktokenModel::P50kEdit => p50k_edit()
.map_err(|e| Error::msg(format!("Failed to load p50k_edit: {}", e)))?,
TiktokenModel::R50kBase => r50k_base()
.map_err(|e| Error::msg(format!("Failed to load r50k_base: {}", e)))?,
};
// Extract special tokens (tiktoken-rs doesn't expose them directly)
// We'll use common ones for GPT models
let special_tokens = Self::get_special_tokens_for_model(model);
// Get vocabulary size (this is an approximation)
let vocab_size = match model {
TiktokenModel::Cl100kBase => 100256, // cl100k has ~100k tokens
TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281, // p50k has ~50k tokens
TiktokenModel::R50kBase => 50257, // r50k has ~50k tokens
};
Ok(TiktokenTokenizer {
tokenizer,
model,
special_tokens,
vocab_size,
})
}
/// Create a tokenizer from a model string (e.g., "gpt-4", "gpt-3.5-turbo")
pub fn from_model_name(model_name: &str) -> Result<Self> {
let model = Self::model_from_name(model_name)?;
Self::new(model)
}
/// Determine the appropriate model from a model name
fn model_from_name(model_name: &str) -> Result<TiktokenModel> {
// Based on OpenAI's model-to-encoding mapping
if model_name.contains("gpt-4")
|| model_name.contains("gpt-3.5")
|| model_name.contains("turbo")
{
Ok(TiktokenModel::Cl100kBase)
} else if model_name.contains("davinci-002")
|| model_name.contains("davinci-003")
|| model_name.contains("codex")
{
Ok(TiktokenModel::P50kBase)
} else if model_name.contains("edit") {
Ok(TiktokenModel::P50kEdit)
} else if model_name.contains("davinci")
|| model_name.contains("curie")
|| model_name.contains("babbage")
|| model_name.contains("ada")
{
Ok(TiktokenModel::R50kBase)
} else {
// Return an error for unrecognized model names to prevent silent failures
Err(anyhow::anyhow!(
"Unrecognized OpenAI model name: '{}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names",
model_name
))
}
}
/// Get special tokens for a specific model
fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
// These are common special tokens for GPT models
// The actual token IDs might vary by model
match model {
TiktokenModel::Cl100kBase => SpecialTokens {
bos_token: Some("<|endoftext|>".to_string()),
eos_token: Some("<|endoftext|>".to_string()),
unk_token: None,
sep_token: None,
pad_token: Some("<|endoftext|>".to_string()),
cls_token: None,
mask_token: None,
additional_special_tokens: vec![
"<|fim_prefix|>".to_string(),
"<|fim_middle|>".to_string(),
"<|fim_suffix|>".to_string(),
"<|endofprompt|>".to_string(),
],
},
_ => SpecialTokens {
bos_token: Some("<|endoftext|>".to_string()),
eos_token: Some("<|endoftext|>".to_string()),
unk_token: None,
sep_token: None,
pad_token: Some("<|endoftext|>".to_string()),
cls_token: None,
mask_token: None,
additional_special_tokens: vec![],
},
}
}
}
impl Encoder for TiktokenTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
let tokens = self.tokenizer.encode_ordinary(input);
Ok(Encoding::Tiktoken(tokens))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
inputs.iter().map(|input| self.encode(input)).collect()
}
}
impl Decoder for TiktokenTokenizer {
fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
// tiktoken-rs 0.7.0 now uses u32 (Rank type)
self.tokenizer
.decode(token_ids.to_vec())
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
}
}
impl TokenizerTrait for TiktokenTokenizer {
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn get_special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn token_to_id(&self, _token: &str) -> Option<TokenIdType> {
// Tiktoken doesn't provide direct token-to-id mapping
// We'd need to encode the token and check if it produces a single ID
None
}
fn id_to_token(&self, _id: TokenIdType) -> Option<String> {
// Tiktoken doesn't provide direct id-to-token mapping
// We can only decode IDs to text
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tiktoken_creation() {
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
assert_eq!(tokenizer.vocab_size(), 100256);
}
#[test]
fn test_model_from_name() {
assert!(matches!(
TiktokenTokenizer::model_from_name("gpt-4").unwrap(),
TiktokenModel::Cl100kBase
));
assert!(matches!(
TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(),
TiktokenModel::Cl100kBase
));
assert!(matches!(
TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(),
TiktokenModel::P50kBase
));
assert!(matches!(
TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(),
TiktokenModel::P50kEdit
));
assert!(matches!(
TiktokenTokenizer::model_from_name("davinci").unwrap(),
TiktokenModel::R50kBase
));
}
#[test]
fn test_encode_decode() {
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
let text = "Hello, world!";
let encoding = tokenizer.encode(text).unwrap();
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, text);
}
#[test]
fn test_batch_encode() {
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
let texts = vec!["Hello", "World", "Test"];
let encodings = tokenizer.encode_batch(&texts).unwrap();
assert_eq!(encodings.len(), 3);
for (i, encoding) in encodings.iter().enumerate() {
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, texts[i]);
}
}
#[test]
fn test_special_tokens() {
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
let special_tokens = tokenizer.get_special_tokens();
assert!(special_tokens.eos_token.is_some());
assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
}
#[test]
fn test_unrecognized_model_name_returns_error() {
// Test that unrecognized model names return an error
let result = TiktokenTokenizer::from_model_name("distilgpt-2");
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Unrecognized OpenAI model name"));
}
let result = TiktokenTokenizer::from_model_name("bert-base-uncased");
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Unrecognized OpenAI model name"));
}
let result = TiktokenTokenizer::from_model_name("llama-7b");
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Unrecognized OpenAI model name"));
}
}
#[test]
fn test_recognized_model_names() {
// Test that recognized model names work correctly
assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok());
assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok());
assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok());
assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok());
}
}

View File

@@ -0,0 +1,83 @@
use anyhow::Result;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
/// Type alias for token IDs
pub type TokenIdType = u32;
/// Core encoding trait - separate from decoding for modularity
pub trait Encoder: Send + Sync {
fn encode(&self, input: &str) -> Result<Encoding>;
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
}
/// Core decoding trait - can be implemented independently
pub trait Decoder: Send + Sync {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
}
/// Combined tokenizer trait
pub trait Tokenizer: Encoder + Decoder {
fn vocab_size(&self) -> usize;
fn get_special_tokens(&self) -> &SpecialTokens;
fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
fn id_to_token(&self, id: TokenIdType) -> Option<String>;
}
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
#[derive(Debug, Clone)]
pub enum Encoding {
/// Hugging Face
Hf(Box<tokenizers::tokenizer::Encoding>),
/// Sentence Piece
Sp(Vec<TokenIdType>),
/// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
Tiktoken(Vec<TokenIdType>),
}
impl Encoding {
/// Returns a reference to token IDs - zero-copy operation
pub fn token_ids(&self) -> &[TokenIdType] {
match self {
Encoding::Hf(inner) => inner.get_ids(),
Encoding::Sp(inner) => inner,
Encoding::Tiktoken(inner) => inner,
}
}
/// Deprecated: Use token_ids() instead (kept for compatibility)
#[deprecated(since = "0.1.0", note = "Use token_ids() instead")]
pub fn token_ids_ref(&self) -> &[TokenIdType] {
self.token_ids()
}
/// Get a hash of the token IDs for caching purposes
pub fn get_hash(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}
/// Hash implementation for Encoding
impl Hash for Encoding {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Encoding::Hf(inner) => inner.get_ids().hash(state),
Encoding::Sp(inner) => inner.hash(state),
Encoding::Tiktoken(inner) => inner.hash(state),
}
}
}
#[derive(Debug, Clone)]
pub struct SpecialTokens {
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub unk_token: Option<String>,
pub sep_token: Option<String>,
pub pad_token: Option<String>,
pub cls_token: Option<String>,
pub mask_token: Option<String>,
pub additional_special_tokens: Vec<String>,
}

View File

@@ -0,0 +1,32 @@
use thiserror::Error;
/// Result type for tool parser operations
pub type ToolParserResult<T> = Result<T, ToolParserError>;
/// Errors that can occur during tool parsing
#[derive(Debug, Error)]
pub enum ToolParserError {
#[error("Parsing failed: {0}")]
ParsingFailed(String),
#[error("Model not supported: {0}")]
ModelNotSupported(String),
#[error("Parse depth exceeded: max {0}")]
DepthExceeded(usize),
#[error("Invalid JSON: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Regex error: {0}")]
RegexError(#[from] regex::Error),
#[error("Incomplete tool call")]
Incomplete,
#[error("Invalid tool name: {0}")]
InvalidToolName(String),
#[error("Token not found: {0}")]
TokenNotFound(String),
}

View File

@@ -0,0 +1,30 @@
/// Tool parser module for handling function/tool calls in model outputs
///
/// This module provides infrastructure for parsing tool calls from various model formats.
// Core modules
pub mod errors;
pub mod partial_json;
pub mod python_literal_parser;
pub mod registry;
pub mod state;
pub mod traits;
pub mod types;
// Parser implementations
pub mod parsers;
#[cfg(test)]
mod tests;
// Re-export commonly used types
pub use errors::{ToolParserError, ToolParserResult};
pub use registry::ParserRegistry;
pub use state::{ParsePhase, ParseState};
pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCall};
// Re-export parsers for convenience
pub use parsers::{
DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser,
MistralParser, PythonicParser, QwenParser, Step3Parser,
};

View File

@@ -0,0 +1,277 @@
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// DeepSeek V3 format parser for tool calls
///
/// Handles the DeepSeek V3 specific format that uses Unicode tokens:
/// `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>{name}\n```json\n{args}\n```<tool▁call▁end><tool▁calls▁end>`
///
/// Features:
/// - Unicode token delimiters
/// - JSON arguments in code blocks
/// - Support for multiple sequential tool calls
pub struct DeepSeekParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting complete tool calls
tool_call_extractor: Regex,
/// Regex for extracting function details
func_detail_extractor: Regex,
}
impl DeepSeekParser {
/// Create a new DeepSeek parser
pub fn new() -> Self {
// Use (?s) flag for DOTALL mode to handle newlines
let tool_call_pattern = r"(?s)<tool▁call▁begin>.*?<tool▁call▁end>";
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
let func_detail_pattern = r"(?s)<tool▁call▁begin>(.*?)<tool▁sep>(.*?)\n```json\n(.*?)\n```<tool▁call▁end>";
let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
Self {
partial_json: PartialJson::default(),
tool_call_extractor,
func_detail_extractor,
}
}
/// Check if text contains DeepSeek tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<tool▁calls▁begin>")
}
/// Extract all tool call blocks from text
fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> {
self.tool_call_extractor
.find_iter(text)
.map(|m| m.as_str())
.collect()
}
/// Parse a single tool call block
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
if let Some(captures) = self.func_detail_extractor.captures(block) {
// Get function type (should be "function")
let func_type = captures.get(1).map_or("", |m| m.as_str());
if func_type != "function" {
return Ok(None);
}
// Get function name
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
// Get JSON arguments
let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim();
// Parse JSON arguments
match serde_json::from_str::<Value>(json_args) {
Ok(value) => {
// Create arguments object
let args = if value.is_object() {
value
} else {
// If not an object, wrap it
serde_json::json!({ "value": value })
};
let arguments = serde_json::to_string(&args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments,
},
}))
}
Err(_) => Ok(None),
}
} else {
Ok(None)
}
}
}
impl Default for DeepSeekParser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for DeepSeekParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// Check if text contains DeepSeek format
if !self.has_tool_markers(text) {
return Ok(vec![]);
}
// Extract all tool call blocks
let tool_blocks = self.extract_tool_calls(text);
let mut tools = Vec::new();
for block in tool_blocks {
if let Some(tool) = self.parse_tool_call(block)? {
tools.push(tool);
}
}
Ok(tools)
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check for tool markers
if !self.has_tool_markers(&state.buffer) {
// No markers found, return as incomplete
return Ok(StreamResult::Incomplete);
}
// Look for start of tool calls
if let Some(start_pos) = state.buffer.find("<tool▁calls▁begin>") {
// Look for individual tool call start
let search_from = start_pos + "<tool▁calls▁begin>".len();
if let Some(call_start) = state.buffer[search_from..].find("<tool▁call▁begin>")
{
let call_start_abs = search_from + call_start;
// Look for the end of this tool call
let search_end_from = call_start_abs + "<tool▁call▁begin>".len();
if let Some(call_end) = state.buffer[search_end_from..].find("<tool▁call▁end>")
{
let call_end_abs = search_end_from + call_end + "<tool▁call▁end>".len();
// Extract and parse the complete tool call
let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
// Remove the processed part from buffer
state.buffer.drain(..call_end_abs);
return Ok(StreamResult::ToolComplete(tool));
}
} else {
// Tool call not complete yet, try to extract partial info
let partial = &state.buffer[search_end_from..];
// Try to extract function name
if let Some(sep_pos) = partial.find("<tool▁sep>") {
if let Some(_func_start) = partial[..sep_pos].rfind("function") {
// We have the function type marker
let after_sep = &partial[sep_pos + "<tool▁sep>".len()..];
// Look for function name (ends at newline before ```json)
if let Some(name_end) = after_sep.find("\n```json\n") {
let func_name = after_sep[..name_end].trim();
if !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.to_string(),
});
}
// Try to extract partial arguments
let args_start = name_end + "\n```json\n".len();
let partial_args = &after_sep[args_start..];
// Check if we can parse partial JSON
if !partial_args.is_empty() {
match self.partial_json.parse_value(partial_args) {
Ok((value, _consumed)) => {
let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
Err(_) => {
// Can't parse yet, keep buffering
}
}
}
}
}
}
}
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_deepseek_single_tool() {
let parser = DeepSeekParser::new();
let input = r#"Some text
<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
```json
{"location": "Tokyo", "units": "celsius"}
```<tool▁call▁end><tool▁calls▁end>More text"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Tokyo"));
}
#[tokio::test]
async fn test_parse_deepseek_multiple_tools() {
let parser = DeepSeekParser::new();
let input = r#"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
```json
{"location": "Tokyo"}
```<tool▁call▁end>
<tool▁call▁begin>function<tool▁sep>get_weather
```json
{"location": "Paris"}
```<tool▁call▁end><tool▁calls▁end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Tokyo"));
assert!(result[1].function.arguments.contains("Paris"));
}
#[test]
fn test_detect_format() {
let parser = DeepSeekParser::new();
assert!(parser.detect_format("<tool▁calls▁begin>"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}

View File

@@ -0,0 +1,292 @@
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// GLM-4 MoE format parser for tool calls
///
/// Handles the GLM-4 MoE specific format:
/// `<tool_call>{name}\n<arg_key>{key}</arg_key>\n<arg_value>{value}</arg_value>\n</tool_call>`
///
/// Features:
/// - XML-style tags for tool calls
/// - Key-value pairs for arguments
/// - Support for multiple sequential tool calls
pub struct Glm4MoeParser {
/// Regex for extracting complete tool calls
tool_call_extractor: Regex,
/// Regex for extracting function details
func_detail_extractor: Regex,
/// Regex for extracting argument key-value pairs
arg_extractor: Regex,
}
impl Glm4MoeParser {
/// Create a new GLM-4 MoE parser
pub fn new() -> Self {
// Use (?s) flag for DOTALL mode to handle newlines
let tool_call_pattern = r"(?s)<tool_call>.*?</tool_call>";
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
let func_detail_pattern = r"(?s)<tool_call>([^\n]*)\n(.*)</tool_call>";
let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
let arg_pattern = r"(?s)<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>";
let arg_extractor = Regex::new(arg_pattern).expect("Valid regex pattern");
Self {
tool_call_extractor,
func_detail_extractor,
arg_extractor,
}
}
/// Check if text contains GLM-4 MoE tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<tool_call>")
}
/// Parse arguments from key-value pairs
fn parse_arguments(&self, args_text: &str) -> ToolParserResult<serde_json::Map<String, Value>> {
let mut arguments = serde_json::Map::new();
for capture in self.arg_extractor.captures_iter(args_text) {
let key = capture.get(1).map_or("", |m| m.as_str()).trim();
let value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
// Try to parse the value as JSON first, fallback to string
let value = if let Ok(json_val) = serde_json::from_str::<Value>(value_str) {
json_val
} else {
// Try parsing as Python literal (similar to Python's ast.literal_eval)
if value_str == "true" || value_str == "True" {
Value::Bool(true)
} else if value_str == "false" || value_str == "False" {
Value::Bool(false)
} else if value_str == "null" || value_str == "None" {
Value::Null
} else if let Ok(num) = value_str.parse::<i64>() {
Value::Number(num.into())
} else if let Ok(num) = value_str.parse::<f64>() {
if let Some(n) = serde_json::Number::from_f64(num) {
Value::Number(n)
} else {
Value::String(value_str.to_string())
}
} else {
Value::String(value_str.to_string())
}
};
arguments.insert(key.to_string(), value);
}
Ok(arguments)
}
/// Parse a single tool call block
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
if let Some(captures) = self.func_detail_extractor.captures(block) {
// Get function name
let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
// Get arguments text
let args_text = captures.get(2).map_or("", |m| m.as_str());
// Parse arguments
let arguments = self.parse_arguments(args_text)?;
let arguments_str = serde_json::to_string(&arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("glm4_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments: arguments_str,
},
}))
} else {
Ok(None)
}
}
}
impl Default for Glm4MoeParser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for Glm4MoeParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// Check if text contains GLM-4 MoE format
if !self.has_tool_markers(text) {
return Ok(vec![]);
}
// Extract all tool call blocks
let mut tools = Vec::new();
for mat in self.tool_call_extractor.find_iter(text) {
if let Some(tool) = self.parse_tool_call(mat.as_str())? {
tools.push(tool);
}
}
Ok(tools)
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check for tool markers
if !self.has_tool_markers(&state.buffer) {
// No markers found, return as incomplete
return Ok(StreamResult::Incomplete);
}
// Look for start of tool call
if let Some(start_pos) = state.buffer.find("<tool_call>") {
// Look for the end of this tool call
let search_from = start_pos + "<tool_call>".len();
if let Some(end_pos) = state.buffer[search_from..].find("</tool_call>") {
let end_abs = search_from + end_pos + "</tool_call>".len();
// Extract and parse the complete tool call
let tool_call_text = &state.buffer[start_pos..end_abs];
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
// Remove the processed part from buffer
state.buffer.drain(..end_abs);
return Ok(StreamResult::ToolComplete(tool));
}
} else {
// Tool call not complete yet, try to extract partial info
let partial = &state.buffer[search_from..];
// Try to extract function name (first line after <tool_call>)
if let Some(name_end) = partial.find('\n') {
let func_name = partial[..name_end].trim();
if !func_name.is_empty() && !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.to_string(),
});
}
// Try to extract partial arguments
let args_text = &partial[name_end + 1..];
let partial_args = self.parse_arguments(args_text)?;
if !partial_args.is_empty() {
let args_str = serde_json::to_string(&partial_args)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_glm4_single_tool() {
let parser = Glm4MoeParser::new();
let input = r#"Some text
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2024-06-27</arg_value>
</tool_call>More text"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Beijing"));
assert!(result[0].function.arguments.contains("2024-06-27"));
}
#[tokio::test]
async fn test_parse_glm4_multiple_tools() {
let parser = Glm4MoeParser::new();
let input = r#"<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
</tool_call>
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Shanghai</arg_value>
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Beijing"));
assert!(result[1].function.arguments.contains("Shanghai"));
}
#[tokio::test]
async fn test_parse_glm4_mixed_types() {
let parser = Glm4MoeParser::new();
let input = r#"<tool_call>process_data
<arg_key>count</arg_key>
<arg_value>42</arg_value>
<arg_key>active</arg_key>
<arg_value>true</arg_value>
<arg_key>name</arg_key>
<arg_value>test</arg_value>
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process_data");
// Parse arguments to check types
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["count"], 42);
assert_eq!(args["active"], true);
assert_eq!(args["name"], "test");
}
#[test]
fn test_detect_format() {
let parser = Glm4MoeParser::new();
assert!(parser.detect_format("<tool_call>"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}

View File

@@ -0,0 +1,292 @@
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// GPT-OSS format parser for tool calls
///
/// Handles the GPT-OSS specific channel format:
/// `<|channel|>commentary to={namespace.function}<|constrain|>json<|message|>{json_args}<|call|>`
///
/// Features:
/// - Channel-based format with commentary
/// - Namespaced function calls
/// - JSON arguments
pub struct GptOssParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting complete function calls
function_call_extractor: Regex,
/// Regex for extracting streaming function calls
streaming_extractor: Regex,
}
impl GptOssParser {
/// Create a new GPT-OSS parser
pub fn new() -> Self {
// Pattern for complete function calls with to= parameter
// Handles optional <|start|>assistant prefix and whitespace after function name
let function_call_pattern = r"(?s)(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*<\|constrain\|>json<\|message\|>(.*?)<\|call\|>(?:commentary)?";
let function_call_extractor =
Regex::new(function_call_pattern).expect("Valid regex pattern");
// Pattern for streaming function calls (incomplete)
let streaming_pattern = r"(?s)(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*<\|constrain\|>json<\|message\|>(.*)";
let streaming_extractor = Regex::new(streaming_pattern).expect("Valid regex pattern");
Self {
partial_json: PartialJson::default(),
function_call_extractor,
streaming_extractor,
}
}
/// Check if text contains GPT-OSS tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<|channel|>commentary to=")
}
/// Extract function name from full namespace (e.g., "functions.get_weather" -> "get_weather")
fn extract_function_name(&self, full_name: &str) -> String {
if let Some(dot_pos) = full_name.rfind('.') {
full_name[dot_pos + 1..].to_string()
} else {
full_name.to_string()
}
}
}
impl Default for GptOssParser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for GptOssParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// Check if text contains GPT-OSS format
if !self.has_tool_markers(text) {
return Ok(vec![]);
}
let mut tools = Vec::new();
let mut _tool_index = 0;
// Extract all function calls
for captures in self.function_call_extractor.captures_iter(text) {
if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) {
let full_function_name = name_match.as_str();
let args_content = args_match.as_str().trim();
// Extract actual function name
let function_name = self.extract_function_name(full_function_name);
// Parse JSON arguments
let arguments = if args_content.is_empty() {
"{}".to_string()
} else {
match serde_json::from_str::<Value>(args_content) {
Ok(value) => serde_json::to_string(&value)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?,
Err(_) => {
// Skip malformed JSON
continue;
}
}
};
// Generate unique ID
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: function_name,
arguments,
},
});
_tool_index += 1;
}
}
Ok(tools)
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check for tool markers
if !self.has_tool_markers(&state.buffer) {
// No markers found, clear buffer and return
state.buffer.clear();
return Ok(StreamResult::Incomplete);
}
// Try to match streaming pattern
if let Some(captures) = self.streaming_extractor.captures(&state.buffer) {
if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) {
let full_function_name = name_match.as_str();
let partial_args = args_match.as_str();
// Extract actual function name
let function_name = self.extract_function_name(full_function_name);
// Send function name if not sent yet
if !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: function_name.clone(),
});
}
// Check if we have a complete function call
if let Some(complete_match) = self.function_call_extractor.captures(&state.buffer) {
if let Some(args_match) = complete_match.get(2) {
let args_content = args_match.as_str().trim();
// Parse JSON arguments
let arguments = if args_content.is_empty() {
"{}".to_string()
} else {
match serde_json::from_str::<Value>(args_content) {
Ok(value) => serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string()),
Err(_) => "{}".to_string(),
}
};
// Generate unique ID
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
let tool = ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: function_name,
arguments,
},
};
// Remove the processed part from buffer
let complete_end = complete_match.get(0).unwrap().end();
state.buffer.drain(..complete_end);
// Reset state for next tool
state.in_string = false;
return Ok(StreamResult::ToolComplete(tool));
}
} else {
// Try to parse partial JSON for streaming arguments
if !partial_args.is_empty() {
// Look for the end of JSON (before <|call|>)
let json_part = if let Some(call_pos) = partial_args.find("<|call|>") {
&partial_args[..call_pos]
} else {
partial_args
};
match self.partial_json.parse_value(json_part) {
Ok((value, _consumed)) => {
let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
Err(_) => {
// Can't parse yet, keep buffering
}
}
}
}
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) || text.contains("<|channel|>commentary")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_gpt_oss_single_tool() {
let parser = GptOssParser::new();
let input = r#"Some text
<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "San Francisco"}<|call|>
More text"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("San Francisco"));
}
#[tokio::test]
async fn test_parse_gpt_oss_multiple_tools() {
let parser = GptOssParser::new();
let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "search");
assert!(result[0].function.arguments.contains("Paris"));
assert!(result[1].function.arguments.contains("Paris tourism"));
}
#[tokio::test]
async fn test_parse_gpt_oss_with_prefix() {
let parser = GptOssParser::new();
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
#[tokio::test]
async fn test_parse_gpt_oss_empty_args() {
let parser = GptOssParser::new();
let input =
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_time");
assert_eq!(result[0].function.arguments, "{}");
}
#[test]
fn test_detect_format() {
let parser = GptOssParser::new();
assert!(parser.detect_format("<|channel|>commentary to="));
assert!(parser.detect_format("<|channel|>commentary"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}

View File

@@ -0,0 +1,619 @@
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, TokenConfig, ToolCall},
};
/// JSON format parser for tool calls
///
/// Handles various JSON formats for function calling:
/// - Single tool call: {"name": "fn", "arguments": {...}}
/// - Multiple tool calls: [{"name": "fn1", "arguments": {...}}, ...]
/// - With parameters instead of arguments: {"name": "fn", "parameters": {...}}
///
/// Supports configurable token markers for different models
pub struct JsonParser {
/// Token configuration for parsing
token_config: TokenConfig,
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex patterns for extracting content between tokens
extractors: Vec<Regex>,
}
impl JsonParser {
/// Create a new JSON parser with default configuration
pub fn new() -> Self {
Self::with_config(TokenConfig {
start_tokens: vec![],
end_tokens: vec![],
separator: ", ".to_string(),
})
}
/// Create a parser with custom token configuration
pub fn with_config(config: TokenConfig) -> Self {
// Build extraction patterns for each token pair
let extractors: Vec<Regex> = config
.iter_pairs()
.filter_map(|(start, end)| {
if !start.is_empty() && !end.is_empty() {
// Use (?s) flag to enable DOTALL mode so . matches newlines
let pattern =
format!(r"(?s){}(.*?){}", regex::escape(start), regex::escape(end));
Regex::new(&pattern).ok()
} else {
None
}
})
.collect();
Self {
token_config: config,
partial_json: PartialJson::default(),
extractors,
}
}
/// Extract JSON content from text, handling wrapper tokens if configured
fn extract_json_content<'a>(&self, text: &'a str) -> &'a str {
let mut content = text;
// Try each extractor pattern (for tokens with both start and end)
for extractor in &self.extractors {
if let Some(captures) = extractor.captures(content) {
if let Some(matched) = captures.get(1) {
return matched.as_str().trim();
}
}
}
// Handle special case where there's a start token but no end token
for (start, end) in self.token_config.iter_pairs() {
if !start.is_empty() && end.is_empty() {
// Find the start token and extract everything after it
if let Some(pos) = content.find(start) {
content = &content[pos + start.len()..];
return content.trim();
}
}
}
content.trim()
}
/// Try to extract a JSON object or array from text that may contain other content
fn extract_json_from_text(&self, text: &str) -> Option<String> {
// Look for JSON object starting with {
if let Some(start) = text.find('{') {
let mut depth = 0;
let mut in_string = false;
let mut escape_next = false;
for (i, ch) in text[start..].char_indices() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' if !in_string => in_string = true,
'"' if in_string => in_string = false,
'{' if !in_string => depth += 1,
'}' if !in_string => {
depth -= 1;
if depth == 0 {
return Some(text[start..start + i + 1].to_string());
}
}
_ => {}
}
}
}
// Look for JSON array starting with [
if let Some(start) = text.find('[') {
let mut depth = 0;
let mut in_string = false;
let mut escape_next = false;
for (i, ch) in text[start..].char_indices() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' if !in_string => in_string = true,
'"' if in_string => in_string = false,
'[' if !in_string => depth += 1,
']' if !in_string => {
depth -= 1;
if depth == 0 {
return Some(text[start..start + i + 1].to_string());
}
}
_ => {}
}
}
}
None
}
/// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
// Check if this looks like a tool call
let name = obj
.get("name")
.or_else(|| obj.get("function"))
.and_then(|v| v.as_str());
if let Some(name) = name {
// Get arguments - support both "arguments" and "parameters" keys
let empty_obj = Value::Object(serde_json::Map::new());
let args = obj
.get("arguments")
.or_else(|| obj.get("parameters"))
.unwrap_or(&empty_obj);
// Convert arguments to JSON string
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate a unique ID if not provided
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: name.to_string(),
arguments,
},
}))
} else {
Ok(None)
}
}
/// Parse JSON value(s) into tool calls
fn parse_json_value(&self, value: &Value) -> ToolParserResult<Vec<ToolCall>> {
let mut tools = Vec::new();
match value {
Value::Array(arr) => {
// Parse each element in the array
for item in arr {
if let Some(tool) = self.parse_single_object(item)? {
tools.push(tool);
}
}
}
Value::Object(_) => {
// Single tool call
if let Some(tool) = self.parse_single_object(value)? {
tools.push(tool);
}
}
_ => {
// Not a valid tool call format
return Ok(vec![]);
}
}
Ok(tools)
}
/// Check if text contains potential tool call markers
fn has_tool_markers(&self, text: &str) -> bool {
// If no start tokens configured, check for JSON structure
if self.token_config.start_tokens.is_empty() {
// For JSON, we just need to see the start of an object or array
return text.contains('{') || text.contains('[');
}
// Check for any start token
self.token_config
.start_tokens
.iter()
.any(|token| text.contains(token))
}
}
impl Default for JsonParser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for JsonParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// Check if we have multiple start tokens (e.g., multiple <|python_tag|> markers)
if !self.token_config.start_tokens.is_empty() {
let start_token = &self.token_config.start_tokens[0];
if !start_token.is_empty() && text.matches(start_token).count() > 1 {
// We have multiple occurrences of the start token
let mut all_tools = Vec::new();
let mut remaining = text;
while let Some(start_pos) = remaining.find(start_token.as_str()) {
// Extract content after this start token
let after_token = &remaining[start_pos + start_token.len()..];
// Find where this JSON ends (look for the next start token or end of string)
let end_pos = if let Some(next_start) = after_token.find(start_token.as_str()) {
next_start
} else {
after_token.len()
};
let json_content = &after_token[..end_pos];
// Try to extract and parse JSON from this segment
if let Some(extracted) = self.extract_json_from_text(json_content) {
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
if let Ok(tools) = self.parse_json_value(&value) {
all_tools.extend(tools);
}
}
}
// Move to the next segment
remaining = &remaining[start_pos + start_token.len() + end_pos..];
if remaining.is_empty() {
break;
}
}
if !all_tools.is_empty() {
return Ok(all_tools);
}
}
}
// Extract JSON content from wrapper tokens if present
let json_content = self.extract_json_content(text);
// Try to parse as JSON first
match serde_json::from_str::<Value>(json_content) {
Ok(value) => self.parse_json_value(&value),
Err(_) => {
// If parse failed, check if we have multiple JSON objects separated by the configured separator
// This handles cases like: {"name": "func1", ...};{"name": "func2", ...}
if !self.token_config.separator.is_empty()
&& json_content.contains(&self.token_config.separator)
{
let mut all_tools = Vec::new();
// Split by separator and try to parse each part
let parts: Vec<&str> =
json_content.split(&self.token_config.separator).collect();
for part in parts {
let trimmed = part.trim();
if trimmed.is_empty() {
continue;
}
// Try to parse this part as JSON
if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
if let Ok(tools) = self.parse_json_value(&value) {
all_tools.extend(tools);
}
} else if let Some(extracted) = self.extract_json_from_text(trimmed) {
// Try extracting JSON from this part
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
if let Ok(tools) = self.parse_json_value(&value) {
all_tools.extend(tools);
}
}
}
}
if !all_tools.is_empty() {
return Ok(all_tools);
}
}
// If no wrapper tokens configured and parse failed,
// try to extract JSON from mixed text
if self.token_config.start_tokens.is_empty() {
if let Some(extracted) = self.extract_json_from_text(text) {
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
return self.parse_json_value(&value);
}
}
}
// Not valid JSON, return empty
Ok(vec![])
}
}
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check if we have potential tool calls
if !self.has_tool_markers(&state.buffer) {
// No tool markers, return as incomplete
return Ok(StreamResult::Incomplete);
}
// Extract JSON content first to check for separators
let extracted_json = self.extract_json_content(&state.buffer);
// Handle multiple JSON objects with separators
// Check if we have a separator and potentially multiple JSON objects
let separator = &self.token_config.separator;
if !separator.is_empty() && extracted_json.contains(separator.as_str()) {
// Try to find a complete JSON object before the separator
if let Some(separator_pos) = extracted_json.find(separator.as_str()) {
// Get JSON before separator
let before_separator = &extracted_json[..separator_pos];
// Try to parse the JSON before the separator
match serde_json::from_str::<Value>(before_separator) {
Ok(value) => {
// Parse tool calls from this JSON
let tools = self.parse_json_value(&value)?;
if !tools.is_empty() {
// We need to figure out how much to remove from the original buffer
// Find where the separator is in the original buffer and remove up to and including it
if let Some(sep_in_original) = state.buffer.find(separator.as_str()) {
let remaining =
state.buffer[sep_in_original + separator.len()..].to_string();
state.buffer = remaining;
}
// Return the first tool as complete
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
}
Err(_) => {
// Failed to parse, continue to try other methods
}
}
}
}
// Handle multiple start tokens (e.g., multiple <|python_tag|> markers)
if !self.token_config.start_tokens.is_empty() {
let start_token = &self.token_config.start_tokens[0];
if !start_token.is_empty() {
// Find all occurrences of start token
let occurrences: Vec<_> =
state.buffer.match_indices(start_token.as_str()).collect();
if occurrences.len() > 1 {
// We have multiple start tokens, try to process the first complete one
let first_pos = occurrences[0].0;
let second_pos = occurrences[1].0;
// Extract content between first and second start token
let first_json_section = &state.buffer[first_pos..second_pos];
let json_content = self.extract_json_content(first_json_section);
// Try to parse this as complete JSON
if let Ok(value) = serde_json::from_str::<Value>(json_content) {
// Parse tool calls from this JSON
let tools = self.parse_json_value(&value)?;
if !tools.is_empty() {
// Remove the processed section from buffer
let remaining = state.buffer[second_pos..].to_string();
state.buffer = remaining;
// Return the first tool as complete
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
}
}
}
}
// Regular single JSON parsing
// Extract JSON content
let json_content = self.extract_json_content(&state.buffer);
// Try to parse with partial JSON parser
match self.partial_json.parse_value(json_content) {
Ok((value, consumed)) => {
// Check if we have a complete JSON structure
if consumed == json_content.len() {
// Check if this is truly complete or just has null from incomplete parsing
// We need to ensure the JSON actually ends properly (not cut off mid-key)
let trimmed = json_content.trim();
let looks_complete = trimmed.ends_with('}') || trimmed.ends_with(']');
if looks_complete {
// Complete JSON, parse tool calls
let tools = self.parse_json_value(&value)?;
if !tools.is_empty() {
// Clear buffer since we consumed everything
state.buffer.clear();
// Return the first tool as complete
// TODO simplified version, address more complex version
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
}
} else {
// Partial JSON, try to extract tool name
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
// TODO simplified version, address more complex version
// Just return the tool name once we see it
if !state.in_string {
state.in_string = true; // Use as a flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for complete arguments
if let Some(args) =
value.get("arguments").or_else(|| value.get("parameters"))
{
if let Ok(args_str) = serde_json::to_string(args) {
// Return arguments as a single update
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Keep buffering
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
// Check if text contains JSON-like structure
if self.has_tool_markers(text) {
// Try to extract and parse
let json_content = self.extract_json_content(text);
// Check if it looks like valid JSON for tool calls
if let Ok(value) = serde_json::from_str::<Value>(json_content) {
match value {
Value::Object(ref obj) => {
// Check for tool call structure
obj.contains_key("name") || obj.contains_key("function")
}
Value::Array(ref arr) => {
// Check if array contains tool-like objects
arr.iter().any(|v| {
if let Some(obj) = v.as_object() {
obj.contains_key("name") || obj.contains_key("function")
} else {
false
}
})
}
_ => false,
}
} else {
false
}
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_single_tool_call() {
let parser = JsonParser::new();
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
}
#[tokio::test]
async fn test_parse_multiple_tool_calls() {
let parser = JsonParser::new();
let input = r#"[
{"name": "get_weather", "arguments": {"location": "SF"}},
{"name": "search", "arguments": {"query": "news"}}
]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "search");
}
#[tokio::test]
async fn test_parse_with_parameters_key() {
let parser = JsonParser::new();
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "calculate");
assert!(result[0].function.arguments.contains("10"));
}
#[tokio::test]
async fn test_parse_with_wrapper_tokens() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string()],
end_tokens: vec!["</tool>".to_string()],
separator: ", ".to_string(),
});
let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
#[test]
fn test_detect_format() {
let parser = JsonParser::new();
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(parser.detect_format(r#"[{"name": "test"}]"#));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format(r#"{"key": "value"}"#));
}
#[tokio::test]
async fn test_streaming_parse() {
// Just verify that streaming eventually produces a complete tool call
let parser = JsonParser::new();
let mut state = ParseState::new();
// Send complete JSON in one go
// TODO simplified version, address more complex version
let full_json = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
let result = parser
.parse_incremental(full_json, &mut state)
.await
.unwrap();
// Should get a complete tool immediately with complete JSON
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
assert!(tool.function.arguments.contains("SF"));
}
_ => panic!("Expected ToolComplete for complete JSON input"),
}
}
}

View File

@@ -0,0 +1,270 @@
use async_trait::async_trait;
use regex::Regex;
use crate::tool_parser::{
errors::ToolParserResult,
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// Kimi K2 format parser for tool calls
///
/// Handles the Kimi K2 specific format:
/// `<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|><|tool_calls_section_end|>`
///
/// Features:
/// - Token-based delimiters
/// - Function calls with explicit indexing
/// - JSON arguments
pub struct KimiK2Parser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting complete tool calls
tool_call_extractor: Regex,
/// Regex for extracting partial tool calls (streaming)
stream_tool_call_extractor: Regex,
}
impl KimiK2Parser {
/// Create a new Kimi K2 parser
pub fn new() -> Self {
// Pattern for complete tool calls
let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>";
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
// Pattern for streaming (partial) tool calls
let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");
Self {
partial_json: PartialJson::default(),
tool_call_extractor,
stream_tool_call_extractor,
}
}
/// Check if text contains Kimi K2 tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<|tool_calls_section_begin|>")
}
/// Parse function ID to extract name and index
fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
// Format: functions.{name}:{index} or namespace.functions.{name}:{index}
// Extract everything after the last dot before the colon as the function name
if let Some(colon_pos) = id.rfind(':') {
let before_colon = &id[..colon_pos];
let index_str = &id[colon_pos + 1..];
// Find the last dot to extract the function name
if let Some(dot_pos) = before_colon.rfind('.') {
let func_name = &before_colon[dot_pos + 1..];
if let Ok(index) = index_str.parse::<usize>() {
return Some((func_name.to_string(), index));
}
}
}
None
}
}
impl Default for KimiK2Parser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for KimiK2Parser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// Check if text contains Kimi K2 format
if !self.has_tool_markers(text) {
return Ok(vec![]);
}
let mut tools = Vec::new();
// Extract all tool calls
for captures in self.tool_call_extractor.captures_iter(text) {
if let (Some(id_match), Some(args_match)) = (
captures.name("tool_call_id"),
captures.name("function_arguments"),
) {
let function_id = id_match.as_str();
let function_args = args_match.as_str();
// Parse function ID
if let Some((func_name, _index)) = self.parse_function_id(function_id) {
// Validate JSON arguments
if serde_json::from_str::<serde_json::Value>(function_args).is_ok() {
// Generate unique ID
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name,
arguments: function_args.to_string(),
},
});
}
}
}
}
Ok(tools)
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check for tool markers
let has_tool_call =
self.has_tool_markers(&state.buffer) || state.buffer.contains("<|tool_call_begin|>");
if !has_tool_call {
// No markers found, clear buffer and return
state.buffer.clear();
return Ok(StreamResult::Incomplete);
}
// Try to match streaming pattern
if let Some(captures) = self.stream_tool_call_extractor.captures(&state.buffer) {
if let (Some(id_match), Some(args_match)) = (
captures.name("tool_call_id"),
captures.name("function_arguments"),
) {
let function_id = id_match.as_str();
let partial_args = args_match.as_str();
// Parse function ID
if let Some((func_name, _index)) = self.parse_function_id(function_id) {
// Send function name if not sent yet
if !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.clone(),
});
}
// Check if we have a complete tool call
if let Some(end_pos) = partial_args.find("<|tool_call_end|>") {
// Extract just the JSON part
let json_args = &partial_args[..end_pos];
// Validate and parse JSON
if serde_json::from_str::<serde_json::Value>(json_args).is_ok() {
// Generate unique ID
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
let tool = ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name,
arguments: json_args.to_string(),
},
};
// Find where this tool call ends in the buffer
if let Some(tool_end) = state.buffer.find("<|tool_call_end|>") {
let end_pos = tool_end + "<|tool_call_end|>".len();
state.buffer.drain(..end_pos);
}
// Reset state for next tool
state.in_string = false;
return Ok(StreamResult::ToolComplete(tool));
}
} else {
// Try to parse partial JSON for streaming arguments
match self.partial_json.parse_value(partial_args) {
Ok((value, _consumed)) => {
let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
Err(_) => {
// Can't parse yet, keep buffering
}
}
}
}
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_kimi_single_tool() {
let parser = KimiK2Parser::new();
let input = r#"Some text
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
<|tool_calls_section_end|>More text"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Tokyo"));
}
#[tokio::test]
async fn test_parse_kimi_multiple_tools() {
let parser = KimiK2Parser::new();
let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust"}<|tool_call_end|>
<|tool_call_begin|>functions.calculate:1<|tool_call_argument_begin|>{"expression": "2+2"}<|tool_call_end|>
<|tool_calls_section_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "calculate");
}
#[tokio::test]
async fn test_parse_kimi_with_whitespace() {
let parser = KimiK2Parser::new();
let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value"} <|tool_call_end|>
<|tool_calls_section_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
#[test]
fn test_detect_format() {
let parser = KimiK2Parser::new();
assert!(parser.detect_format("<|tool_calls_section_begin|>"));
assert!(parser.detect_format("<|tool_call_begin|>"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}

View File

@@ -0,0 +1,156 @@
use async_trait::async_trait;
use super::json_parser::JsonParser;
use crate::tool_parser::{
errors::ToolParserResult,
state::ParseState,
traits::ToolParser,
types::{StreamResult, TokenConfig, ToolCall},
};
/// Llama 3.2 format parser for tool calls
///
/// Handles the Llama 3.2 specific format:
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
///
/// Also supports plain JSON without the python_tag prefix
pub struct LlamaParser {
/// Underlying JSON parser with Llama-specific configuration
json_parser: JsonParser,
}
impl LlamaParser {
/// Create a new Llama parser
pub fn new() -> Self {
// Configure JSON parser with Llama's python_tag token
// Note: No end token for python_tag format
let json_parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<|python_tag|>".to_string()],
end_tokens: vec!["".to_string()], // Empty end token
separator: ";".to_string(), // Llama uses semicolon for multiple calls (though not well supported)
});
Self { json_parser }
}
}
impl Default for LlamaParser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for LlamaParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// First try with the configured python_tag parser
let result = self.json_parser.parse_complete(text).await?;
if !result.is_empty() {
return Ok(result);
}
// If no results and text starts with '{', try plain JSON
if text.trim_start().starts_with('{') {
// Create a temporary plain JSON parser
let plain_parser = JsonParser::new();
return plain_parser.parse_complete(text).await;
}
Ok(vec![])
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
// Try with the python_tag parser first
let result = self.json_parser.parse_incremental(chunk, state).await?;
// If we get Incomplete and buffer starts with '{', might be plain JSON
if matches!(result, StreamResult::Incomplete) && state.buffer.trim_start().starts_with('{')
{
// Check if we have python_tag in the buffer
if !state.buffer.contains("<|python_tag|>") {
// Likely plain JSON, create temporary parser
let plain_parser = JsonParser::new();
return plain_parser.parse_incremental("", state).await;
}
}
Ok(result)
}
fn detect_format(&self, text: &str) -> bool {
// Llama format if contains python_tag or starts with JSON object
text.contains("<|python_tag|>")
|| (text.trim_start().starts_with('{')
&& (text.contains(r#""name""#) || text.contains(r#""function""#)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_with_python_tag() {
let parser = LlamaParser::new();
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search");
assert!(result[0].function.arguments.contains("weather"));
}
#[tokio::test]
async fn test_parse_plain_json() {
let parser = LlamaParser::new();
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "calculate");
}
#[tokio::test]
async fn test_parse_with_text_before() {
let parser = LlamaParser::new();
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_time");
}
#[test]
fn test_detect_format() {
let parser = LlamaParser::new();
assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
}
#[tokio::test]
async fn test_single_call_with_semicolon() {
let parser = LlamaParser::new();
// Note: Llama 3.2 doesn't handle multiple calls well
// Test that we can at least parse a single call followed by semicolon
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
let result = parser.parse_complete(input).await.unwrap();
// We expect this to either parse the first JSON object or fail gracefully
// Since the semicolon makes it invalid JSON, it will likely return empty
// This is acceptable as Llama 3.2 doesn't reliably support parallel calls
// If it parses anything, it should be func1
if !result.is_empty() {
assert_eq!(result[0].function.name, "func1");
}
}
}

View File

@@ -0,0 +1,347 @@
use async_trait::async_trait;
use serde_json::Value;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// Mistral format parser for tool calls
///
/// Handles the Mistral-specific format:
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
///
/// Features:
/// - Bracket counting for proper JSON array extraction
/// - Support for multiple tool calls in a single array
/// - String-aware parsing to handle nested brackets in JSON
pub struct MistralParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
}
impl MistralParser {
/// Create a new Mistral parser
pub fn new() -> Self {
Self {
partial_json: PartialJson::default(),
}
}
/// Extract JSON array using bracket counting
///
/// Handles nested brackets in JSON content by tracking:
/// - String boundaries (quotes)
/// - Escape sequences
/// - Bracket depth
fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
const BOT_TOKEN: &str = "[TOOL_CALLS] [";
// Find the start of the token
let start_idx = text.find(BOT_TOKEN)?;
// Start from the opening bracket after [TOOL_CALLS]
// The -1 is to include the opening bracket that's part of the token
let json_start = start_idx + BOT_TOKEN.len() - 1;
let mut bracket_count = 0;
let mut in_string = false;
let mut escape_next = false;
let bytes = text.as_bytes();
for i in json_start..text.len() {
let char = bytes[i];
if escape_next {
escape_next = false;
continue;
}
if char == b'\\' {
escape_next = true;
continue;
}
if char == b'"' && !escape_next {
in_string = !in_string;
continue;
}
if !in_string {
if char == b'[' {
bracket_count += 1;
} else if char == b']' {
bracket_count -= 1;
if bracket_count == 0 {
// Found the matching closing bracket
return Some(&text[json_start..=i]);
}
}
}
}
// Incomplete array (no matching closing bracket found)
None
}
/// Parse tool calls from a JSON array
fn parse_json_array(&self, json_str: &str) -> ToolParserResult<Vec<ToolCall>> {
let value: Value = serde_json::from_str(json_str)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
let mut tools = Vec::new();
if let Value::Array(arr) = value {
for (index, item) in arr.iter().enumerate() {
if let Some(tool) = self.parse_single_object(item, index)? {
tools.push(tool);
}
}
} else {
// Single object case (shouldn't happen with Mistral format, but handle it)
if let Some(tool) = self.parse_single_object(&value, 0)? {
tools.push(tool);
}
}
Ok(tools)
}
/// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name {
// Get arguments - Mistral uses "arguments" key
let empty_obj = Value::Object(serde_json::Map::new());
let args = obj.get("arguments").unwrap_or(&empty_obj);
// Convert arguments to JSON string
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID with index for multiple tools
let id = format!("mistral_call_{}", index);
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: name.to_string(),
arguments,
},
}))
} else {
Ok(None)
}
}
/// Check if text contains Mistral tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("[TOOL_CALLS]")
}
}
impl Default for MistralParser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for MistralParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// Check if text contains Mistral format
if !self.has_tool_markers(text) {
return Ok(vec![]);
}
// Extract JSON array from Mistral format
if let Some(json_array) = self.extract_json_array(text) {
self.parse_json_array(json_array)
} else {
// Markers present but no complete array found
Ok(vec![])
}
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check if we have the start marker
if !self.has_tool_markers(&state.buffer) {
return Ok(StreamResult::Incomplete);
}
// Try to extract complete JSON array
if let Some(json_array) = self.extract_json_array(&state.buffer) {
// Parse with partial JSON to handle incomplete content
match self.partial_json.parse_value(json_array) {
Ok((value, consumed)) => {
// Check if we have a complete JSON structure
if consumed == json_array.len() {
// Complete JSON, parse tool calls
let tools = if let Value::Array(arr) = value {
let mut result = Vec::new();
for (index, item) in arr.iter().enumerate() {
if let Some(tool) = self.parse_single_object(item, index)? {
result.push(tool);
}
}
result
} else {
vec![]
};
if !tools.is_empty() {
// Clear buffer since we consumed everything
state.buffer.clear();
// Return the first tool (simplified for Phase 3)
// Full multi-tool streaming will be implemented later
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
} else {
// Partial JSON - try to extract tool name for streaming
if let Value::Array(arr) = value {
if let Some(first_tool) = arr.first() {
if let Some(name) = first_tool.get("name").and_then(|v| v.as_str())
{
// Check if we've already sent the name
if !state.in_string {
state.in_string = true; // Use as flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for arguments
if let Some(args) = first_tool.get("arguments") {
if let Ok(args_str) = serde_json::to_string(args) {
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Keep buffering
}
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
// Check if text contains Mistral-specific markers
if self.has_tool_markers(text) {
// Try to extract and validate the array
if let Some(json_array) = self.extract_json_array(text) {
// Check if it's valid JSON
if let Ok(value) = serde_json::from_str::<Value>(json_array) {
// Check if it contains tool-like structures
match value {
Value::Array(ref arr) => arr.iter().any(|v| {
v.as_object().is_some_and(|o| {
o.contains_key("name") && o.contains_key("arguments")
})
}),
Value::Object(ref obj) => {
obj.contains_key("name") && obj.contains_key("arguments")
}
_ => false,
}
} else {
false
}
} else {
// Has markers but no complete array - might be streaming
true
}
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_mistral_format() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Paris"));
}
#[tokio::test]
async fn test_parse_multiple_tools() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [
{"name": "search", "arguments": {"query": "rust programming"}},
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "calculate");
}
#[tokio::test]
async fn test_nested_brackets_in_json() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process");
// JSON serialization removes spaces, so check for [3,4] without spaces
assert!(result[0].function.arguments.contains("[3,4]"));
}
#[tokio::test]
async fn test_escaped_quotes_in_strings() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "echo");
}
#[test]
fn test_detect_format() {
let parser = MistralParser::new();
assert!(parser.detect_format(r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#));
assert!(
parser.detect_format(r#"Some text [TOOL_CALLS] [{"name": "test", "arguments": {}}]"#)
);
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(!parser.detect_format("plain text"));
}
}

View File

@@ -0,0 +1,27 @@
/// Parser implementations for different model formats
///
/// This module contains concrete parser implementations for various model-specific
/// tool/function call formats.
// Individual parser modules
pub mod deepseek_parser;
pub mod glm4_moe_parser;
pub mod gpt_oss_parser;
pub mod json_parser;
pub mod kimik2_parser;
pub mod llama_parser;
pub mod mistral_parser;
pub mod pythonic_parser;
pub mod qwen_parser;
pub mod step3_parser;
// Re-export parser types for convenience
pub use deepseek_parser::DeepSeekParser;
pub use glm4_moe_parser::Glm4MoeParser;
pub use gpt_oss_parser::GptOssParser;
pub use json_parser::JsonParser;
pub use kimik2_parser::KimiK2Parser;
pub use llama_parser::LlamaParser;
pub use mistral_parser::MistralParser;
pub use pythonic_parser::PythonicParser;
pub use qwen_parser::QwenParser;
pub use step3_parser::Step3Parser;

View File

@@ -0,0 +1,434 @@
/// Pythonic format parser for tool calls
///
/// Handles Python function call syntax within square brackets:
/// ```text
/// [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
/// ```
///
/// This format is used by Llama-4 models and uses Python literals
/// rather than JSON for arguments.
use async_trait::async_trait;
use regex::Regex;
use serde_json::{json, Value};
use crate::tool_parser::{
errors::ToolParserResult,
python_literal_parser::parse_python_literal,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// Parser for Pythonic tool call format
pub struct PythonicParser {
/// Regex to detect tool calls in Pythonic format
tool_call_regex: Regex,
/// Regex to parse function calls - cached for reuse
call_regex: Regex,
}
impl PythonicParser {
/// Create a new Pythonic parser
pub fn new() -> Self {
// Simple regex to detect start of Pythonic tool calls
// We'll use manual parsing for the actual extraction
let pattern = r"\[[a-zA-Z_]\w*\(";
let tool_call_regex = Regex::new(pattern).expect("Valid regex pattern");
// Compile the function call regex once
let call_regex = Regex::new(r"(?s)^([a-zA-Z_]\w*)\((.*)\)$").expect("Valid regex pattern");
Self {
tool_call_regex,
call_regex,
}
}
/// Extract tool calls using bracket counting (similar to MistralParser)
fn extract_tool_calls(&self, text: &str) -> Option<String> {
// Find the start of a tool call list - look for [ followed by a function name
let chars: Vec<char> = text.chars().collect();
for start_idx in 0..chars.len() {
if chars[start_idx] != '[' {
continue;
}
// Check if this looks like a tool call
// Skip whitespace after [
let mut check_idx = start_idx + 1;
while check_idx < chars.len() && chars[check_idx].is_whitespace() {
check_idx += 1;
}
// Check if we have a function name (starts with letter or underscore)
if check_idx >= chars.len()
|| (!chars[check_idx].is_alphabetic() && chars[check_idx] != '_')
{
continue;
}
// Now count brackets to find the matching ]
let mut bracket_count = 0;
let mut _paren_count = 0;
let mut _brace_count = 0;
let mut in_string = false;
let mut string_char = ' ';
let mut escape_next = false;
for i in start_idx..chars.len() {
let ch = chars[i];
if escape_next {
escape_next = false;
continue;
}
if ch == '\\' && in_string {
escape_next = true;
continue;
}
if !in_string && (ch == '"' || ch == '\'') {
in_string = true;
string_char = ch;
} else if in_string && ch == string_char && !escape_next {
in_string = false;
} else if !in_string {
match ch {
'[' => bracket_count += 1,
']' => {
bracket_count -= 1;
if bracket_count == 0 {
// Found the matching bracket
let extracted: String = chars[start_idx..=i].iter().collect();
// Verify this actually contains a function call
if extracted.contains('(') && extracted.contains(')') {
return Some(extracted);
}
}
}
'(' => _paren_count += 1,
')' => _paren_count -= 1,
'{' => _brace_count += 1,
'}' => _brace_count -= 1,
_ => {}
}
}
}
}
None
}
/// Strip special tokens that Llama 4 might output
fn strip_special_tokens(text: &str) -> String {
text.replace("<|python_start|>", "")
.replace("<|python_end|>", "")
}
/// Parse a single function call from Python syntax
fn parse_function_call(&self, call_str: &str) -> ToolParserResult<Option<ToolCall>> {
// Use cached regex instead of creating new one
if let Some(captures) = self.call_regex.captures(call_str.trim()) {
let function_name = captures.get(1).unwrap().as_str();
let args_str = captures.get(2).unwrap().as_str();
// Parse arguments
let arguments = self.parse_arguments(args_str)?;
Ok(Some(ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
r#type: "function".to_string(),
function: FunctionCall {
name: function_name.to_string(),
arguments: serde_json::to_string(&arguments)?,
},
}))
} else {
Ok(None)
}
}
/// Parse Python-style arguments into JSON
fn parse_arguments(&self, args_str: &str) -> ToolParserResult<Value> {
if args_str.trim().is_empty() {
return Ok(json!({}));
}
let mut result = serde_json::Map::new();
let mut current_key = String::new();
let mut current_value = String::new();
let mut in_key = true;
let mut depth = 0;
let mut in_string = false;
let mut string_char = ' ';
let mut escape_next = false;
let chars: Vec<char> = args_str.chars().collect();
let mut i = 0;
while i < chars.len() {
let ch = chars[i];
if escape_next {
if in_key {
current_key.push(ch);
} else {
current_value.push(ch);
}
escape_next = false;
i += 1;
continue;
}
if ch == '\\' && in_string {
escape_next = true;
current_value.push(ch);
i += 1;
continue;
}
// Handle string literals
if !in_string && (ch == '"' || ch == '\'') {
in_string = true;
string_char = ch;
if !in_key {
current_value.push(ch);
}
} else if in_string && ch == string_char && !escape_next {
in_string = false;
if !in_key {
current_value.push(ch);
}
} else if in_string {
if in_key {
current_key.push(ch);
} else {
current_value.push(ch);
}
} else {
// Not in string
match ch {
'=' if in_key && depth == 0 => {
in_key = false;
}
',' if depth == 0 => {
// End of current argument
if !current_key.is_empty() {
let value = parse_python_literal(current_value.trim())?;
result.insert(current_key.trim().to_string(), value);
}
current_key.clear();
current_value.clear();
in_key = true;
}
'[' | '{' | '(' => {
depth += 1;
if !in_key {
current_value.push(ch);
}
}
']' | '}' | ')' => {
depth -= 1;
if !in_key {
current_value.push(ch);
}
}
_ => {
if in_key {
if !ch.is_whitespace() || !current_key.is_empty() {
current_key.push(ch);
}
} else {
current_value.push(ch);
}
}
}
}
i += 1;
}
// Handle the last argument
if !current_key.is_empty() {
let value = parse_python_literal(current_value.trim())?;
result.insert(current_key.trim().to_string(), value);
}
Ok(Value::Object(result))
}
}
#[async_trait]
impl ToolParser for PythonicParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
let cleaned = Self::strip_special_tokens(text);
// Extract tool calls using bracket counting
if let Some(tool_calls_text) = self.extract_tool_calls(&cleaned) {
// Remove the outer brackets
let tool_calls_str = &tool_calls_text[1..tool_calls_text.len() - 1];
// Split into individual function calls
let mut calls = Vec::new();
let mut current_call = String::new();
let mut paren_depth = 0;
let mut in_string = false;
let mut string_char = ' ';
for ch in tool_calls_str.chars() {
if !in_string && (ch == '"' || ch == '\'') {
in_string = true;
string_char = ch;
current_call.push(ch);
} else if in_string && ch == string_char {
in_string = false;
current_call.push(ch);
} else if in_string {
current_call.push(ch);
} else {
match ch {
'(' => {
paren_depth += 1;
current_call.push(ch);
}
')' => {
paren_depth -= 1;
current_call.push(ch);
}
',' if paren_depth == 0 => {
// End of current function call
if let Some(call) = self.parse_function_call(current_call.trim())? {
calls.push(call);
}
current_call.clear();
}
_ => {
if !ch.is_whitespace() || !current_call.is_empty() {
current_call.push(ch);
}
}
}
}
}
// Handle the last call (important for single calls or the last call in a list)
if !current_call.trim().is_empty() {
if let Some(call) = self.parse_function_call(current_call.trim())? {
calls.push(call);
}
}
Ok(calls)
} else {
Ok(vec![])
}
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
// For Pythonic format, we accumulate until we have a complete tool call
// This is a simplified implementation
state.buffer.push_str(chunk);
// Try to parse if we have a complete tool call
let cleaned = Self::strip_special_tokens(&state.buffer);
if self.extract_tool_calls(&cleaned).is_some() {
let result = self.parse_complete(&state.buffer).await?;
if !result.is_empty() {
state.buffer.clear();
return Ok(StreamResult::ToolComplete(
result.into_iter().next().unwrap(),
));
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
let cleaned = Self::strip_special_tokens(text);
self.tool_call_regex.is_match(&cleaned)
}
}
impl Default for PythonicParser {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_single_function_call() {
let parser = PythonicParser::new();
let input = r#"[search_web(query="Rust programming", max_results=5)]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search_web");
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["query"], "Rust programming");
assert_eq!(args["max_results"], 5);
}
#[tokio::test]
async fn test_multiple_function_calls() {
let parser = PythonicParser::new();
let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "search");
}
#[tokio::test]
async fn test_python_literals() {
let parser = PythonicParser::new();
let input = r#"[test(flag=True, disabled=False, optional=None)]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["flag"], true);
assert_eq!(args["disabled"], false);
assert_eq!(args["optional"], Value::Null);
}
#[tokio::test]
async fn test_special_tokens() {
let parser = PythonicParser::new();
let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "calculate");
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["x"], 10);
assert_eq!(args["y"], 20);
}
#[tokio::test]
async fn test_llama4_format() {
let parser = PythonicParser::new();
let input = r#"[get_weather(city="London", units="celsius")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["city"], "London");
assert_eq!(args["units"], "celsius");
}
}

View File

@@ -0,0 +1,396 @@
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// Qwen format parser for tool calls
///
/// Handles the Qwen 2.5/3 specific format:
/// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>`
///
/// Features:
/// - XML-style tags with JSON content
/// - Support for multiple sequential tool calls
/// - Newline-aware parsing
pub struct QwenParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting tool calls
extractor: Regex,
}
impl QwenParser {
/// Create a new Qwen parser
pub fn new() -> Self {
// Use (?s) flag for DOTALL mode to handle newlines
let pattern = r"(?s)<tool_call>\n(.*?)\n</tool_call>";
let extractor = Regex::new(pattern).expect("Valid regex pattern");
Self {
partial_json: PartialJson::default(),
extractor,
}
}
/// Extract all tool call blocks from text
fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> {
self.extractor
.captures_iter(text)
.filter_map(|cap| cap.get(1).map(|m| m.as_str()))
.collect()
}
/// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name {
// Get arguments - Qwen uses "arguments" key
let empty_obj = Value::Object(serde_json::Map::new());
let args = obj.get("arguments").unwrap_or(&empty_obj);
// Convert arguments to JSON string
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID with index for multiple tools
let id = format!("qwen_call_{}", index);
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: name.to_string(),
arguments,
},
}))
} else {
Ok(None)
}
}
/// Check if text contains Qwen tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<tool_call>")
}
/// Find the start position of a tool call
fn find_tool_start(&self, text: &str) -> Option<usize> {
text.find("<tool_call>\n")
}
/// Find the end position of a tool call
fn find_tool_end(&self, text: &str, start_pos: usize) -> Option<usize> {
let search_from = start_pos + "<tool_call>\n".len();
text[search_from..]
.find("\n</tool_call>")
.map(|pos| search_from + pos + "\n</tool_call>".len())
}
/// Check if buffer ends with a partial token
fn ends_with_partial_token(&self, buffer: &str) -> Option<usize> {
// Check for partial start token
let start_token = "<tool_call>\n";
// Use inclusive range to check if entire buffer could be a prefix
for i in 1..=start_token.len().min(buffer.len()) {
if start_token.starts_with(&buffer[buffer.len() - i..]) {
return Some(i);
}
}
// Check for partial end token
let end_token = "\n</tool_call>";
// Only check if buffer ends with a partial match (not the complete token without newline)
// If buffer ends with "</tool_call>", that's not a partial token - it's missing the newline
if buffer.ends_with("</tool_call>") {
// This is a complete end tag, just missing the leading newline
// Not a partial token situation
return None;
}
// Use inclusive range to check if entire buffer could be a prefix
(1..=end_token.len().min(buffer.len()))
.find(|&i| end_token.starts_with(&buffer[buffer.len() - i..]))
}
}
impl Default for QwenParser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for QwenParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// Check if text contains Qwen format
if !self.has_tool_markers(text) {
return Ok(vec![]);
}
// Extract all tool call blocks
let tool_blocks = self.extract_tool_calls(text);
let mut tools = Vec::new();
for (index, json_str) in tool_blocks.iter().enumerate() {
// Parse each JSON block
match serde_json::from_str::<Value>(json_str.trim()) {
Ok(value) => {
if let Some(tool) = self.parse_single_object(&value, index)? {
tools.push(tool);
}
}
Err(_) => {
// Skip malformed JSON blocks
continue;
}
}
}
Ok(tools)
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check for partial token at end of buffer
if let Some(_partial_len) = self.ends_with_partial_token(&state.buffer) {
// Hold back the partial token
return Ok(StreamResult::Incomplete);
}
// Check if we have the start marker
if !self.has_tool_markers(&state.buffer) {
return Ok(StreamResult::Incomplete);
}
// Find start and end positions
if let Some(start_pos) = self.find_tool_start(&state.buffer) {
// Check if we have the complete tool call
if let Some(end_pos) = self.find_tool_end(&state.buffer, start_pos) {
// Extract the JSON content
let json_start = start_pos + "<tool_call>\n".len();
let json_end = end_pos - "\n</tool_call>".len();
let json_str = &state.buffer[json_start..json_end];
// Parse the complete JSON
match serde_json::from_str::<Value>(json_str.trim()) {
Ok(value) => {
if let Some(tool) = self.parse_single_object(&value, 0)? {
// Clear the consumed part from buffer using drain for efficiency
state.buffer.drain(..end_pos);
return Ok(StreamResult::ToolComplete(tool));
}
}
Err(_) => {
// JSON parsing failed, might be incomplete
}
}
} else {
// We have start but no end yet - try partial parsing
let json_start = start_pos + "<tool_call>\n".len();
let partial_json = &state.buffer[json_start..];
// Remove trailing newline if present (might be start of end token)
let partial_json = partial_json.trim_end();
// Try to parse with partial JSON parser
match self.partial_json.parse_value(partial_json) {
Ok((value, _consumed)) => {
// Extract tool name if available
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
// Check if we've already sent the name
if !state.in_string {
state.in_string = true; // Use as flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for arguments
if let Some(args) = value.get("arguments") {
if let Ok(args_str) = serde_json::to_string(args) {
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Keep buffering
}
}
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
// Check if text contains Qwen-specific markers. If not, it's not this format.
if !self.has_tool_markers(text) {
return false;
}
// Try to extract tool calls to see if we have a complete, valid one.
let tool_blocks = self.extract_tool_calls(text);
for json_str in &tool_blocks {
if let Ok(value) = serde_json::from_str::<Value>(json_str.trim()) {
if let Some(obj) = value.as_object() {
if obj.contains_key("name") && obj.contains_key("arguments") {
// Found a valid, complete tool call.
return true;
}
}
}
}
// If we have the marker but no valid complete tool call,
// it could be a partial stream. We should detect this as the format.
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_qwen_format() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "Beijing", "units": "celsius"}}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Beijing"));
}
#[tokio::test]
async fn test_parse_multiple_tools() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{"name": "search", "arguments": {"query": "rust programming"}}
</tool_call>
<tool_call>
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "calculate");
}
#[tokio::test]
async fn test_with_normal_text() {
let parser = QwenParser::new();
let input = r#"Let me help you with that.
<tool_call>
{"name": "get_info", "arguments": {"topic": "Rust"}}
</tool_call>
Here are the results."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_info");
}
#[tokio::test]
async fn test_nested_json_structures() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{
"name": "process_data",
"arguments": {
"data": {
"nested": {
"array": [1, 2, 3],
"object": {"key": "value"}
}
}
}
}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process_data");
assert!(result[0].function.arguments.contains("nested"));
}
#[test]
fn test_detect_format() {
let parser = QwenParser::new();
assert!(parser.detect_format(
r#"<tool_call>
{"name": "test", "arguments": {}}
</tool_call>"#
));
assert!(parser.detect_format(
r#"Text before <tool_call>
{"name": "test", "arguments": {}}
</tool_call> text after"#
));
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(!parser.detect_format("plain text"));
// Partial format should still be detected
assert!(parser.detect_format("<tool_call>"));
}
#[tokio::test]
async fn test_streaming_partial() {
let parser = QwenParser::new();
let mut state = ParseState::new();
// Simulate streaming chunks
let chunks = vec![
"<tool_call>\n",
r#"{"name": "search","#,
r#" "arguments": {"query":"#,
r#" "rust"}}"#,
"\n</tool_call>",
];
let mut found_name = false;
let mut found_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
match result {
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "search");
found_name = true;
}
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "search");
found_complete = true;
}
_ => {}
}
}
assert!(found_name || found_complete); // At least one should be found
}
}

View File

@@ -0,0 +1,348 @@
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// Step3 format parser for tool calls
///
/// Handles the Step3 specific format with steptml XML:
/// `<tool_calls_begin><tool_call_begin>function<tool_sep><steptml:invoke name="{name}"><steptml:parameter name="{k}">{v}</steptml:parameter></steptml:invoke><tool_call_end><tool_calls_end>`
///
/// Features:
/// - Unicode token delimiters
/// - StepTML XML format for invocations
/// - Support for multiple sequential tool calls
pub struct Step3Parser {
/// Regex for extracting tool call blocks
tool_call_extractor: Regex,
/// Regex for extracting steptml invocations
invoke_extractor: Regex,
/// Regex for extracting parameters
param_extractor: Regex,
}
impl Step3Parser {
/// Create a new Step3 parser
pub fn new() -> Self {
// Pattern for individual tool calls
let tool_call_pattern = r"(?s)<tool_call_begin>.*?<tool_call_end>";
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
// Pattern for steptml invocations
let invoke_pattern = r#"(?s)<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>"#;
let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern");
// Pattern for steptml parameters - using non-greedy match for values to handle < characters
let param_pattern = r#"(?s)<steptml:parameter name="([^"]+)">(.+?)</steptml:parameter>"#;
let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern");
Self {
tool_call_extractor,
invoke_extractor,
param_extractor,
}
}
/// Check if text contains Step3 tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<tool_calls_begin>")
}
/// Parse parameters from steptml format
fn parse_steptml_parameters(
&self,
params_text: &str,
) -> ToolParserResult<serde_json::Map<String, Value>> {
let mut parameters = serde_json::Map::new();
for capture in self.param_extractor.captures_iter(params_text) {
let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
// Try to parse the value as JSON first, fallback to string
let param_value = if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
json_val
} else {
// Try parsing as Python literal
if param_value_str == "true" || param_value_str == "True" {
Value::Bool(true)
} else if param_value_str == "false" || param_value_str == "False" {
Value::Bool(false)
} else if param_value_str == "null" || param_value_str == "None" {
Value::Null
} else if let Ok(num) = param_value_str.parse::<i64>() {
Value::Number(num.into())
} else if let Ok(num) = param_value_str.parse::<f64>() {
if let Some(n) = serde_json::Number::from_f64(num) {
Value::Number(n)
} else {
Value::String(param_value_str.to_string())
}
} else {
Value::String(param_value_str.to_string())
}
};
parameters.insert(param_name.to_string(), param_value);
}
Ok(parameters)
}
/// Parse a single tool call block
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
// Check if it contains function marker and tool separator
if !block.contains("function") || !block.contains("<tool_sep>") {
return Ok(None);
}
// Split by tool separator
let parts: Vec<&str> = block.split("<tool_sep>").collect();
if parts.len() != 2 {
return Ok(None);
}
// Check if it's a function type
if !parts[0].contains("function") {
return Ok(None);
}
let invoke_part = parts[1];
// Extract steptml invoke
if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
// Validate function name is not empty
if func_name.is_empty() {
return Ok(None);
}
let params_text = captures.get(2).map_or("", |m| m.as_str());
// Parse parameters
let parameters = self.parse_steptml_parameters(params_text)?;
let arguments_str = serde_json::to_string(&parameters)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("step3_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments: arguments_str,
},
}))
} else {
Ok(None)
}
}
}
impl Default for Step3Parser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for Step3Parser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// Check if text contains Step3 format
if !self.has_tool_markers(text) {
return Ok(vec![]);
}
// Find the tool calls section
if let Some(start_pos) = text.find("<tool_calls_begin>") {
let search_from = start_pos + "<tool_calls_begin>".len();
// Find the end of tool calls section
if let Some(end_pos) = text[search_from..].find("<tool_calls_end>") {
let tool_section = &text[search_from..search_from + end_pos];
// Extract all tool call blocks
let mut tools = Vec::new();
for mat in self.tool_call_extractor.find_iter(tool_section) {
if let Some(tool) = self.parse_tool_call(mat.as_str())? {
tools.push(tool);
}
}
return Ok(tools);
}
}
Ok(vec![])
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check for tool markers
if !self.has_tool_markers(&state.buffer) {
// No markers found, return as incomplete
return Ok(StreamResult::Incomplete);
}
// Look for start of tool calls
if let Some(start_pos) = state.buffer.find("<tool_calls_begin>") {
let search_from = start_pos + "<tool_calls_begin>".len();
// Look for individual tool call start
if let Some(call_start) = state.buffer[search_from..].find("<tool_call_begin>") {
let call_start_abs = search_from + call_start;
// Look for the end of this tool call
let search_end_from = call_start_abs + "<tool_call_begin>".len();
if let Some(call_end) = state.buffer[search_end_from..].find("<tool_call_end>")
{
let call_end_abs = search_end_from + call_end + "<tool_call_end>".len();
// Extract and parse the complete tool call
let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
// Remove the processed part from buffer
state.buffer.drain(..call_end_abs);
return Ok(StreamResult::ToolComplete(tool));
}
} else {
// Tool call not complete yet, try to extract partial info
let partial = &state.buffer[search_end_from..];
// Check for tool separator
if let Some(sep_pos) = partial.find("<tool_sep>") {
// Check if it's a function
if partial[..sep_pos].contains("function") {
let after_sep = &partial[sep_pos + "<tool_sep>".len()..];
// Try to extract function name from steptml:invoke
if let Some(name_match) = self.invoke_extractor.captures(after_sep) {
let func_name = name_match.get(1).map_or("", |m| m.as_str()).trim();
if !state.in_string && !func_name.is_empty() {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.to_string(),
});
}
// Try to extract partial parameters
if let Some(params_text) = name_match.get(2) {
let parameters =
self.parse_steptml_parameters(params_text.as_str())?;
if !parameters.is_empty() {
let args_str = serde_json::to_string(&parameters)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
}
}
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_step3_single_tool() {
let parser = Step3Parser::new();
let input = r#"Some text
<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="get_weather">
<steptml:parameter name="location">Tokyo</steptml:parameter>
<steptml:parameter name="units">celsius</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>More text"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Tokyo"));
assert!(result[0].function.arguments.contains("celsius"));
}
#[tokio::test]
async fn test_parse_step3_multiple_tools() {
let parser = Step3Parser::new();
let input = r#"<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="search">
<steptml:parameter name="query">rust programming</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_call_begin>function<tool_sep><steptml:invoke name="calculate">
<steptml:parameter name="expression">2 + 2</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "calculate");
}
#[tokio::test]
async fn test_parse_step3_mixed_types() {
let parser = Step3Parser::new();
let input = r#"<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="process_data">
<steptml:parameter name="count">42</steptml:parameter>
<steptml:parameter name="active">true</steptml:parameter>
<steptml:parameter name="rate">1.5</steptml:parameter>
<steptml:parameter name="name">test</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process_data");
// Parse arguments to check types
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["count"], 42);
assert_eq!(args["active"], true);
assert_eq!(args["rate"], 1.5);
assert_eq!(args["name"], "test");
}
#[test]
fn test_detect_format() {
let parser = Step3Parser::new();
assert!(parser.detect_format("<tool_calls_begin>"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}

View File

@@ -0,0 +1,527 @@
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
traits::PartialJsonParser,
};
use serde_json::{Map, Value};
/// Parser for incomplete JSON
pub struct PartialJson {
/// Maximum depth for nested structures
max_depth: usize,
/// Whether to allow incomplete values
allow_incomplete: bool,
}
impl PartialJson {
/// Create a new partial JSON parser
pub fn new(max_depth: usize, allow_incomplete: bool) -> Self {
Self {
max_depth,
allow_incomplete,
}
}
/// Parse potentially incomplete JSON, returning parsed value and consumed bytes
pub fn parse_value(&self, input: &str) -> ToolParserResult<(Value, usize)> {
let mut parser = Parser::new(input, self.max_depth, self.allow_incomplete);
let value = parser.parse_value(0)?;
Ok((value, parser.position))
}
}
impl Default for PartialJson {
fn default() -> Self {
Self::new(32, true)
}
}
impl PartialJsonParser for PartialJson {
fn parse(&self, input: &str) -> ToolParserResult<(Value, usize)> {
self.parse_value(input)
}
fn is_complete(&self, input: &str) -> bool {
// Try to parse as complete JSON
serde_json::from_str::<Value>(input).is_ok()
}
fn max_depth(&self) -> usize {
self.max_depth
}
}
/// Internal parser state
struct Parser<'a> {
chars: std::iter::Peekable<std::str::Chars<'a>>,
position: usize,
max_depth: usize,
allow_incomplete: bool,
}
impl<'a> Parser<'a> {
fn new(input: &'a str, max_depth: usize, allow_incomplete: bool) -> Self {
Self {
chars: input.chars().peekable(),
position: 0,
max_depth,
allow_incomplete,
}
}
fn peek(&mut self) -> Option<char> {
self.chars.peek().copied()
}
fn advance(&mut self) {
if self.chars.next().is_some() {
self.position += 1;
}
}
fn skip_whitespace(&mut self) {
while let Some(ch) = self.peek() {
if ch.is_whitespace() {
self.advance();
} else {
break;
}
}
}
fn parse_value(&mut self, depth: usize) -> ToolParserResult<Value> {
if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth));
}
self.skip_whitespace();
match self.peek() {
Some('{') => self.parse_object(depth + 1),
Some('[') => self.parse_array(depth + 1),
Some('"') => self.parse_string(),
Some('t') | Some('f') => self.parse_bool(),
Some('n') => self.parse_null(),
Some(c) if c == '-' || c.is_ascii_digit() => self.parse_number(),
_ => {
if self.allow_incomplete {
Ok(Value::Null)
} else {
Err(ToolParserError::ParsingFailed(
"Unexpected character".into(),
))
}
}
}
}
fn parse_object(&mut self, depth: usize) -> ToolParserResult<Value> {
if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth));
}
let mut object = Map::new();
// Consume '{'
self.advance();
self.skip_whitespace();
// Check for empty object
if self.peek() == Some('}') {
self.advance();
return Ok(Value::Object(object));
}
loop {
// Parse key
let key = match self.parse_string() {
Ok(Value::String(s)) => s,
Err(_) if self.allow_incomplete => {
// Incomplete object
return Ok(Value::Object(object));
}
Err(e) => return Err(e),
_ => return Err(ToolParserError::ParsingFailed("Expected string key".into())),
};
self.skip_whitespace();
// Expect ':'
if self.peek() != Some(':') {
if self.allow_incomplete {
// Add null value for incomplete pair
object.insert(key, Value::Null);
return Ok(Value::Object(object));
}
return Err(ToolParserError::ParsingFailed("Expected ':'".into()));
}
self.advance();
self.skip_whitespace();
// Parse value (keep same depth - we already incremented in parse_object)
let value = match self.parse_value(depth) {
Ok(v) => v,
Err(_) if self.allow_incomplete => {
// Add null for incomplete value
object.insert(key, Value::Null);
return Ok(Value::Object(object));
}
Err(e) => return Err(e),
};
object.insert(key, value);
self.skip_whitespace();
match self.peek() {
Some(',') => {
self.advance();
self.skip_whitespace();
// Check for trailing comma
if self.peek() == Some('}') {
self.advance();
return Ok(Value::Object(object));
}
}
Some('}') => {
self.advance();
return Ok(Value::Object(object));
}
None if self.allow_incomplete => {
return Ok(Value::Object(object));
}
_ => {
if self.allow_incomplete {
return Ok(Value::Object(object));
}
return Err(ToolParserError::ParsingFailed("Expected ',' or '}'".into()));
}
}
}
}
fn parse_array(&mut self, depth: usize) -> ToolParserResult<Value> {
if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth));
}
let mut array = Vec::new();
// Consume '['
self.advance();
self.skip_whitespace();
// Check for empty array
if self.peek() == Some(']') {
self.advance();
return Ok(Value::Array(array));
}
loop {
// Parse value (keep same depth - we already incremented in parse_object)
let value = match self.parse_value(depth) {
Ok(v) => v,
Err(_) if self.allow_incomplete => {
return Ok(Value::Array(array));
}
Err(e) => return Err(e),
};
array.push(value);
self.skip_whitespace();
match self.peek() {
Some(',') => {
self.advance();
self.skip_whitespace();
// Check for trailing comma
if self.peek() == Some(']') {
self.advance();
return Ok(Value::Array(array));
}
}
Some(']') => {
self.advance();
return Ok(Value::Array(array));
}
None if self.allow_incomplete => {
return Ok(Value::Array(array));
}
_ => {
if self.allow_incomplete {
return Ok(Value::Array(array));
}
return Err(ToolParserError::ParsingFailed("Expected ',' or ']'".into()));
}
}
}
}
fn parse_string(&mut self) -> ToolParserResult<Value> {
if self.peek() != Some('"') {
return Err(ToolParserError::ParsingFailed("Expected '\"'".into()));
}
// Consume opening quote
self.advance();
let mut string = String::new();
let mut escaped = false;
while let Some(ch) = self.peek() {
if escaped {
// Handle escape sequences
let escaped_char = match ch {
'"' | '\\' | '/' => ch,
'b' => '\u{0008}',
'f' => '\u{000C}',
'n' => '\n',
'r' => '\r',
't' => '\t',
'u' => {
// Unicode escape
self.advance();
let hex = self.parse_unicode_escape()?;
string.push(hex);
escaped = false;
continue;
}
_ => ch, // Invalid escape, but be lenient
};
string.push(escaped_char);
escaped = false;
} else if ch == '\\' {
escaped = true;
} else if ch == '"' {
// End of string
self.advance();
return Ok(Value::String(string));
} else {
string.push(ch);
}
self.advance();
}
// Incomplete string
if self.allow_incomplete {
Ok(Value::String(string))
} else {
Err(ToolParserError::ParsingFailed("Unterminated string".into()))
}
}
fn parse_unicode_escape(&mut self) -> ToolParserResult<char> {
let mut hex = String::new();
for _ in 0..4 {
if let Some(ch) = self.peek() {
if ch.is_ascii_hexdigit() {
hex.push(ch);
self.advance();
} else {
break;
}
} else {
break;
}
}
if hex.len() == 4 {
u32::from_str_radix(&hex, 16)
.ok()
.and_then(char::from_u32)
.ok_or_else(|| ToolParserError::ParsingFailed("Invalid unicode escape".into()))
} else if self.allow_incomplete {
Ok('\u{FFFD}') // Replacement character
} else {
Err(ToolParserError::ParsingFailed(
"Incomplete unicode escape".into(),
))
}
}
fn parse_number(&mut self) -> ToolParserResult<Value> {
let mut number = String::new();
// Handle negative sign
if self.peek() == Some('-') {
number.push('-');
self.advance();
}
// Parse integer part
if self.peek() == Some('0') {
number.push('0');
self.advance();
} else {
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
number.push(ch);
self.advance();
} else {
break;
}
}
}
// Parse decimal part
if self.peek() == Some('.') {
number.push('.');
self.advance();
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
number.push(ch);
self.advance();
} else {
break;
}
}
}
// Parse exponent
if let Some(ch) = self.peek() {
if ch == 'e' || ch == 'E' {
number.push(ch);
self.advance();
if let Some(sign) = self.peek() {
if sign == '+' || sign == '-' {
number.push(sign);
self.advance();
}
}
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
number.push(ch);
self.advance();
} else {
break;
}
}
}
}
// Try to parse as integer first, then as float
if let Ok(n) = number.parse::<i64>() {
Ok(Value::Number(serde_json::Number::from(n)))
} else if let Ok(n) = number.parse::<f64>() {
Ok(Value::Number(
serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)),
))
} else if self.allow_incomplete {
Ok(Value::Number(serde_json::Number::from(0)))
} else {
Err(ToolParserError::ParsingFailed("Invalid number".into()))
}
}
fn parse_bool(&mut self) -> ToolParserResult<Value> {
let mut word = String::new();
// Peek at upcoming characters to validate it looks like a boolean
let mut temp_chars = self.chars.clone();
while let Some(&ch) = temp_chars.peek() {
if ch.is_alphabetic() && word.len() < 5 {
// "false" is 5 chars
word.push(ch);
temp_chars.next();
} else {
break;
}
}
// Check if it's a valid boolean prefix
let is_valid = word == "true"
|| word == "false"
|| (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word)));
if !is_valid {
return Err(ToolParserError::ParsingFailed("Invalid boolean".into()));
}
// Now actually consume the characters
word.clear();
while let Some(ch) = self.peek() {
if ch.is_alphabetic() {
word.push(ch);
self.advance();
} else {
break;
}
}
match word.as_str() {
"true" => Ok(Value::Bool(true)),
"false" => Ok(Value::Bool(false)),
partial if self.allow_incomplete => {
if "true".starts_with(partial) {
Ok(Value::Bool(true))
} else if "false".starts_with(partial) {
Ok(Value::Bool(false))
} else {
Err(ToolParserError::ParsingFailed("Invalid boolean".into()))
}
}
_ => Err(ToolParserError::ParsingFailed("Invalid boolean".into())),
}
}
fn parse_null(&mut self) -> ToolParserResult<Value> {
let mut word = String::new();
// Peek at upcoming characters to validate it looks like "null"
let mut temp_chars = self.chars.clone();
while let Some(&ch) = temp_chars.peek() {
if ch.is_alphabetic() && word.len() < 4 {
// "null" is 4 chars
word.push(ch);
temp_chars.next();
} else {
break;
}
}
// Check if it's a valid null prefix
let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word));
if !is_valid {
return Err(ToolParserError::ParsingFailed("Invalid null".into()));
}
// Now actually consume the characters
word.clear();
while let Some(ch) = self.peek() {
if ch.is_alphabetic() {
word.push(ch);
self.advance();
} else {
break;
}
}
if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) {
Ok(Value::Null)
} else {
Err(ToolParserError::ParsingFailed("Invalid null".into()))
}
}
}
/// Utility function to check if a string contains complete JSON
pub fn is_complete_json(input: &str) -> bool {
serde_json::from_str::<Value>(input).is_ok()
}
/// Utility function to find common prefix between two strings
pub fn find_common_prefix(s1: &str, s2: &str) -> usize {
s1.chars()
.zip(s2.chars())
.take_while(|(a, b)| a == b)
.count()
}
/// Utility function to compute diff between old and new strings
pub fn compute_diff(old: &str, new: &str) -> String {
let common_len = find_common_prefix(old, new);
// Convert character count to byte offset
new.chars().skip(common_len).collect()
}

View File

@@ -0,0 +1,442 @@
/// Minimal Python literal parser for Pythonic tool call format
///
/// This module provides a recursive descent parser for Python literals
/// (strings, numbers, booleans, None, lists, dicts) without requiring
/// a full Python AST parser.
use serde_json::{json, Value};
use std::collections::HashMap;
use crate::tool_parser::errors::{ToolParserError, ToolParserResult};
/// Token types for Python literals
#[derive(Debug, Clone, PartialEq)]
enum Token {
// Literals
String(String),
Number(String),
True,
False,
None,
// Delimiters
LeftBracket, // [
RightBracket, // ]
LeftBrace, // {
RightBrace, // }
LeftParen, // (
RightParen, // )
Comma, // ,
Colon, // :
Equals, // =
// Identifier for function names
Identifier(String),
// End of input
Eof,
}
/// Lexer for Python literals
struct Lexer {
input: Vec<char>,
position: usize,
}
impl Lexer {
fn new(input: &str) -> Self {
Self {
input: input.chars().collect(),
position: 0,
}
}
fn current_char(&self) -> Option<char> {
self.input.get(self.position).copied()
}
fn advance(&mut self) {
if self.position < self.input.len() {
self.position += 1;
}
}
fn skip_whitespace(&mut self) {
while let Some(ch) = self.current_char() {
if ch.is_whitespace() {
self.advance();
} else {
break;
}
}
}
fn read_string(&mut self, quote_char: char) -> Result<String, ToolParserError> {
let mut result = String::new();
self.advance(); // Skip opening quote
while let Some(ch) = self.current_char() {
if ch == '\\' {
self.advance();
if let Some(escaped) = self.current_char() {
match escaped {
'n' => result.push('\n'),
't' => result.push('\t'),
'r' => result.push('\r'),
'\\' => result.push('\\'),
'\'' => result.push('\''),
'"' => result.push('"'),
_ => {
result.push('\\');
result.push(escaped);
}
}
self.advance();
}
} else if ch == quote_char {
self.advance(); // Skip closing quote
return Ok(result);
} else {
result.push(ch);
self.advance();
}
}
Err(ToolParserError::ParsingFailed("Unterminated string".into()))
}
fn read_number(&mut self) -> String {
let mut result = String::new();
// Handle negative numbers
if self.current_char() == Some('-') {
result.push('-');
self.advance();
}
// Read digits and decimal point
while let Some(ch) = self.current_char() {
if ch.is_ascii_digit() || ch == '.' || ch == 'e' || ch == 'E' || ch == '+' || ch == '-'
{
result.push(ch);
self.advance();
} else {
break;
}
}
result
}
fn read_identifier(&mut self) -> String {
let mut result = String::new();
while let Some(ch) = self.current_char() {
if ch.is_alphanumeric() || ch == '_' {
result.push(ch);
self.advance();
} else {
break;
}
}
result
}
fn next_token(&mut self) -> Result<Token, ToolParserError> {
self.skip_whitespace();
match self.current_char() {
None => Ok(Token::Eof),
Some('[') => {
self.advance();
Ok(Token::LeftBracket)
}
Some(']') => {
self.advance();
Ok(Token::RightBracket)
}
Some('{') => {
self.advance();
Ok(Token::LeftBrace)
}
Some('}') => {
self.advance();
Ok(Token::RightBrace)
}
Some('(') => {
self.advance();
Ok(Token::LeftParen)
}
Some(')') => {
self.advance();
Ok(Token::RightParen)
}
Some(',') => {
self.advance();
Ok(Token::Comma)
}
Some(':') => {
self.advance();
Ok(Token::Colon)
}
Some('=') => {
self.advance();
Ok(Token::Equals)
}
Some('"') => Ok(Token::String(self.read_string('"')?)),
Some('\'') => Ok(Token::String(self.read_string('\'')?)),
Some(ch) if ch == '-' || ch.is_ascii_digit() => Ok(Token::Number(self.read_number())),
Some(ch) if ch.is_alphabetic() || ch == '_' => {
let ident = self.read_identifier();
match ident.as_str() {
"True" => Ok(Token::True),
"False" => Ok(Token::False),
"None" => Ok(Token::None),
_ => Ok(Token::Identifier(ident)),
}
}
Some(ch) => Err(ToolParserError::ParsingFailed(format!(
"Unexpected character: {}",
ch
))),
}
}
}
/// Parser for Python literals
pub struct PythonLiteralParser {
lexer: Lexer,
current_token: Token,
}
impl PythonLiteralParser {
pub fn new(input: &str) -> Result<Self, ToolParserError> {
let mut lexer = Lexer::new(input);
let current_token = lexer.next_token()?;
Ok(Self {
lexer,
current_token,
})
}
fn advance(&mut self) -> Result<(), ToolParserError> {
self.current_token = self.lexer.next_token()?;
Ok(())
}
fn expect(&mut self, expected: Token) -> Result<(), ToolParserError> {
if self.current_token == expected {
self.advance()?;
Ok(())
} else {
Err(ToolParserError::ParsingFailed(format!(
"Expected {:?}, got {:?}",
expected, self.current_token
)))
}
}
/// Parse a Python literal value
pub fn parse_value(&mut self) -> Result<Value, ToolParserError> {
match &self.current_token.clone() {
Token::String(s) => {
let value = s.clone();
self.advance()?;
Ok(json!(value))
}
Token::Number(n) => {
let value = if let Ok(int_val) = n.parse::<i64>() {
json!(int_val)
} else if let Ok(float_val) = n.parse::<f64>() {
json!(float_val)
} else {
return Err(ToolParserError::ParsingFailed(format!(
"Invalid number: {}",
n
)));
};
self.advance()?;
Ok(value)
}
Token::True => {
self.advance()?;
Ok(json!(true))
}
Token::False => {
self.advance()?;
Ok(json!(false))
}
Token::None => {
self.advance()?;
Ok(Value::Null)
}
Token::LeftBracket => self.parse_list(),
Token::LeftBrace => self.parse_dict(),
_ => Err(ToolParserError::ParsingFailed(format!(
"Unexpected token: {:?}",
self.current_token
))),
}
}
/// Parse a Python list: [item1, item2, ...]
fn parse_list(&mut self) -> Result<Value, ToolParserError> {
self.expect(Token::LeftBracket)?;
let mut items = Vec::new();
// Handle empty list
if self.current_token == Token::RightBracket {
self.advance()?;
return Ok(json!(items));
}
loop {
items.push(self.parse_value()?);
if self.current_token == Token::Comma {
self.advance()?;
// Handle trailing comma
if self.current_token == Token::RightBracket {
break;
}
} else if self.current_token == Token::RightBracket {
break;
} else {
return Err(ToolParserError::ParsingFailed(format!(
"Expected ',' or ']', got {:?}",
self.current_token
)));
}
}
self.expect(Token::RightBracket)?;
Ok(json!(items))
}
/// Parse a Python dict: {key1: value1, key2: value2, ...}
fn parse_dict(&mut self) -> Result<Value, ToolParserError> {
self.expect(Token::LeftBrace)?;
let mut map = HashMap::new();
// Handle empty dict
if self.current_token == Token::RightBrace {
self.advance()?;
return Ok(json!(map));
}
loop {
// Parse key (must be a string)
let key = match &self.current_token {
Token::String(s) => {
let k = s.clone();
self.advance()?;
k
}
_ => {
return Err(ToolParserError::ParsingFailed(format!(
"Expected string key, got {:?}",
self.current_token
)))
}
};
self.expect(Token::Colon)?;
// Parse value
let value = self.parse_value()?;
map.insert(key, value);
if self.current_token == Token::Comma {
self.advance()?;
// Handle trailing comma
if self.current_token == Token::RightBrace {
break;
}
} else if self.current_token == Token::RightBrace {
break;
} else {
return Err(ToolParserError::ParsingFailed(format!(
"Expected ',' or '}}', got {:?}",
self.current_token
)));
}
}
self.expect(Token::RightBrace)?;
Ok(json!(map))
}
}
/// Parse a Python literal string into a JSON value
pub fn parse_python_literal(input: &str) -> ToolParserResult<Value> {
let mut parser = PythonLiteralParser::new(input)?;
parser.parse_value()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_primitives() {
assert_eq!(parse_python_literal("True").unwrap(), json!(true));
assert_eq!(parse_python_literal("False").unwrap(), json!(false));
assert_eq!(parse_python_literal("None").unwrap(), Value::Null);
assert_eq!(parse_python_literal("42").unwrap(), json!(42));
assert_eq!(parse_python_literal("12.345").unwrap(), json!(12.345));
assert_eq!(parse_python_literal("-42").unwrap(), json!(-42));
assert_eq!(parse_python_literal("\"hello\"").unwrap(), json!("hello"));
assert_eq!(parse_python_literal("'world'").unwrap(), json!("world"));
}
#[test]
fn test_parse_list() {
assert_eq!(parse_python_literal("[]").unwrap(), json!([]));
assert_eq!(parse_python_literal("[1, 2, 3]").unwrap(), json!([1, 2, 3]));
assert_eq!(
parse_python_literal("[\"a\", \"b\", \"c\"]").unwrap(),
json!(["a", "b", "c"])
);
assert_eq!(
parse_python_literal("[True, False, None]").unwrap(),
json!([true, false, null])
);
// Nested list
assert_eq!(
parse_python_literal("[[1, 2], [3, 4]]").unwrap(),
json!([[1, 2], [3, 4]])
);
}
#[test]
fn test_parse_dict() {
assert_eq!(parse_python_literal("{}").unwrap(), json!({}));
assert_eq!(
parse_python_literal("{\"a\": 1, \"b\": 2}").unwrap(),
json!({"a": 1, "b": 2})
);
assert_eq!(
parse_python_literal("{'x': True, 'y': False}").unwrap(),
json!({"x": true, "y": false})
);
// Nested dict
assert_eq!(
parse_python_literal("{\"nested\": {\"value\": [1, 2, 3]}}").unwrap(),
json!({"nested": {"value": [1, 2, 3]}})
);
}
#[test]
fn test_complex_nested() {
let input = r#"{"config": {"nested": {"value": [1, 2, 3]}, "enabled": True}}"#;
let expected = json!({
"config": {
"nested": {
"value": [1, 2, 3]
},
"enabled": true
}
});
assert_eq!(parse_python_literal(input).unwrap(), expected);
}
}

View File

@@ -0,0 +1,224 @@
use crate::tool_parser::parsers::{
DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser,
MistralParser, PythonicParser, QwenParser, Step3Parser,
};
use crate::tool_parser::traits::ToolParser;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::Arc;
/// Global singleton registry instance - created once and reused
pub static GLOBAL_REGISTRY: Lazy<ParserRegistry> = Lazy::new(ParserRegistry::new_internal);
/// Registry for tool parsers and model mappings
pub struct ParserRegistry {
/// Map of parser name to parser instance
parsers: HashMap<String, Arc<dyn ToolParser>>,
/// Map of model name/pattern to parser name
model_mapping: HashMap<String, String>,
/// Default parser to use when no match found
default_parser: String,
}
impl ParserRegistry {
/// Get the global singleton instance
pub fn new() -> &'static Self {
&GLOBAL_REGISTRY
}
/// Create a new instance for testing (not the singleton)
#[cfg(test)]
pub fn new_for_testing() -> Self {
Self::new_internal()
}
/// Internal constructor for creating the singleton instance
fn new_internal() -> Self {
let mut registry = Self {
parsers: HashMap::new(),
model_mapping: HashMap::new(),
default_parser: "json".to_string(),
};
// Register default parsers
registry.register_default_parsers();
// Register default model mappings
registry.register_default_mappings();
registry
}
/// Register a parser
pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
self.parsers.insert(name.into(), parser);
}
/// Map a model name/pattern to a parser
pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
self.model_mapping.insert(model.into(), parser.into());
}
/// Get parser for a specific model
pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
// Try exact match first
if let Some(parser_name) = self.model_mapping.get(model) {
if let Some(parser) = self.parsers.get(parser_name) {
return Some(parser.clone());
}
}
// Try prefix matching with more specific patterns first
// Collect all matching patterns and sort by specificity (longer = more specific)
let mut matches: Vec<(&String, &String)> = self
.model_mapping
.iter()
.filter(|(pattern, _)| {
if pattern.ends_with('*') {
let prefix = &pattern[..pattern.len() - 1];
model.starts_with(prefix)
} else {
false
}
})
.collect();
// Sort by pattern length in descending order (longer patterns are more specific)
matches.sort_by_key(|(pattern, _)| std::cmp::Reverse(pattern.len()));
// Return the first matching parser
for (_, parser_name) in matches {
if let Some(parser) = self.parsers.get(parser_name) {
return Some(parser.clone());
}
}
// Fall back to default parser if it exists
self.parsers.get(&self.default_parser).cloned()
}
/// List all registered parsers
pub fn list_parsers(&self) -> Vec<&str> {
self.parsers.keys().map(|s| s.as_str()).collect()
}
/// List all model mappings
pub fn list_mappings(&self) -> Vec<(&str, &str)> {
self.model_mapping
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect()
}
/// Register default parsers
fn register_default_parsers(&mut self) {
// JSON parser - most common format
self.register_parser("json", Arc::new(JsonParser::new()));
// Mistral parser - [TOOL_CALLS] [...] format
self.register_parser("mistral", Arc::new(MistralParser::new()));
// Qwen parser - <tool_call>...</tool_call> format
self.register_parser("qwen", Arc::new(QwenParser::new()));
// Pythonic parser - [func(arg=val)] format
self.register_parser("pythonic", Arc::new(PythonicParser::new()));
// Llama parser - <|python_tag|>{...} or plain JSON format
self.register_parser("llama", Arc::new(LlamaParser::new()));
// DeepSeek V3 parser - Unicode tokens with JSON blocks
self.register_parser("deepseek", Arc::new(DeepSeekParser::new()));
// GLM-4 MoE parser - XML-style key-value format
self.register_parser("glm4_moe", Arc::new(Glm4MoeParser::new()));
// Step3 parser - StepTML XML format
self.register_parser("step3", Arc::new(Step3Parser::new()));
// Kimi K2 parser - Token-based with indexed functions
self.register_parser("kimik2", Arc::new(KimiK2Parser::new()));
// GPT-OSS parser - Channel format
self.register_parser("gpt_oss", Arc::new(GptOssParser::new()));
}
/// Register default model mappings
fn register_default_mappings(&mut self) {
// OpenAI models
self.map_model("gpt-4*", "json");
self.map_model("gpt-3.5*", "json");
self.map_model("gpt-4o*", "json");
// Anthropic models
self.map_model("claude-*", "json");
// Mistral models - use Mistral parser
self.map_model("mistral-*", "mistral");
self.map_model("mixtral-*", "mistral");
// Qwen models - use Qwen parser
self.map_model("qwen*", "qwen");
self.map_model("Qwen*", "qwen");
// Llama models
// Llama 4 uses pythonic format
self.map_model("llama-4*", "pythonic");
self.map_model("meta-llama-4*", "pythonic");
// Llama 3.2 uses python_tag format
self.map_model("llama-3.2*", "llama");
self.map_model("meta-llama-3.2*", "llama");
// Other Llama models use JSON
self.map_model("llama-*", "json");
self.map_model("meta-llama-*", "json");
// DeepSeek models
// DeepSeek V3 uses custom Unicode token format
self.map_model("deepseek-v3*", "deepseek");
self.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
// DeepSeek V2 uses pythonic format
self.map_model("deepseek-*", "pythonic");
// GLM models
// GLM-4 MoE uses XML-style format
self.map_model("glm-4-moe*", "glm4_moe");
self.map_model("THUDM/glm-4-moe*", "glm4_moe");
self.map_model("glm-4.5*", "glm4_moe");
// Other GLM models may use JSON
self.map_model("glm-*", "json");
// Step3 models
self.map_model("step3*", "step3");
self.map_model("Step-3*", "step3");
// Kimi models
self.map_model("kimi-k2*", "kimik2");
self.map_model("Kimi-K2*", "kimik2");
self.map_model("moonshot*/Kimi-K2*", "kimik2");
// GPT-OSS models (T4-style)
self.map_model("gpt-oss*", "gpt_oss");
self.map_model("t4-*", "gpt_oss");
// Other models default to JSON
self.map_model("gemini-*", "json");
self.map_model("palm-*", "json");
self.map_model("gemma-*", "json");
}
/// Set the default parser
pub fn set_default_parser(&mut self, name: impl Into<String>) {
self.default_parser = name.into();
}
/// Check if a parser is registered
pub fn has_parser(&self, name: &str) -> bool {
self.parsers.contains_key(name)
}
}
impl Default for &'static ParserRegistry {
fn default() -> Self {
ParserRegistry::new()
}
}

View File

@@ -0,0 +1,181 @@
use crate::tool_parser::types::{PartialToolCall, ToolCall};
/// Current phase of parsing
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParsePhase {
/// Looking for start of tool call
Searching,
/// Parsing function name
InName,
/// Parsing function arguments
InArguments,
/// Tool call complete
Complete,
}
/// State for streaming parser
#[derive(Debug, Clone)]
pub struct ParseState {
/// Buffer for accumulating input
pub buffer: String,
/// Position of last consumed character
pub consumed: usize,
/// Current partial tool being parsed
pub partial_tool: Option<PartialToolCall>,
/// Completed tool calls
pub completed_tools: Vec<ToolCall>,
/// Current parsing phase
pub phase: ParsePhase,
/// Bracket/brace depth for JSON parsing
pub bracket_depth: i32,
/// Whether currently inside a string literal
pub in_string: bool,
/// Whether next character should be escaped
pub escape_next: bool,
/// Current tool index (for streaming)
pub tool_index: usize,
}
impl ParseState {
/// Create a new parse state
pub fn new() -> Self {
Self {
buffer: String::new(),
consumed: 0,
partial_tool: None,
completed_tools: Vec::new(),
phase: ParsePhase::Searching,
bracket_depth: 0,
in_string: false,
escape_next: false,
tool_index: 0,
}
}
/// Reset state for parsing next tool
pub fn reset(&mut self) {
self.partial_tool = None;
self.phase = ParsePhase::Searching;
self.bracket_depth = 0;
self.in_string = false;
self.escape_next = false;
}
/// Process a single character for JSON parsing
pub fn process_char(&mut self, ch: char) {
// Handle escape sequences
if self.escape_next {
self.escape_next = false;
self.buffer.push(ch);
return;
}
if ch == '\\' && self.in_string {
self.escape_next = true;
self.buffer.push(ch);
return;
}
// Track string boundaries
if ch == '"' && !self.escape_next {
self.in_string = !self.in_string;
}
// Track bracket depth for JSON
if !self.in_string {
match ch {
'{' | '[' => {
self.bracket_depth += 1;
}
'}' | ']' => {
self.bracket_depth -= 1;
if self.bracket_depth == 0 && self.partial_tool.is_some() {
// Complete tool call found
self.phase = ParsePhase::Complete;
}
}
_ => {}
}
}
self.buffer.push(ch);
}
/// Check if we have a complete JSON object/array
pub fn has_complete_json(&self) -> bool {
self.bracket_depth == 0 && !self.in_string && !self.buffer.is_empty()
}
/// Extract content from buffer starting at position
pub fn extract_from(&self, start: usize) -> &str {
if start >= self.buffer.len() {
return "";
}
// Find the nearest character boundary at or after start
let mut safe_start = start;
while safe_start < self.buffer.len() && !self.buffer.is_char_boundary(safe_start) {
safe_start += 1;
}
if safe_start < self.buffer.len() {
&self.buffer[safe_start..]
} else {
""
}
}
/// Mark content as consumed up to position
pub fn consume_to(&mut self, position: usize) {
if position > self.consumed {
self.consumed = position;
}
}
/// Get unconsumed content
pub fn unconsumed(&self) -> &str {
if self.consumed >= self.buffer.len() {
return "";
}
// Find the nearest character boundary at or after consumed
let mut safe_consumed = self.consumed;
while safe_consumed < self.buffer.len() && !self.buffer.is_char_boundary(safe_consumed) {
safe_consumed += 1;
}
if safe_consumed < self.buffer.len() {
&self.buffer[safe_consumed..]
} else {
""
}
}
/// Clear consumed content from buffer
pub fn clear_consumed(&mut self) {
if self.consumed > 0 {
// Find the nearest character boundary at or before consumed
let mut safe_consumed = self.consumed;
while safe_consumed > 0 && !self.buffer.is_char_boundary(safe_consumed) {
safe_consumed -= 1;
}
if safe_consumed > 0 {
self.buffer.drain(..safe_consumed);
self.consumed = self.consumed.saturating_sub(safe_consumed);
}
}
}
/// Add completed tool
pub fn add_completed_tool(&mut self, tool: ToolCall) {
self.completed_tools.push(tool);
self.tool_index += 1;
}
}
impl Default for ParseState {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,886 @@
use super::*;
use crate::tool_parser::parsers::JsonParser;
use crate::tool_parser::partial_json::{
compute_diff, find_common_prefix, is_complete_json, PartialJson,
};
use crate::tool_parser::traits::ToolParser;
use crate::tool_parser::types::TokenConfig;
#[test]
fn test_parse_state_new() {
let state = ParseState::new();
assert_eq!(state.phase, ParsePhase::Searching);
assert_eq!(state.buffer, "");
assert_eq!(state.consumed, 0);
assert_eq!(state.bracket_depth, 0);
assert!(!state.in_string);
assert!(!state.escape_next);
}
#[test]
fn test_parse_state_process_char() {
let mut state = ParseState::new();
// Test bracket tracking
state.process_char('{');
assert_eq!(state.bracket_depth, 1);
state.process_char('}');
assert_eq!(state.bracket_depth, 0);
// Test string tracking
state.process_char('"');
assert!(state.in_string);
state.process_char('"');
assert!(!state.in_string);
// Test escape handling
state.process_char('"');
state.process_char('\\');
assert!(state.escape_next);
state.process_char('"');
assert!(!state.escape_next);
assert!(state.in_string); // Still in string because quote was escaped
}
#[test]
fn test_token_config() {
let config = TokenConfig {
start_tokens: vec!["<start>".to_string(), "[".to_string()],
end_tokens: vec!["</end>".to_string(), "]".to_string()],
separator: ", ".to_string(),
};
let pairs: Vec<_> = config.iter_pairs().collect();
assert_eq!(pairs.len(), 2);
assert_eq!(pairs[0], ("<start>", "</end>"));
assert_eq!(pairs[1], ("[", "]"));
}
#[test]
fn test_parser_registry() {
let registry = ParserRegistry::new();
// Test has default mappings
assert!(!registry.list_mappings().is_empty());
// Test model pattern matching
let mappings = registry.list_mappings();
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
assert!(has_gpt);
}
#[test]
fn test_parser_registry_pattern_matching() {
let mut registry = ParserRegistry::new_for_testing();
// Test that model mappings work by checking the list
registry.map_model("test-model", "json");
// Verify through list_mappings
let mappings = registry.list_mappings();
let has_test = mappings
.iter()
.any(|(m, p)| *m == "test-model" && *p == "json");
assert!(has_test);
}
#[test]
fn test_tool_call_serialization() {
let tool_call = ToolCall {
id: "call-123".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
name: "search".to_string(),
arguments: r#"{"query": "rust programming"}"#.to_string(),
},
};
let json = serde_json::to_string(&tool_call).unwrap();
assert!(json.contains("call-123"));
assert!(json.contains("search"));
assert!(json.contains("rust programming"));
let parsed: ToolCall = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "call-123");
assert_eq!(parsed.function.name, "search");
}
#[test]
fn test_partial_json_parser() {
let parser = PartialJson::default();
// Test complete JSON
let input = r#"{"name": "test", "value": 42}"#;
let (value, consumed) = parser.parse_value(input).unwrap();
assert_eq!(value["name"], "test");
assert_eq!(value["value"], 42);
assert_eq!(consumed, input.len());
// Test incomplete JSON object
let input = r#"{"name": "test", "value": "#;
let (value, _consumed) = parser.parse_value(input).unwrap();
assert_eq!(value["name"], "test");
assert!(value["value"].is_null());
// Test incomplete string
let input = r#"{"name": "tes"#;
let (value, _consumed) = parser.parse_value(input).unwrap();
assert_eq!(value["name"], "tes");
// Test incomplete array
let input = r#"[1, 2, "#;
let (value, _consumed) = parser.parse_value(input).unwrap();
assert!(value.is_array());
assert_eq!(value[0], 1);
assert_eq!(value[1], 2);
}
#[test]
fn test_partial_json_depth_limit() {
// max_depth of 3 allows nesting up to 3 levels
// Set allow_incomplete to false to get errors instead of partial results
let parser = PartialJson::new(3, false);
// This should work (simple object)
let input = r#"{"a": 1}"#;
let result = parser.parse_value(input);
assert!(result.is_ok());
// This should work (nested to depth 3)
let input = r#"{"a": {"b": {"c": 1}}}"#;
let result = parser.parse_value(input);
assert!(result.is_ok());
// This should fail (nested to depth 4, exceeds limit)
let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#;
let result = parser.parse_value(input);
assert!(result.is_err());
}
#[test]
fn test_is_complete_json() {
assert!(is_complete_json(r#"{"name": "test"}"#));
assert!(is_complete_json(r#"[1, 2, 3]"#));
assert!(is_complete_json(r#""string""#));
assert!(is_complete_json("42"));
assert!(is_complete_json("true"));
assert!(is_complete_json("null"));
assert!(!is_complete_json(r#"{"name": "#));
assert!(!is_complete_json(r#"[1, 2, "#));
assert!(!is_complete_json(r#""unclosed"#));
}
#[test]
fn test_find_common_prefix() {
assert_eq!(find_common_prefix("hello", "hello"), 5);
assert_eq!(find_common_prefix("hello", "help"), 3);
assert_eq!(find_common_prefix("hello", "world"), 0);
assert_eq!(find_common_prefix("", "hello"), 0);
assert_eq!(find_common_prefix("hello", ""), 0);
}
#[test]
fn test_compute_diff() {
assert_eq!(compute_diff("hello", "hello world"), " world");
assert_eq!(compute_diff("", "hello"), "hello");
assert_eq!(compute_diff("hello", "hello"), "");
assert_eq!(compute_diff("test", "hello"), "hello");
}
#[test]
fn test_stream_result_variants() {
// Test Incomplete
let result = StreamResult::Incomplete;
matches!(result, StreamResult::Incomplete);
// Test ToolName
let result = StreamResult::ToolName {
index: 0,
name: "test".to_string(),
};
if let StreamResult::ToolName { index, name } = result {
assert_eq!(index, 0);
assert_eq!(name, "test");
} else {
panic!("Expected ToolName variant");
}
// Test ToolComplete
let tool = ToolCall {
id: "123".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
name: "test".to_string(),
arguments: "{}".to_string(),
},
};
let result = StreamResult::ToolComplete(tool.clone());
if let StreamResult::ToolComplete(t) = result {
assert_eq!(t.id, "123");
} else {
panic!("Expected ToolComplete variant");
}
}
#[test]
fn test_partial_tool_call() {
let mut partial = PartialToolCall {
name: None,
arguments_buffer: String::new(),
start_position: 0,
name_sent: false,
streamed_args: String::new(),
};
// Set name
partial.name = Some("test_function".to_string());
assert_eq!(partial.name.as_ref().unwrap(), "test_function");
// Append arguments
partial.arguments_buffer.push_str(r#"{"key": "value"}"#);
assert_eq!(partial.arguments_buffer, r#"{"key": "value"}"#);
// Update streaming state
partial.name_sent = true;
partial.streamed_args = r#"{"key": "#.to_string();
assert!(partial.name_sent);
assert_eq!(partial.streamed_args, r#"{"key": "#);
}
#[tokio::test]
async fn test_json_parser_complete_single() {
let parser = JsonParser::new();
// Test single tool call with arguments
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("San Francisco"));
assert!(result[0].function.arguments.contains("celsius"));
}
#[tokio::test]
async fn test_json_parser_complete_array() {
let parser = JsonParser::new();
// Test array of tool calls
let input = r#"[
{"name": "get_weather", "arguments": {"location": "SF"}},
{"name": "get_news", "arguments": {"query": "technology"}}
]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "get_news");
}
#[tokio::test]
async fn test_json_parser_with_parameters() {
let parser = JsonParser::new();
// Test with "parameters" instead of "arguments"
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "calculate");
assert!(result[0].function.arguments.contains("10"));
assert!(result[0].function.arguments.contains("20"));
assert!(result[0].function.arguments.contains("add"));
}
#[tokio::test]
async fn test_json_parser_with_tokens() {
// Test with custom wrapper tokens
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["[TOOL_CALLS] [".to_string()],
end_tokens: vec!["]".to_string()],
separator: ", ".to_string(),
});
let input = r#"[TOOL_CALLS] [{"name": "search", "arguments": {"query": "rust programming"}}]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search");
}
#[tokio::test]
async fn test_multiline_json_with_tokens() {
// Test that regex with (?s) flag properly handles multi-line JSON
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string()],
end_tokens: vec!["</tool>".to_string()],
separator: ", ".to_string(),
});
// Pretty-printed multi-line JSON
let input = r#"<tool>{
"name": "get_weather",
"arguments": {
"location": "San Francisco",
"units": "celsius",
"include_forecast": true
}
}</tool>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("San Francisco"));
assert!(result[0].function.arguments.contains("celsius"));
assert!(result[0].function.arguments.contains("true"));
}
#[tokio::test]
async fn test_multiline_json_array() {
// Test multi-line JSON array without wrapper tokens
let parser = JsonParser::new();
let input = r#"[
{
"name": "function1",
"arguments": {
"param1": "value1",
"param2": 42
}
},
{
"name": "function2",
"parameters": {
"data": [1, 2, 3],
"flag": false
}
}
]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "function1");
assert_eq!(result[1].function.name, "function2");
assert!(result[0].function.arguments.contains("value1"));
assert!(result[1].function.arguments.contains("[1,2,3]"));
}
#[test]
fn test_json_parser_format_detection() {
let parser = JsonParser::new();
// Should detect valid tool call formats
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(parser.detect_format(r#"{"name": "test", "parameters": {"x": 1}}"#));
assert!(parser.detect_format(r#"[{"name": "test"}]"#));
// Should not detect non-tool formats
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format(r#"{"key": "value"}"#));
assert!(!parser.detect_format(r#"{"data": {"nested": true}}"#));
}
#[tokio::test]
async fn test_json_parser_streaming() {
let parser = JsonParser::new();
let mut state = ParseState::new();
// Test with complete JSON
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
let result = parser
.parse_incremental(full_json, &mut state)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
assert!(tool.function.arguments.contains("San Francisco"));
}
_ => panic!("Expected ToolComplete for complete JSON"),
}
}
#[tokio::test]
async fn test_registry_with_json_parser() {
let registry = ParserRegistry::new();
// JSON parser should be registered by default
assert!(registry.has_parser("json"));
// Should get JSON parser for OpenAI models
let parser = registry.get_parser("gpt-4-turbo").unwrap();
// Test that the parser works
let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
#[tokio::test]
async fn test_json_parser_invalid_input() {
let parser = JsonParser::new();
// Invalid JSON should return empty results
assert_eq!(parser.parse_complete("not json").await.unwrap().len(), 0);
assert_eq!(parser.parse_complete("{invalid}").await.unwrap().len(), 0);
assert_eq!(parser.parse_complete("").await.unwrap().len(), 0);
}
#[tokio::test]
async fn test_json_parser_empty_arguments() {
let parser = JsonParser::new();
// Tool call with no arguments
let input = r#"{"name": "get_time"}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_time");
assert_eq!(result[0].function.arguments, "{}");
}
#[cfg(test)]
mod failure_cases {
use super::*;
#[tokio::test]
async fn test_malformed_tool_missing_name() {
let parser = JsonParser::new();
// Missing name field
let input = r#"{"arguments": {"x": 1}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0, "Should return empty for tool without name");
// Empty name
let input = r#"{"name": "", "arguments": {"x": 1}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1, "Should accept empty name string");
assert_eq!(result[0].function.name, "");
}
#[tokio::test]
async fn test_invalid_arguments_json() {
let parser = JsonParser::new();
// Arguments is a string instead of object
let input = r#"{"name": "test", "arguments": "not an object"}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
// Should serialize the string as JSON
assert!(result[0].function.arguments.contains("not an object"));
// Arguments is a number
let input = r#"{"name": "test", "arguments": 42}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.arguments, "42");
// Arguments is null
let input = r#"{"name": "test", "arguments": null}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.arguments, "null");
}
#[tokio::test]
async fn test_broken_wrapper_tokens() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string()],
end_tokens: vec!["</tool>".to_string()],
separator: ", ".to_string(),
});
// Missing end token
let input = r#"<tool>{"name": "test", "arguments": {}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(
result.len(),
0,
"Should fail to parse without complete wrapper"
);
// Missing start token - parser looks for complete wrapper, so this won't parse
let input = r#"{"name": "test", "arguments": {}}</tool>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(
result.len(),
0,
"Should not parse JSON with incomplete wrapper"
);
// Mismatched tokens
let input = r#"<tool>{"name": "test", "arguments": {}}</wrong>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0, "Should fail with mismatched tokens");
}
#[tokio::test]
async fn test_invalid_json_structures() {
let parser = JsonParser::new();
// Trailing comma
let input = r#"{"name": "test", "arguments": {"x": 1,}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0, "Should reject JSON with trailing comma");
// Missing quotes on keys
let input = r#"{name: "test", arguments: {}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0, "Should reject invalid JSON syntax");
// Unclosed object
let input = r#"{"name": "test", "arguments": {"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0, "Should reject incomplete JSON");
}
}
#[cfg(test)]
mod edge_cases {
use super::*;
#[tokio::test]
async fn test_unicode_in_names_and_arguments() {
let parser = JsonParser::new();
// Unicode in function name
let input = r#"{"name": "获取天气", "arguments": {"location": "北京"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "获取天气");
assert!(result[0].function.arguments.contains("北京"));
// Emoji in arguments
let input = r#"{"name": "send_message", "arguments": {"text": "Hello 👋 World 🌍"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains("👋"));
assert!(result[0].function.arguments.contains("🌍"));
}
#[tokio::test]
async fn test_escaped_characters() {
let parser = JsonParser::new();
// Escaped quotes in arguments
let input = r#"{"name": "echo", "arguments": {"text": "He said \"hello\""}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains(r#"\"hello\""#));
// Escaped backslashes
let input = r#"{"name": "path", "arguments": {"dir": "C:\\Users\\test"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains("\\\\"));
// Newlines and tabs
let input = r#"{"name": "format", "arguments": {"text": "line1\nline2\ttabbed"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains("\\n"));
assert!(result[0].function.arguments.contains("\\t"));
}
#[tokio::test]
async fn test_very_large_payloads() {
let parser = JsonParser::new();
// Large arguments object
let mut large_args = r#"{"name": "process", "arguments": {"#.to_string();
for i in 0..1000 {
large_args.push_str(&format!(r#""field_{}": "value_{}","#, i, i));
}
large_args.push_str(r#""final": "value"}}"#);
let result = parser.parse_complete(&large_args).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process");
assert!(result[0].function.arguments.contains("field_999"));
// Large array of tool calls
let mut large_array = "[".to_string();
for i in 0..100 {
if i > 0 {
large_array.push(',');
}
large_array.push_str(&format!(r#"{{"name": "func_{}", "arguments": {{}}}}"#, i));
}
large_array.push(']');
let result = parser.parse_complete(&large_array).await.unwrap();
assert_eq!(result.len(), 100);
assert_eq!(result[99].function.name, "func_99");
}
#[tokio::test]
async fn test_mixed_array_tools_and_non_tools() {
let parser = JsonParser::new();
// Array with both tool calls and non-tool objects
let input = r#"[
{"name": "tool1", "arguments": {}},
{"not_a_tool": "just_data"},
{"name": "tool2", "parameters": {"x": 1}},
{"key": "value", "another": "field"}
]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2, "Should only parse valid tool calls");
assert_eq!(result[0].function.name, "tool1");
assert_eq!(result[1].function.name, "tool2");
}
#[tokio::test]
async fn test_duplicate_keys_in_json() {
let parser = JsonParser::new();
// JSON with duplicate keys (last one wins in most parsers)
let input = r#"{"name": "first", "name": "second", "arguments": {"x": 1, "x": 2}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(
result[0].function.name, "second",
"Last duplicate key should win"
);
assert!(
result[0].function.arguments.contains("2"),
"Last duplicate value should win"
);
}
#[tokio::test]
async fn test_null_values_in_arguments() {
let parser = JsonParser::new();
// Null values in arguments
let input = r#"{"name": "test", "arguments": {"required": "value", "optional": null}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains("null"));
// Array with null
let input = r#"{"name": "test", "arguments": {"items": [1, null, "three"]}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains("null"));
}
#[tokio::test]
async fn test_multiple_token_pairs_with_conflicts() {
// Test with overlapping token patterns
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<<".to_string(), "<tool>".to_string()],
end_tokens: vec![">>".to_string(), "</tool>".to_string()],
separator: ", ".to_string(),
});
// First pattern
let input = r#"<<{"name": "test1", "arguments": {}}>>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test1");
// Second pattern
let input = r#"<tool>{"name": "test2", "arguments": {}}</tool>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test2");
// Nested patterns (should use first match)
let input = r#"<<tool>{"name": "test3", "arguments": {}}</tool>>"#;
let result = parser.parse_complete(input).await.unwrap();
// This is tricky - depends on regex behavior
// The parser should handle this gracefully
assert!(result.len() <= 1, "Should not parse multiple times");
}
#[tokio::test]
async fn test_streaming_with_partial_chunks() {
let parser = JsonParser::new();
// Test 1: Very incomplete JSON (just opening brace) should return Incomplete
let mut state1 = ParseState::new();
let partial = r#"{"#;
let result = parser
.parse_incremental(partial, &mut state1)
.await
.unwrap();
assert!(
matches!(result, StreamResult::Incomplete),
"Should return Incomplete for just opening brace"
);
// Test 2: Complete JSON should return ToolComplete
let mut state2 = ParseState::new();
let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
let result = parser
.parse_incremental(complete, &mut state2)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value =
serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["location"], "SF");
}
_ => panic!("Expected ToolComplete for complete JSON"),
}
// Test 3: Partial JSON with name
// The PartialJson parser can complete partial JSON by filling in missing values
let mut state3 = ParseState::new();
let partial_with_name = r#"{"name": "test", "argum"#;
let result = parser
.parse_incremental(partial_with_name, &mut state3)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "test");
// Arguments will be empty object since "argum" is incomplete
assert_eq!(tool.function.arguments, "{}");
}
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "test");
}
StreamResult::Incomplete => {
// Also acceptable if parser decides to wait
}
_ => panic!("Unexpected result for partial JSON with name"),
}
}
#[tokio::test]
async fn test_special_json_values() {
let parser = JsonParser::new();
// Boolean values
let input = r#"{"name": "toggle", "arguments": {"enabled": true, "disabled": false}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains("true"));
assert!(result[0].function.arguments.contains("false"));
// Numbers (including float and negative)
let input = r#"{"name": "calc", "arguments": {"int": 42, "float": 3.14, "negative": -17}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains("42"));
assert!(result[0].function.arguments.contains("3.14"));
assert!(result[0].function.arguments.contains("-17"));
// Empty arrays and objects
let input = r#"{"name": "test", "arguments": {"empty_arr": [], "empty_obj": {}}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains("[]"));
assert!(result[0].function.arguments.contains("{}"));
}
#[tokio::test]
async fn test_function_field_alternative() {
let parser = JsonParser::new();
// Using "function" instead of "name"
let input = r#"{"function": "test_func", "arguments": {"x": 1}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test_func");
// Both "name" and "function" present (name should take precedence)
let input = r#"{"name": "primary", "function": "secondary", "arguments": {}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "primary");
}
#[tokio::test]
async fn test_whitespace_handling() {
let parser = JsonParser::new();
// Extra whitespace everywhere
let input = r#" {
"name" : "test" ,
"arguments" : {
"key" : "value"
}
} "#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
// Minified JSON (no whitespace)
let input = r#"{"name":"compact","arguments":{"a":1,"b":2}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "compact");
}
}
#[cfg(test)]
mod stress_tests {
use super::*;
#[tokio::test]
async fn test_deeply_nested_arguments() {
let parser = JsonParser::new();
// Deeply nested structure
let input = r#"{
"name": "nested",
"arguments": {
"level1": {
"level2": {
"level3": {
"level4": {
"level5": {
"value": "deep"
}
}
}
}
}
}
}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].function.arguments.contains("deep"));
}
#[tokio::test]
async fn test_concurrent_parser_usage() {
// Test that parser can be used concurrently
let parser = std::sync::Arc::new(JsonParser::new());
let mut handles = vec![];
for i in 0..10 {
let parser_clone = parser.clone();
let handle = tokio::spawn(async move {
let input = format!(r#"{{"name": "func_{}", "arguments": {{}}}}"#, i);
let result = parser_clone.parse_complete(&input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, format!("func_{}", i));
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}
}

View File

@@ -0,0 +1,35 @@
use crate::tool_parser::{
errors::ToolParserResult,
state::ParseState,
types::{StreamResult, ToolCall},
};
use async_trait::async_trait;
/// Core trait for all tool parsers
#[async_trait]
pub trait ToolParser: Send + Sync {
/// Parse complete tool calls from final output
async fn parse_complete(&self, output: &str) -> ToolParserResult<Vec<ToolCall>>;
/// Parse tool calls from model output (streaming)
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult>;
/// Check if text contains tool calls in this parser's format
fn detect_format(&self, text: &str) -> bool;
}
/// Trait for partial JSON parsing
pub trait PartialJsonParser: Send + Sync {
/// Parse potentially incomplete JSON
fn parse(&self, input: &str) -> ToolParserResult<(serde_json::Value, usize)>;
/// Check if JSON is complete
fn is_complete(&self, input: &str) -> bool;
/// Get the maximum parsing depth
fn max_depth(&self) -> usize;
}

View File

@@ -0,0 +1,73 @@
use serde::{Deserialize, Serialize};
/// Parsed tool call from model output (OpenAI format)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
/// Unique identifier for the tool call
pub id: String,
/// Type of tool call (currently always "function")
#[serde(rename = "type")]
pub r#type: String,
/// Function call details
pub function: FunctionCall,
}
/// Function call within a tool call
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
/// Name of the function to call
pub name: String,
/// Arguments as JSON string
pub arguments: String,
}
/// Streaming parse result
#[derive(Debug, Clone)]
pub enum StreamResult {
/// Need more data to continue parsing
Incomplete,
/// Found a tool name (for streaming)
ToolName { index: usize, name: String },
/// Found incremental arguments (for streaming)
ToolArguments { index: usize, arguments: String },
/// Completed parsing a tool
ToolComplete(ToolCall),
/// Normal text (not part of tool call)
NormalText(String),
}
/// Token configuration for parsing
#[derive(Debug, Clone)]
pub struct TokenConfig {
/// Start tokens for tool calls
pub start_tokens: Vec<String>,
/// End tokens for tool calls
pub end_tokens: Vec<String>,
/// Separator between multiple tool calls
pub separator: String,
}
impl TokenConfig {
/// Iterate over start/end token pairs
pub fn iter_pairs(&self) -> impl Iterator<Item = (&str, &str)> {
self.start_tokens
.iter()
.zip(self.end_tokens.iter())
.map(|(s, e)| (s.as_str(), e.as_str()))
}
}
/// Simple partial tool call for streaming
#[derive(Debug, Clone)]
pub struct PartialToolCall {
/// Tool name (if parsed)
pub name: Option<String>,
/// Buffer for accumulating arguments
pub arguments_buffer: String,
/// Start position in the input buffer
pub start_position: usize,
/// Whether the name has been sent (for streaming)
pub name_sent: bool,
/// Arguments already streamed
pub streamed_args: String,
}

1478
sgl-router/src/tree.rs Normal file

File diff suppressed because it is too large Load Diff