sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
28
sgl-router/src/config/mod.rs
Normal file
28
sgl-router/src/config/mod.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
pub mod types;
|
||||
pub mod validation;
|
||||
|
||||
pub use types::*;
|
||||
pub use validation::*;
|
||||
|
||||
/// Configuration errors
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConfigError {
|
||||
#[error("Validation failed: {reason}")]
|
||||
ValidationFailed { reason: String },
|
||||
|
||||
#[error("Invalid value for field '{field}': {value} - {reason}")]
|
||||
InvalidValue {
|
||||
field: String,
|
||||
value: String,
|
||||
reason: String,
|
||||
},
|
||||
|
||||
#[error("Incompatible configuration: {reason}")]
|
||||
IncompatibleConfig { reason: String },
|
||||
|
||||
#[error("Missing required field: {field}")]
|
||||
MissingRequired { field: String },
|
||||
}
|
||||
|
||||
/// Result type for configuration operations
|
||||
pub type ConfigResult<T> = Result<T, ConfigError>;
|
||||
1241
sgl-router/src/config/types.rs
Normal file
1241
sgl-router/src/config/types.rs
Normal file
File diff suppressed because it is too large
Load Diff
776
sgl-router/src/config/validation.rs
Normal file
776
sgl-router/src/config/validation.rs
Normal file
@@ -0,0 +1,776 @@
|
||||
use super::*;
|
||||
|
||||
/// Configuration validator
|
||||
pub struct ConfigValidator;
|
||||
|
||||
impl ConfigValidator {
|
||||
/// Validate a complete router configuration
|
||||
pub fn validate(config: &RouterConfig) -> ConfigResult<()> {
|
||||
// Check if service discovery is enabled
|
||||
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
|
||||
|
||||
Self::validate_mode(&config.mode, has_service_discovery)?;
|
||||
Self::validate_policy(&config.policy)?;
|
||||
Self::validate_server_settings(config)?;
|
||||
|
||||
if let Some(discovery) = &config.discovery {
|
||||
Self::validate_discovery(discovery, &config.mode)?;
|
||||
}
|
||||
|
||||
if let Some(metrics) = &config.metrics {
|
||||
Self::validate_metrics(metrics)?;
|
||||
}
|
||||
|
||||
Self::validate_compatibility(config)?;
|
||||
|
||||
// Validate effective retry/CB configs (respect disable flags)
|
||||
let retry_cfg = config.effective_retry_config();
|
||||
let cb_cfg = config.effective_circuit_breaker_config();
|
||||
Self::validate_retry(&retry_cfg)?;
|
||||
Self::validate_circuit_breaker(&cb_cfg)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate routing mode configuration
|
||||
fn validate_mode(mode: &RoutingMode, has_service_discovery: bool) -> ConfigResult<()> {
|
||||
match mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
// Validate URLs if any are provided
|
||||
if !worker_urls.is_empty() {
|
||||
Self::validate_urls(worker_urls)?;
|
||||
}
|
||||
// Note: We allow empty worker URLs even without service discovery
|
||||
// to let the router start and fail at runtime when routing requests.
|
||||
// This matches legacy behavior and test expectations.
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
// Only require URLs if service discovery is disabled
|
||||
if !has_service_discovery {
|
||||
if prefill_urls.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "PD mode requires at least one prefill worker URL".to_string(),
|
||||
});
|
||||
}
|
||||
if decode_urls.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "PD mode requires at least one decode worker URL".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Validate URLs if any are provided
|
||||
if !prefill_urls.is_empty() {
|
||||
let prefill_url_strings: Vec<String> =
|
||||
prefill_urls.iter().map(|(url, _)| url.clone()).collect();
|
||||
Self::validate_urls(&prefill_url_strings)?;
|
||||
}
|
||||
if !decode_urls.is_empty() {
|
||||
Self::validate_urls(decode_urls)?;
|
||||
}
|
||||
|
||||
// Validate bootstrap ports
|
||||
for (_url, port) in prefill_urls {
|
||||
if let Some(port) = port {
|
||||
if *port == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "bootstrap_port".to_string(),
|
||||
value: port.to_string(),
|
||||
reason: "Port must be between 1 and 65535".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate optional prefill and decode policies
|
||||
if let Some(p_policy) = prefill_policy {
|
||||
Self::validate_policy(p_policy)?;
|
||||
}
|
||||
if let Some(d_policy) = decode_policy {
|
||||
Self::validate_policy(d_policy)?;
|
||||
}
|
||||
}
|
||||
RoutingMode::OpenAI { worker_urls } => {
|
||||
// Require exactly one worker URL for OpenAI router
|
||||
if worker_urls.len() != 1 {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "OpenAI mode requires exactly one --worker-urls entry".to_string(),
|
||||
});
|
||||
}
|
||||
// Validate URL format
|
||||
if let Err(e) = url::Url::parse(&worker_urls[0]) {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: format!("Invalid OpenAI worker URL '{}': {}", &worker_urls[0], e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate policy configuration
|
||||
fn validate_policy(policy: &PolicyConfig) -> ConfigResult<()> {
|
||||
match policy {
|
||||
PolicyConfig::Random | PolicyConfig::RoundRobin => {
|
||||
// No specific validation needed
|
||||
}
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold,
|
||||
balance_abs_threshold: _,
|
||||
balance_rel_threshold,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
} => {
|
||||
if !(0.0..=1.0).contains(cache_threshold) {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "cache_threshold".to_string(),
|
||||
value: cache_threshold.to_string(),
|
||||
reason: "Must be between 0.0 and 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if *balance_rel_threshold < 1.0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "balance_rel_threshold".to_string(),
|
||||
value: balance_rel_threshold.to_string(),
|
||||
reason: "Must be >= 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if *eviction_interval_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "eviction_interval_secs".to_string(),
|
||||
value: eviction_interval_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if *max_tree_size == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "max_tree_size".to_string(),
|
||||
value: max_tree_size.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs,
|
||||
} => {
|
||||
if *load_check_interval_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "load_check_interval_secs".to_string(),
|
||||
value: load_check_interval_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate server configuration
|
||||
fn validate_server_settings(config: &RouterConfig) -> ConfigResult<()> {
|
||||
if config.port == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "port".to_string(),
|
||||
value: config.port.to_string(),
|
||||
reason: "Port must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if config.max_payload_size == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "max_payload_size".to_string(),
|
||||
value: config.max_payload_size.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if config.request_timeout_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "request_timeout_secs".to_string(),
|
||||
value: config.request_timeout_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if config.worker_startup_timeout_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_startup_timeout_secs".to_string(),
|
||||
value: config.worker_startup_timeout_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if config.worker_startup_check_interval_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_startup_check_interval_secs".to_string(),
|
||||
value: config.worker_startup_check_interval_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate service discovery configuration
|
||||
fn validate_discovery(discovery: &DiscoveryConfig, mode: &RoutingMode) -> ConfigResult<()> {
|
||||
if !discovery.enabled {
|
||||
return Ok(()); // No validation needed if disabled
|
||||
}
|
||||
|
||||
if discovery.port == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "discovery.port".to_string(),
|
||||
value: discovery.port.to_string(),
|
||||
reason: "Port must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if discovery.check_interval_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "discovery.check_interval_secs".to_string(),
|
||||
value: discovery.check_interval_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Validate selectors based on mode
|
||||
match mode {
|
||||
RoutingMode::Regular { .. } => {
|
||||
if discovery.selector.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "Regular mode with service discovery requires a non-empty selector"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
RoutingMode::PrefillDecode { .. } => {
|
||||
if discovery.prefill_selector.is_empty() && discovery.decode_selector.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "PD mode with service discovery requires at least one non-empty selector (prefill or decode)".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
RoutingMode::OpenAI { .. } => {
|
||||
// OpenAI mode doesn't use service discovery
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "OpenAI mode does not support service discovery".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate metrics configuration
|
||||
fn validate_metrics(metrics: &MetricsConfig) -> ConfigResult<()> {
|
||||
if metrics.port == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "metrics.port".to_string(),
|
||||
value: metrics.port.to_string(),
|
||||
reason: "Port must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if metrics.host.is_empty() {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "metrics.host".to_string(),
|
||||
value: metrics.host.clone(),
|
||||
reason: "Host cannot be empty".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate retry configuration
|
||||
fn validate_retry(retry: &RetryConfig) -> ConfigResult<()> {
|
||||
if retry.max_retries < 1 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.max_retries".to_string(),
|
||||
value: retry.max_retries.to_string(),
|
||||
reason: "Must be >= 1 (set to 1 to effectively disable retries)".to_string(),
|
||||
});
|
||||
}
|
||||
if retry.initial_backoff_ms == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.initial_backoff_ms".to_string(),
|
||||
value: retry.initial_backoff_ms.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
if retry.max_backoff_ms < retry.initial_backoff_ms {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.max_backoff_ms".to_string(),
|
||||
value: retry.max_backoff_ms.to_string(),
|
||||
reason: "Must be >= initial_backoff_ms".to_string(),
|
||||
});
|
||||
}
|
||||
if retry.backoff_multiplier < 1.0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.backoff_multiplier".to_string(),
|
||||
value: retry.backoff_multiplier.to_string(),
|
||||
reason: "Must be >= 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
if !(0.0..=1.0).contains(&retry.jitter_factor) {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "retry.jitter_factor".to_string(),
|
||||
value: retry.jitter_factor.to_string(),
|
||||
reason: "Must be between 0.0 and 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate circuit breaker configuration
|
||||
fn validate_circuit_breaker(cb: &CircuitBreakerConfig) -> ConfigResult<()> {
|
||||
if cb.failure_threshold < 1 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.failure_threshold".to_string(),
|
||||
value: cb.failure_threshold.to_string(),
|
||||
reason: "Must be >= 1 (set to u32::MAX to effectively disable CB)".to_string(),
|
||||
});
|
||||
}
|
||||
if cb.success_threshold < 1 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.success_threshold".to_string(),
|
||||
value: cb.success_threshold.to_string(),
|
||||
reason: "Must be >= 1".to_string(),
|
||||
});
|
||||
}
|
||||
if cb.timeout_duration_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.timeout_duration_secs".to_string(),
|
||||
value: cb.timeout_duration_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
if cb.window_duration_secs == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "circuit_breaker.window_duration_secs".to_string(),
|
||||
value: cb.window_duration_secs.to_string(),
|
||||
reason: "Must be > 0".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate compatibility between different configuration sections
|
||||
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
|
||||
// IGW mode is independent - skip other compatibility checks when enabled
|
||||
if config.enable_igw {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Validate gRPC connection mode requires tokenizer configuration
|
||||
if config.connection_mode == ConnectionMode::Grpc
|
||||
&& config.tokenizer_path.is_none()
|
||||
&& config.model_path.is_none()
|
||||
{
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "gRPC connection mode requires either --tokenizer-path or --model-path to be specified".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// All policies are now supported for both router types thanks to the unified trait design
|
||||
// No mode/policy restrictions needed anymore
|
||||
|
||||
// Check if service discovery is enabled for worker count validation
|
||||
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
|
||||
|
||||
// Only validate worker counts if service discovery is disabled
|
||||
if !has_service_discovery {
|
||||
// Check if power-of-two policy makes sense with insufficient workers
|
||||
if let PolicyConfig::PowerOfTwo { .. } = &config.policy {
|
||||
let worker_count = config.mode.worker_count();
|
||||
if worker_count < 2 {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "Power-of-two policy requires at least 2 workers".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// For PD mode, validate that policies have sufficient workers
|
||||
if let RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} = &config.mode
|
||||
{
|
||||
// Check power-of-two for prefill
|
||||
if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy {
|
||||
if prefill_urls.len() < 2 {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "Power-of-two policy for prefill requires at least 2 prefill workers".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check power-of-two for decode
|
||||
if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy {
|
||||
if decode_urls.len() < 2 {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason:
|
||||
"Power-of-two policy for decode requires at least 2 decode workers"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Service discovery is conflict with dp_aware routing for now
|
||||
// since it's not fully supported yet
|
||||
if has_service_discovery && config.dp_aware {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "DP-aware routing is not compatible with service discovery".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate URL format
|
||||
fn validate_urls(urls: &[String]) -> ConfigResult<()> {
|
||||
for url in urls {
|
||||
if url.is_empty() {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_url".to_string(),
|
||||
value: url.clone(),
|
||||
reason: "URL cannot be empty".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if !url.starts_with("http://")
|
||||
&& !url.starts_with("https://")
|
||||
&& !url.starts_with("grpc://")
|
||||
{
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_url".to_string(),
|
||||
value: url.clone(),
|
||||
reason: "URL must start with http://, https://, or grpc://".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Basic URL validation
|
||||
match ::url::Url::parse(url) {
|
||||
Ok(parsed) => {
|
||||
// Additional validation
|
||||
if parsed.host_str().is_none() {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_url".to_string(),
|
||||
value: url.clone(),
|
||||
reason: "URL must have a valid host".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "worker_url".to_string(),
|
||||
value: url.clone(),
|
||||
reason: format!("Invalid URL format: {}", e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_regular_mode() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["http://worker:8000".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_empty_worker_urls() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
// Empty worker URLs are now allowed to match legacy behavior
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_empty_worker_urls_with_service_discovery() {
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
// Enable service discovery
|
||||
config.discovery = Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
selector: vec![("app".to_string(), "test".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Should pass validation since service discovery is enabled
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_invalid_urls() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["invalid-url".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cache_aware_thresholds() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![
|
||||
"http://worker1:8000".to_string(),
|
||||
"http://worker2:8000".to_string(),
|
||||
],
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 1.5, // Invalid: > 1.0
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
},
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cache_aware_single_worker() {
|
||||
// Cache-aware with single worker should be allowed (even if not optimal)
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["http://worker1:8000".to_string()],
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
},
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode() {
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
|
||||
decode_urls: vec!["http://decode:8000".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
assert!(ConfigValidator::validate(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_roundrobin_with_pd_mode() {
|
||||
// RoundRobin with PD mode is now supported
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||
decode_urls: vec!["http://decode:8000".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::RoundRobin,
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cache_aware_with_pd_mode() {
|
||||
// CacheAware with PD mode is now supported
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
|
||||
decode_urls: vec!["http://decode:8000".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
},
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_power_of_two_with_regular_mode() {
|
||||
// PowerOfTwo with Regular mode is now supported
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![
|
||||
"http://worker1:8000".to_string(),
|
||||
"http://worker2:8000".to_string(),
|
||||
],
|
||||
},
|
||||
PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
},
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode_with_separate_policies() {
|
||||
// Test PD mode with different policies for prefill and decode
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
("http://prefill1:8000".to_string(), None),
|
||||
("http://prefill2:8000".to_string(), None),
|
||||
],
|
||||
decode_urls: vec![
|
||||
"http://decode1:8000".to_string(),
|
||||
"http://decode2:8000".to_string(),
|
||||
],
|
||||
prefill_policy: Some(PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000,
|
||||
}),
|
||||
decode_policy: Some(PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
}),
|
||||
},
|
||||
PolicyConfig::Random, // Main policy as fallback
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
|
||||
// Test that power-of-two policy requires at least 2 workers
|
||||
let config = RouterConfig::new(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
|
||||
decode_urls: vec![
|
||||
"http://decode1:8000".to_string(),
|
||||
"http://decode2:8000".to_string(),
|
||||
],
|
||||
prefill_policy: Some(PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
}), // Requires 2+ workers
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("prefill requires at least 2"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_requires_tokenizer() {
|
||||
// Test that gRPC connection mode requires tokenizer configuration
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
// Set connection mode to gRPC without tokenizer config
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.tokenizer_path = None;
|
||||
config.model_path = None;
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("gRPC connection mode requires"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_with_model_path() {
|
||||
// Test that gRPC works with model_path
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.model_path = Some("meta-llama/Llama-3-8B".to_string());
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_grpc_with_tokenizer_path() {
|
||||
// Test that gRPC works with tokenizer_path
|
||||
let mut config = RouterConfig::new(
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.tokenizer_path = Some("/path/to/tokenizer.json".to_string());
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
555
sgl-router/src/core/circuit_breaker.rs
Normal file
555
sgl-router/src/core/circuit_breaker.rs
Normal file
@@ -0,0 +1,555 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
/// Circuit breaker configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CircuitBreakerConfig {
|
||||
/// Number of consecutive failures to open the circuit
|
||||
pub failure_threshold: u32,
|
||||
/// Success threshold to close circuit from half-open
|
||||
pub success_threshold: u32,
|
||||
/// Duration to wait before attempting half-open
|
||||
pub timeout_duration: Duration,
|
||||
/// Time window for failure counting
|
||||
pub window_duration: Duration,
|
||||
}
|
||||
|
||||
impl Default for CircuitBreakerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
failure_threshold: 5,
|
||||
success_threshold: 2,
|
||||
timeout_duration: Duration::from_secs(30),
|
||||
window_duration: Duration::from_secs(60),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Circuit breaker state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CircuitState {
|
||||
/// Normal operation - requests are allowed
|
||||
Closed,
|
||||
/// Circuit is open - requests are rejected
|
||||
Open,
|
||||
/// Testing if service has recovered - limited requests allowed
|
||||
HalfOpen,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CircuitState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
CircuitState::Closed => write!(f, "Closed"),
|
||||
CircuitState::Open => write!(f, "Open"),
|
||||
CircuitState::HalfOpen => write!(f, "HalfOpen"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Circuit breaker implementation
|
||||
#[derive(Debug)]
|
||||
pub struct CircuitBreaker {
|
||||
state: Arc<RwLock<CircuitState>>,
|
||||
consecutive_failures: Arc<AtomicU32>,
|
||||
consecutive_successes: Arc<AtomicU32>,
|
||||
total_failures: Arc<AtomicU64>,
|
||||
total_successes: Arc<AtomicU64>,
|
||||
last_failure_time: Arc<RwLock<Option<Instant>>>,
|
||||
last_state_change: Arc<RwLock<Instant>>,
|
||||
config: CircuitBreakerConfig,
|
||||
}
|
||||
|
||||
impl CircuitBreaker {
|
||||
/// Create a new circuit breaker with default configuration
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(CircuitBreakerConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new circuit breaker with custom configuration
|
||||
pub fn with_config(config: CircuitBreakerConfig) -> Self {
|
||||
Self {
|
||||
state: Arc::new(RwLock::new(CircuitState::Closed)),
|
||||
consecutive_failures: Arc::new(AtomicU32::new(0)),
|
||||
consecutive_successes: Arc::new(AtomicU32::new(0)),
|
||||
total_failures: Arc::new(AtomicU64::new(0)),
|
||||
total_successes: Arc::new(AtomicU64::new(0)),
|
||||
last_failure_time: Arc::new(RwLock::new(None)),
|
||||
last_state_change: Arc::new(RwLock::new(Instant::now())),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a request can be executed
|
||||
pub fn can_execute(&self) -> bool {
|
||||
// First check if we need to transition from Open to HalfOpen
|
||||
self.check_and_update_state();
|
||||
|
||||
let state = *self.state.read().unwrap();
|
||||
match state {
|
||||
CircuitState::Closed => true,
|
||||
CircuitState::Open => false,
|
||||
CircuitState::HalfOpen => true, // Allow limited requests in half-open state
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current state
|
||||
pub fn state(&self) -> CircuitState {
|
||||
self.check_and_update_state();
|
||||
*self.state.read().unwrap()
|
||||
}
|
||||
|
||||
/// Record the outcome of a request
|
||||
pub fn record_outcome(&self, success: bool) {
|
||||
if success {
|
||||
self.record_success();
|
||||
} else {
|
||||
self.record_failure();
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a successful request
|
||||
pub fn record_success(&self) {
|
||||
self.total_successes.fetch_add(1, Ordering::Relaxed);
|
||||
self.consecutive_failures.store(0, Ordering::Release);
|
||||
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
|
||||
// Outcome-level metrics are recorded at the worker level where the worker label is known
|
||||
|
||||
let current_state = *self.state.read().unwrap();
|
||||
|
||||
match current_state {
|
||||
CircuitState::HalfOpen => {
|
||||
// Check if we've reached the success threshold to close the circuit
|
||||
if successes >= self.config.success_threshold {
|
||||
self.transition_to(CircuitState::Closed);
|
||||
}
|
||||
}
|
||||
CircuitState::Closed => {
|
||||
// Already closed, nothing to do
|
||||
}
|
||||
CircuitState::Open => {
|
||||
// Shouldn't happen, but if it does, stay open
|
||||
tracing::warn!("Success recorded while circuit is open");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a failed request
|
||||
pub fn record_failure(&self) {
|
||||
self.total_failures.fetch_add(1, Ordering::Relaxed);
|
||||
self.consecutive_successes.store(0, Ordering::Release);
|
||||
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
|
||||
// Outcome-level metrics are recorded at the worker level where the worker label is known
|
||||
|
||||
// Update last failure time
|
||||
{
|
||||
let mut last_failure = self.last_failure_time.write().unwrap();
|
||||
*last_failure = Some(Instant::now());
|
||||
}
|
||||
|
||||
let current_state = *self.state.read().unwrap();
|
||||
|
||||
match current_state {
|
||||
CircuitState::Closed => {
|
||||
// Check if we've reached the failure threshold to open the circuit
|
||||
if failures >= self.config.failure_threshold {
|
||||
self.transition_to(CircuitState::Open);
|
||||
}
|
||||
}
|
||||
CircuitState::HalfOpen => {
|
||||
// Single failure in half-open state reopens the circuit
|
||||
self.transition_to(CircuitState::Open);
|
||||
}
|
||||
CircuitState::Open => {
|
||||
// Already open, nothing to do
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check and update state based on timeout
|
||||
fn check_and_update_state(&self) {
|
||||
let current_state = *self.state.read().unwrap();
|
||||
|
||||
if current_state == CircuitState::Open {
|
||||
// Check if timeout has expired
|
||||
let last_change = *self.last_state_change.read().unwrap();
|
||||
if last_change.elapsed() >= self.config.timeout_duration {
|
||||
self.transition_to(CircuitState::HalfOpen);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Transition to a new state
|
||||
fn transition_to(&self, new_state: CircuitState) {
|
||||
let mut state = self.state.write().unwrap();
|
||||
let old_state = *state;
|
||||
|
||||
if old_state != new_state {
|
||||
*state = new_state;
|
||||
|
||||
// Update last state change time
|
||||
let mut last_change = self.last_state_change.write().unwrap();
|
||||
*last_change = Instant::now();
|
||||
|
||||
// Reset counters based on transition
|
||||
match new_state {
|
||||
CircuitState::Closed => {
|
||||
self.consecutive_failures.store(0, Ordering::Release);
|
||||
self.consecutive_successes.store(0, Ordering::Release);
|
||||
}
|
||||
CircuitState::Open => {
|
||||
self.consecutive_successes.store(0, Ordering::Release);
|
||||
}
|
||||
CircuitState::HalfOpen => {
|
||||
self.consecutive_failures.store(0, Ordering::Release);
|
||||
self.consecutive_successes.store(0, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
let from = match old_state {
|
||||
CircuitState::Closed => "closed",
|
||||
CircuitState::Open => "open",
|
||||
CircuitState::HalfOpen => "half_open",
|
||||
};
|
||||
let to = match new_state {
|
||||
CircuitState::Closed => "closed",
|
||||
CircuitState::Open => "open",
|
||||
CircuitState::HalfOpen => "half_open",
|
||||
};
|
||||
info!("Circuit breaker state transition: {} -> {}", from, to);
|
||||
// Transition metrics are recorded at the worker level where the worker label is known
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of consecutive failures
|
||||
pub fn failure_count(&self) -> u32 {
|
||||
self.consecutive_failures.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// Get the number of consecutive successes
|
||||
pub fn success_count(&self) -> u32 {
|
||||
self.consecutive_successes.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// Get total failures
|
||||
pub fn total_failures(&self) -> u64 {
|
||||
self.total_failures.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get total successes
|
||||
pub fn total_successes(&self) -> u64 {
|
||||
self.total_successes.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get time since last failure
|
||||
pub fn time_since_last_failure(&self) -> Option<Duration> {
|
||||
self.last_failure_time.read().unwrap().map(|t| t.elapsed())
|
||||
}
|
||||
|
||||
/// Get time since last state change
|
||||
pub fn time_since_last_state_change(&self) -> Duration {
|
||||
self.last_state_change.read().unwrap().elapsed()
|
||||
}
|
||||
|
||||
/// Check if the circuit is in a half-open state
|
||||
pub fn is_half_open(&self) -> bool {
|
||||
self.state() == CircuitState::HalfOpen
|
||||
}
|
||||
|
||||
/// Record a test success (for health check probing)
|
||||
pub fn record_test_success(&self) {
|
||||
if self.is_half_open() {
|
||||
self.record_success();
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a test failure (for health check probing)
|
||||
pub fn record_test_failure(&self) {
|
||||
if self.is_half_open() {
|
||||
self.record_failure();
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the circuit breaker to closed state
|
||||
pub fn reset(&self) {
|
||||
self.transition_to(CircuitState::Closed);
|
||||
self.consecutive_failures.store(0, Ordering::Release);
|
||||
self.consecutive_successes.store(0, Ordering::Release);
|
||||
}
|
||||
|
||||
/// Force the circuit to open (for manual intervention)
|
||||
pub fn force_open(&self) {
|
||||
self.transition_to(CircuitState::Open);
|
||||
}
|
||||
|
||||
/// Get circuit breaker statistics
|
||||
pub fn stats(&self) -> CircuitBreakerStats {
|
||||
CircuitBreakerStats {
|
||||
state: self.state(),
|
||||
consecutive_failures: self.failure_count(),
|
||||
consecutive_successes: self.success_count(),
|
||||
total_failures: self.total_failures(),
|
||||
total_successes: self.total_successes(),
|
||||
time_since_last_failure: self.time_since_last_failure(),
|
||||
time_since_last_state_change: self.time_since_last_state_change(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for CircuitBreaker {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
state: Arc::clone(&self.state),
|
||||
consecutive_failures: Arc::clone(&self.consecutive_failures),
|
||||
consecutive_successes: Arc::clone(&self.consecutive_successes),
|
||||
total_failures: Arc::clone(&self.total_failures),
|
||||
total_successes: Arc::clone(&self.total_successes),
|
||||
last_failure_time: Arc::clone(&self.last_failure_time),
|
||||
last_state_change: Arc::clone(&self.last_state_change),
|
||||
config: self.config.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CircuitBreaker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Circuit breaker statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CircuitBreakerStats {
|
||||
pub state: CircuitState,
|
||||
pub consecutive_failures: u32,
|
||||
pub consecutive_successes: u32,
|
||||
pub total_failures: u64,
|
||||
pub total_successes: u64,
|
||||
pub time_since_last_failure: Option<Duration>,
|
||||
pub time_since_last_state_change: Duration,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn test_circuit_breaker_initial_state() {
|
||||
let cb = CircuitBreaker::new();
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
assert!(cb.can_execute());
|
||||
assert_eq!(cb.failure_count(), 0);
|
||||
assert_eq!(cb.success_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_opens_on_threshold() {
|
||||
let config = CircuitBreakerConfig {
|
||||
failure_threshold: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Record failures up to threshold
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
cb.record_failure();
|
||||
|
||||
// Circuit should now be open
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
assert!(!cb.can_execute());
|
||||
assert_eq!(cb.failure_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_half_open_after_timeout() {
|
||||
let config = CircuitBreakerConfig {
|
||||
failure_threshold: 1,
|
||||
timeout_duration: Duration::from_millis(100),
|
||||
..Default::default()
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Open the circuit
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
|
||||
// Wait for timeout
|
||||
thread::sleep(Duration::from_millis(150));
|
||||
|
||||
// Circuit should be half-open
|
||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||
assert!(cb.can_execute());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_closes_on_success_threshold() {
|
||||
let config = CircuitBreakerConfig {
|
||||
failure_threshold: 1,
|
||||
success_threshold: 2,
|
||||
timeout_duration: Duration::from_millis(50),
|
||||
..Default::default()
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Open the circuit
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
|
||||
// Wait for timeout
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||
|
||||
// Record successes
|
||||
cb.record_success();
|
||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||
cb.record_success();
|
||||
|
||||
// Circuit should now be closed
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
assert!(cb.can_execute());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_reopens_on_half_open_failure() {
|
||||
let config = CircuitBreakerConfig {
|
||||
failure_threshold: 1,
|
||||
timeout_duration: Duration::from_millis(50),
|
||||
..Default::default()
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Open the circuit
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
|
||||
// Wait for timeout
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
assert_eq!(cb.state(), CircuitState::HalfOpen);
|
||||
|
||||
// Record a failure in half-open state
|
||||
cb.record_failure();
|
||||
|
||||
// Circuit should reopen immediately
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
assert!(!cb.can_execute());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_success_resets_failure_count() {
|
||||
let config = CircuitBreakerConfig {
|
||||
failure_threshold: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Record some failures
|
||||
cb.record_failure();
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.failure_count(), 2);
|
||||
|
||||
// Success should reset failure count
|
||||
cb.record_success();
|
||||
assert_eq!(cb.failure_count(), 0);
|
||||
assert_eq!(cb.success_count(), 1);
|
||||
|
||||
// Can now record more failures without opening
|
||||
cb.record_failure();
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manual_reset() {
|
||||
let config = CircuitBreakerConfig {
|
||||
failure_threshold: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
// Open the circuit
|
||||
cb.record_failure();
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
|
||||
// Manual reset
|
||||
cb.reset();
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
assert_eq!(cb.failure_count(), 0);
|
||||
assert_eq!(cb.success_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_force_open() {
|
||||
let cb = CircuitBreaker::new();
|
||||
assert_eq!(cb.state(), CircuitState::Closed);
|
||||
|
||||
cb.force_open();
|
||||
assert_eq!(cb.state(), CircuitState::Open);
|
||||
assert!(!cb.can_execute());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats() {
|
||||
let config = CircuitBreakerConfig {
|
||||
failure_threshold: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let cb = CircuitBreaker::with_config(config);
|
||||
|
||||
cb.record_success();
|
||||
cb.record_failure();
|
||||
cb.record_failure();
|
||||
|
||||
let stats = cb.stats();
|
||||
assert_eq!(stats.state, CircuitState::Open);
|
||||
assert_eq!(stats.consecutive_failures, 2);
|
||||
assert_eq!(stats.consecutive_successes, 0);
|
||||
assert_eq!(stats.total_failures, 2);
|
||||
assert_eq!(stats.total_successes, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clone() {
|
||||
let cb1 = CircuitBreaker::new();
|
||||
cb1.record_failure();
|
||||
|
||||
let cb2 = cb1.clone();
|
||||
assert_eq!(cb2.failure_count(), 1);
|
||||
|
||||
// Changes to cb1 affect cb2 (shared state)
|
||||
cb1.record_failure();
|
||||
assert_eq!(cb2.failure_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thread_safety() {
|
||||
use std::sync::Arc;
|
||||
|
||||
let cb = Arc::new(CircuitBreaker::new());
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn threads that record failures
|
||||
for _ in 0..10 {
|
||||
let cb_clone = Arc::clone(&cb);
|
||||
let handle = thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
cb_clone.record_failure();
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Should have recorded 1000 failures
|
||||
assert_eq!(cb.total_failures(), 1000);
|
||||
}
|
||||
}
|
||||
240
sgl-router/src/core/error.rs
Normal file
240
sgl-router/src/core/error.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
//! Error types for the SGLang router core
|
||||
//!
|
||||
//! This module defines error types used throughout the router for worker operations.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Worker-related errors
|
||||
#[derive(Debug)]
|
||||
pub enum WorkerError {
|
||||
/// Health check failed
|
||||
HealthCheckFailed { url: String, reason: String },
|
||||
/// Worker not found
|
||||
WorkerNotFound { url: String },
|
||||
/// Invalid worker configuration
|
||||
InvalidConfiguration { message: String },
|
||||
/// Network error
|
||||
NetworkError { url: String, error: String },
|
||||
/// Worker is at capacity
|
||||
WorkerAtCapacity { url: String },
|
||||
/// Invalid URL format
|
||||
InvalidUrl { url: String },
|
||||
}
|
||||
|
||||
impl fmt::Display for WorkerError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
WorkerError::HealthCheckFailed { url, reason } => {
|
||||
write!(f, "Health check failed for worker {}: {}", url, reason)
|
||||
}
|
||||
WorkerError::WorkerNotFound { url } => {
|
||||
write!(f, "Worker not found: {}", url)
|
||||
}
|
||||
WorkerError::InvalidConfiguration { message } => {
|
||||
write!(f, "Invalid worker configuration: {}", message)
|
||||
}
|
||||
WorkerError::NetworkError { url, error } => {
|
||||
write!(f, "Network error for worker {}: {}", url, error)
|
||||
}
|
||||
WorkerError::WorkerAtCapacity { url } => {
|
||||
write!(f, "Worker at capacity: {}", url)
|
||||
}
|
||||
WorkerError::InvalidUrl { url } => {
|
||||
write!(f, "Invalid URL format: {}", url)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WorkerError {}
|
||||
|
||||
/// Result type for worker operations
|
||||
pub type WorkerResult<T> = Result<T, WorkerError>;
|
||||
|
||||
/// Convert from reqwest errors to worker errors
|
||||
impl From<reqwest::Error> for WorkerError {
|
||||
fn from(err: reqwest::Error) -> Self {
|
||||
WorkerError::NetworkError {
|
||||
url: err.url().map(|u| u.to_string()).unwrap_or_default(),
|
||||
error: err.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::error::Error;
|
||||
|
||||
#[test]
|
||||
fn test_health_check_failed_display() {
|
||||
let error = WorkerError::HealthCheckFailed {
|
||||
url: "http://worker1:8080".to_string(),
|
||||
reason: "Connection refused".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"Health check failed for worker http://worker1:8080: Connection refused"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_worker_not_found_display() {
|
||||
let error = WorkerError::WorkerNotFound {
|
||||
url: "http://worker2:8080".to_string(),
|
||||
};
|
||||
assert_eq!(error.to_string(), "Worker not found: http://worker2:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_configuration_display() {
|
||||
let error = WorkerError::InvalidConfiguration {
|
||||
message: "Missing port number".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"Invalid worker configuration: Missing port number"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_network_error_display() {
|
||||
let error = WorkerError::NetworkError {
|
||||
url: "http://worker3:8080".to_string(),
|
||||
error: "Timeout after 30s".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"Network error for worker http://worker3:8080: Timeout after 30s"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_worker_at_capacity_display() {
|
||||
let error = WorkerError::WorkerAtCapacity {
|
||||
url: "http://worker4:8080".to_string(),
|
||||
};
|
||||
assert_eq!(error.to_string(), "Worker at capacity: http://worker4:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_worker_error_implements_std_error() {
|
||||
let error = WorkerError::WorkerNotFound {
|
||||
url: "http://test".to_string(),
|
||||
};
|
||||
// Verify it implements Error trait
|
||||
let _: &dyn Error = &error;
|
||||
assert!(error.source().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_send_sync() {
|
||||
fn assert_send_sync<T: Send + Sync>() {}
|
||||
assert_send_sync::<WorkerError>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_worker_result_type_alias() {
|
||||
// Test Ok variant
|
||||
let result: WorkerResult<i32> = Ok(42);
|
||||
assert!(matches!(result, Ok(42)));
|
||||
|
||||
// Test Err variant
|
||||
let error = WorkerError::WorkerNotFound {
|
||||
url: "test".to_string(),
|
||||
};
|
||||
let result: WorkerResult<i32> = Err(error);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_url_handling() {
|
||||
// Test empty URLs in error variants
|
||||
let error1 = WorkerError::HealthCheckFailed {
|
||||
url: "".to_string(),
|
||||
reason: "No connection".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
error1.to_string(),
|
||||
"Health check failed for worker : No connection"
|
||||
);
|
||||
|
||||
let error2 = WorkerError::NetworkError {
|
||||
url: "".to_string(),
|
||||
error: "DNS failure".to_string(),
|
||||
};
|
||||
assert_eq!(error2.to_string(), "Network error for worker : DNS failure");
|
||||
|
||||
let error3 = WorkerError::WorkerNotFound {
|
||||
url: "".to_string(),
|
||||
};
|
||||
assert_eq!(error3.to_string(), "Worker not found: ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_special_characters_in_messages() {
|
||||
// Test with special characters
|
||||
let error = WorkerError::InvalidConfiguration {
|
||||
message: "Invalid JSON: {\"error\": \"test\"}".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"Invalid worker configuration: Invalid JSON: {\"error\": \"test\"}"
|
||||
);
|
||||
|
||||
// Test with unicode
|
||||
let error2 = WorkerError::HealthCheckFailed {
|
||||
url: "http://测试:8080".to_string(),
|
||||
reason: "连接被拒绝".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
error2.to_string(),
|
||||
"Health check failed for worker http://测试:8080: 连接被拒绝"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_very_long_error_messages() {
|
||||
let long_message = "A".repeat(10000);
|
||||
let error = WorkerError::InvalidConfiguration {
|
||||
message: long_message.clone(),
|
||||
};
|
||||
let display = error.to_string();
|
||||
assert!(display.contains(&long_message));
|
||||
assert_eq!(
|
||||
display.len(),
|
||||
"Invalid worker configuration: ".len() + long_message.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Mock reqwest error for testing conversion
|
||||
#[test]
|
||||
fn test_reqwest_error_conversion() {
|
||||
// Test that NetworkError is the correct variant
|
||||
let network_error = WorkerError::NetworkError {
|
||||
url: "http://example.com".to_string(),
|
||||
error: "connection timeout".to_string(),
|
||||
};
|
||||
|
||||
match network_error {
|
||||
WorkerError::NetworkError { url, error } => {
|
||||
assert_eq!(url, "http://example.com");
|
||||
assert_eq!(error, "connection timeout");
|
||||
}
|
||||
_ => panic!("Expected NetworkError variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_equality() {
|
||||
// WorkerError doesn't implement PartialEq, but we can test that
|
||||
// the same error construction produces the same display output
|
||||
let error1 = WorkerError::WorkerNotFound {
|
||||
url: "http://test".to_string(),
|
||||
};
|
||||
let error2 = WorkerError::WorkerNotFound {
|
||||
url: "http://test".to_string(),
|
||||
};
|
||||
assert_eq!(error1.to_string(), error2.to_string());
|
||||
}
|
||||
}
|
||||
24
sgl-router/src/core/mod.rs
Normal file
24
sgl-router/src/core/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
//! Core abstractions for the SGLang router
|
||||
//!
|
||||
//! This module contains the fundamental types and traits used throughout the router:
|
||||
//! - Worker trait and implementations
|
||||
//! - Error types
|
||||
//! - Circuit breaker for reliability
|
||||
//! - Common utilities
|
||||
|
||||
pub mod circuit_breaker;
|
||||
pub mod error;
|
||||
pub mod retry;
|
||||
pub mod token_bucket;
|
||||
pub mod worker;
|
||||
|
||||
// Re-export commonly used types at the module level
|
||||
pub use circuit_breaker::{
|
||||
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
|
||||
};
|
||||
pub use error::{WorkerError, WorkerResult};
|
||||
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
|
||||
pub use worker::{
|
||||
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
|
||||
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||
};
|
||||
409
sgl-router/src/core/retry.rs
Normal file
409
sgl-router/src/core/retry.rs
Normal file
@@ -0,0 +1,409 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Response;
|
||||
use rand::Rng;
|
||||
use std::time::Duration;
|
||||
use tracing::debug;
|
||||
|
||||
/// Check if an HTTP status code indicates a retryable error
|
||||
pub fn is_retryable_status(status: StatusCode) -> bool {
|
||||
matches!(
|
||||
status,
|
||||
StatusCode::REQUEST_TIMEOUT
|
||||
| StatusCode::TOO_MANY_REQUESTS
|
||||
| StatusCode::INTERNAL_SERVER_ERROR
|
||||
| StatusCode::BAD_GATEWAY
|
||||
| StatusCode::SERVICE_UNAVAILABLE
|
||||
| StatusCode::GATEWAY_TIMEOUT
|
||||
)
|
||||
}
|
||||
|
||||
/// Computes exponential backoff with optional jitter.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BackoffCalculator;
|
||||
|
||||
impl BackoffCalculator {
|
||||
/// Calculate backoff delay for a given attempt index (0-based).
|
||||
pub fn calculate_delay(config: &RetryConfig, attempt: u32) -> Duration {
|
||||
// Base exponential backoff
|
||||
let pow = config.backoff_multiplier.powi(attempt as i32);
|
||||
let mut delay_ms = (config.initial_backoff_ms as f32 * pow) as u64;
|
||||
if delay_ms > config.max_backoff_ms {
|
||||
delay_ms = config.max_backoff_ms;
|
||||
}
|
||||
|
||||
// Apply jitter in range [-j, +j]
|
||||
let jitter = config.jitter_factor.clamp(0.0, 1.0);
|
||||
if jitter > 0.0 {
|
||||
let mut rng = rand::rng();
|
||||
let jitter_scale: f32 = rng.random_range(-jitter..=jitter);
|
||||
let jitter_ms = (delay_ms as f32 * jitter_scale)
|
||||
.round()
|
||||
.max(-(delay_ms as f32));
|
||||
let adjusted = (delay_ms as i64 + jitter_ms as i64).max(0) as u64;
|
||||
return Duration::from_millis(adjusted);
|
||||
}
|
||||
|
||||
Duration::from_millis(delay_ms)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RetryError {
|
||||
#[error("no available workers")]
|
||||
NoAvailableWorkers,
|
||||
#[error("maximum retry attempts exceeded")]
|
||||
MaxRetriesExceeded,
|
||||
}
|
||||
|
||||
/// A thin async retry executor for generic operations.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RetryExecutor;
|
||||
|
||||
impl RetryExecutor {
|
||||
/// Execute an async operation with retries and backoff.
|
||||
/// The `operation` closure is invoked each attempt with the attempt index.
|
||||
pub async fn execute_with_retry<F, Fut, T>(
|
||||
config: &RetryConfig,
|
||||
mut operation: F,
|
||||
) -> Result<T, RetryError>
|
||||
where
|
||||
F: FnMut(u32) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<T, ()>>,
|
||||
{
|
||||
let max = config.max_retries.max(1);
|
||||
let mut attempt: u32 = 0;
|
||||
loop {
|
||||
match operation(attempt).await {
|
||||
Ok(val) => return Ok(val),
|
||||
Err(_) => {
|
||||
// Use the number of failures so far (0-indexed) to compute delay,
|
||||
// so the first retry uses `initial_backoff_ms`.
|
||||
let is_last = attempt + 1 >= max;
|
||||
if is_last {
|
||||
return Err(RetryError::MaxRetriesExceeded);
|
||||
}
|
||||
let delay = BackoffCalculator::calculate_delay(config, attempt);
|
||||
attempt += 1; // advance to the next attempt after computing delay
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute an operation that returns an HTTP Response with retries and backoff.
|
||||
///
|
||||
/// Usage pattern:
|
||||
/// - `operation(attempt)`: perform one attempt (0-based). Construct and send the request,
|
||||
/// then return the `Response`. Do any per-attempt bookkeeping (e.g., load tracking,
|
||||
/// circuit-breaker outcome recording) inside this closure.
|
||||
/// - `should_retry(&response, attempt)`: decide if the given response should be retried
|
||||
/// (e.g., based on HTTP status). Returning false short-circuits and returns the response.
|
||||
/// - `on_backoff(delay, next_attempt)`: called before sleeping between attempts.
|
||||
/// Use this to record metrics.
|
||||
/// - `on_exhausted()`: called when the executor has exhausted all retry attempts.
|
||||
///
|
||||
/// Example:
|
||||
/// ```ignore
|
||||
/// let resp = RetryExecutor::execute_response_with_retry(
|
||||
/// &retry_cfg,
|
||||
/// |attempt| async move {
|
||||
/// let worker = select_cb_aware_worker()?;
|
||||
/// let resp = send_request(worker).await;
|
||||
/// worker.record_outcome(resp.status().is_success());
|
||||
/// resp
|
||||
/// },
|
||||
/// |res, _| matches!(res.status(), StatusCode::REQUEST_TIMEOUT | StatusCode::TOO_MANY_REQUESTS | StatusCode::INTERNAL_SERVER_ERROR | StatusCode::BAD_GATEWAY | StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT),
|
||||
/// |delay, attempt| RouterMetrics::record_retry_backoff_duration(delay, attempt),
|
||||
/// || RouterMetrics::record_retries_exhausted("/route"),
|
||||
/// ).await;
|
||||
/// ```
|
||||
pub async fn execute_response_with_retry<Op, Fut, ShouldRetry, OnBackoff, OnExhausted>(
|
||||
config: &RetryConfig,
|
||||
mut operation: Op,
|
||||
should_retry: ShouldRetry,
|
||||
on_backoff: OnBackoff,
|
||||
mut on_exhausted: OnExhausted,
|
||||
) -> Response
|
||||
where
|
||||
Op: FnMut(u32) -> Fut,
|
||||
Fut: std::future::Future<Output = Response>,
|
||||
ShouldRetry: Fn(&Response, u32) -> bool,
|
||||
OnBackoff: Fn(Duration, u32),
|
||||
OnExhausted: FnMut(),
|
||||
{
|
||||
let max = config.max_retries.max(1);
|
||||
|
||||
let mut attempt: u32 = 0;
|
||||
loop {
|
||||
let response = operation(attempt).await;
|
||||
let is_last = attempt + 1 >= max;
|
||||
|
||||
if !should_retry(&response, attempt) {
|
||||
return response;
|
||||
}
|
||||
|
||||
if is_last {
|
||||
// Exhausted retries
|
||||
on_exhausted();
|
||||
return response;
|
||||
}
|
||||
|
||||
// Backoff before next attempt
|
||||
let next_attempt = attempt + 1;
|
||||
// Compute delay based on the number of failures so far (0-indexed)
|
||||
let delay = BackoffCalculator::calculate_delay(config, attempt);
|
||||
debug!(
|
||||
attempt = attempt,
|
||||
next_attempt = next_attempt,
|
||||
delay_ms = delay.as_millis() as u64,
|
||||
"Retry backoff"
|
||||
);
|
||||
on_backoff(delay, next_attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
|
||||
attempt = next_attempt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn base_retry_config() -> RetryConfig {
|
||||
RetryConfig {
|
||||
max_retries: 3,
|
||||
initial_backoff_ms: 1,
|
||||
max_backoff_ms: 4,
|
||||
backoff_multiplier: 2.0,
|
||||
jitter_factor: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_no_jitter_progression_and_cap() {
|
||||
let cfg = RetryConfig {
|
||||
max_retries: 10,
|
||||
initial_backoff_ms: 100,
|
||||
max_backoff_ms: 250,
|
||||
backoff_multiplier: 2.0,
|
||||
jitter_factor: 0.0,
|
||||
};
|
||||
// attempt=0 => 100ms
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 0),
|
||||
Duration::from_millis(100)
|
||||
);
|
||||
// attempt=1 => 200ms
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 1),
|
||||
Duration::from_millis(200)
|
||||
);
|
||||
// attempt=2 => 400ms -> capped to 250ms
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 2),
|
||||
Duration::from_millis(250)
|
||||
);
|
||||
// large attempt still capped
|
||||
assert_eq!(
|
||||
BackoffCalculator::calculate_delay(&cfg, 10),
|
||||
Duration::from_millis(250)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_with_jitter_within_bounds() {
|
||||
let cfg = RetryConfig {
|
||||
max_retries: 5,
|
||||
initial_backoff_ms: 100,
|
||||
max_backoff_ms: 10_000,
|
||||
backoff_multiplier: 2.0,
|
||||
jitter_factor: 0.5,
|
||||
};
|
||||
// attempt=2 => base 400ms, jitter in [0.5x, 1.5x]
|
||||
let base = 400.0;
|
||||
for _ in 0..50 {
|
||||
let d = BackoffCalculator::calculate_delay(&cfg, 2).as_millis() as f32;
|
||||
assert!(d >= base * 0.5 - 1.0 && d <= base * 1.5 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_with_retry_success_after_failures() {
|
||||
let cfg = base_retry_config();
|
||||
let remaining = Arc::new(AtomicU32::new(2));
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let res: Result<u32, RetryError> = RetryExecutor::execute_with_retry(&cfg, {
|
||||
let remaining = remaining.clone();
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
let remaining = remaining.clone();
|
||||
async move {
|
||||
if remaining
|
||||
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
|
||||
.is_ok()
|
||||
{
|
||||
Err(())
|
||||
} else {
|
||||
Ok(42u32)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(res.is_ok());
|
||||
assert_eq!(res.unwrap(), 42);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_with_retry_exhausted() {
|
||||
let cfg = base_retry_config();
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
let res: Result<u32, RetryError> = RetryExecutor::execute_with_retry(&cfg, {
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
async move { Err(()) }
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, Err(RetryError::MaxRetriesExceeded)));
|
||||
assert_eq!(calls.load(Ordering::Relaxed), cfg.max_retries);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_response_with_retry_success_path_and_hooks() {
|
||||
let cfg = base_retry_config();
|
||||
let remaining = Arc::new(AtomicU32::new(2));
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
let backoffs = Arc::new(AtomicU32::new(0));
|
||||
let exhausted = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let response = RetryExecutor::execute_response_with_retry(
|
||||
&cfg,
|
||||
{
|
||||
let remaining = remaining.clone();
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
let remaining = remaining.clone();
|
||||
async move {
|
||||
if remaining
|
||||
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
|
||||
.is_ok()
|
||||
{
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "fail").into_response()
|
||||
} else {
|
||||
(StatusCode::OK, "ok").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|res, _attempt| !res.status().is_success(), // retry until success
|
||||
{
|
||||
let backoffs = backoffs.clone();
|
||||
move |_delay, _next_attempt| {
|
||||
backoffs.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
{
|
||||
let exhausted = exhausted.clone();
|
||||
move || {
|
||||
exhausted.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
|
||||
assert_eq!(backoffs.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_response_with_retry_non_retryable_short_circuit() {
|
||||
let cfg = base_retry_config();
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
let backoffs = Arc::new(AtomicU32::new(0));
|
||||
let exhausted = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let response = RetryExecutor::execute_response_with_retry(
|
||||
&cfg,
|
||||
{
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
async move { (StatusCode::BAD_REQUEST, "bad").into_response() }
|
||||
}
|
||||
},
|
||||
|_res, _attempt| false, // never retry
|
||||
{
|
||||
let backoffs = backoffs.clone();
|
||||
move |_delay, _next_attempt| {
|
||||
backoffs.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
{
|
||||
let exhausted = exhausted.clone();
|
||||
move || {
|
||||
exhausted.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(backoffs.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_response_with_retry_exhausted_hooks() {
|
||||
let cfg = base_retry_config();
|
||||
let calls = Arc::new(AtomicU32::new(0));
|
||||
let backoffs = Arc::new(AtomicU32::new(0));
|
||||
let exhausted = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let response = RetryExecutor::execute_response_with_retry(
|
||||
&cfg,
|
||||
{
|
||||
let calls = calls.clone();
|
||||
move |_attempt| {
|
||||
calls.fetch_add(1, Ordering::Relaxed);
|
||||
async move { (StatusCode::SERVICE_UNAVAILABLE, "fail").into_response() }
|
||||
}
|
||||
},
|
||||
|_res, _attempt| true, // keep retrying
|
||||
{
|
||||
let backoffs = backoffs.clone();
|
||||
move |_delay, _next_attempt| {
|
||||
backoffs.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
{
|
||||
let exhausted = exhausted.clone();
|
||||
move || {
|
||||
exhausted.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
assert_eq!(calls.load(Ordering::Relaxed), cfg.max_retries);
|
||||
assert_eq!(backoffs.load(Ordering::Relaxed), cfg.max_retries - 1);
|
||||
assert_eq!(exhausted.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
}
|
||||
195
sgl-router/src/core/token_bucket.rs
Normal file
195
sgl-router/src/core/token_bucket.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{Mutex, Notify};
|
||||
use tracing::{debug, trace};
|
||||
|
||||
/// Token bucket for rate limiting
|
||||
///
|
||||
/// This implementation provides:
|
||||
/// - Smooth rate limiting with configurable refill rate
|
||||
/// - Burst capacity handling
|
||||
/// - Fair queuing for waiting requests
|
||||
#[derive(Clone)]
|
||||
pub struct TokenBucket {
|
||||
inner: Arc<Mutex<TokenBucketInner>>,
|
||||
notify: Arc<Notify>,
|
||||
capacity: f64,
|
||||
refill_rate: f64, // tokens per second
|
||||
}
|
||||
|
||||
struct TokenBucketInner {
|
||||
tokens: f64,
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
/// Create a new token bucket
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `capacity` - Maximum number of tokens (burst capacity)
|
||||
/// * `refill_rate` - Tokens added per second
|
||||
pub fn new(capacity: usize, refill_rate: usize) -> Self {
|
||||
let capacity = capacity as f64;
|
||||
let refill_rate = refill_rate as f64;
|
||||
|
||||
// Ensure refill_rate is not zero to prevent division by zero
|
||||
let refill_rate = if refill_rate > 0.0 {
|
||||
refill_rate
|
||||
} else {
|
||||
1.0 // Default to 1 token per second if zero
|
||||
};
|
||||
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(TokenBucketInner {
|
||||
tokens: capacity, // Start full
|
||||
last_refill: Instant::now(),
|
||||
})),
|
||||
notify: Arc::new(Notify::new()),
|
||||
capacity,
|
||||
refill_rate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to acquire tokens immediately
|
||||
pub async fn try_acquire(&self, tokens: f64) -> Result<(), ()> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
// Refill tokens based on elapsed time
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
|
||||
let refill_amount = elapsed * self.refill_rate;
|
||||
|
||||
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
|
||||
inner.last_refill = now;
|
||||
|
||||
trace!(
|
||||
"Token bucket: {} tokens available, requesting {}",
|
||||
inner.tokens,
|
||||
tokens
|
||||
);
|
||||
|
||||
if inner.tokens >= tokens {
|
||||
inner.tokens -= tokens;
|
||||
debug!(
|
||||
"Token bucket: acquired {} tokens, {} remaining",
|
||||
tokens, inner.tokens
|
||||
);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquire tokens, waiting if necessary
|
||||
pub async fn acquire(&self, tokens: f64) -> Result<(), tokio::time::error::Elapsed> {
|
||||
// First try to acquire immediately
|
||||
if self.try_acquire(tokens).await.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Calculate wait time
|
||||
let wait_time = {
|
||||
let inner = self.inner.lock().await;
|
||||
let tokens_needed = tokens - inner.tokens;
|
||||
let wait_secs = tokens_needed / self.refill_rate;
|
||||
Duration::from_secs_f64(wait_secs)
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Token bucket: waiting {:?} for {} tokens",
|
||||
wait_time, tokens
|
||||
);
|
||||
|
||||
// Wait for tokens to be available
|
||||
tokio::time::timeout(wait_time, async {
|
||||
loop {
|
||||
// Check if we can acquire now
|
||||
if self.try_acquire(tokens).await.is_ok() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for notification or small interval
|
||||
tokio::select! {
|
||||
_ = self.notify.notified() => {},
|
||||
_ = tokio::time::sleep(Duration::from_millis(10)) => {},
|
||||
}
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Acquire tokens with custom timeout
|
||||
pub async fn acquire_timeout(
|
||||
&self,
|
||||
tokens: f64,
|
||||
timeout: Duration,
|
||||
) -> Result<(), tokio::time::error::Elapsed> {
|
||||
tokio::time::timeout(timeout, self.acquire(tokens)).await?
|
||||
}
|
||||
|
||||
/// Return tokens to the bucket (for cancelled requests)
|
||||
pub async fn return_tokens(&self, tokens: f64) {
|
||||
let mut inner = self.inner.lock().await;
|
||||
inner.tokens = (inner.tokens + tokens).min(self.capacity);
|
||||
self.notify.notify_waiters();
|
||||
debug!(
|
||||
"Token bucket: returned {} tokens, {} available",
|
||||
tokens, inner.tokens
|
||||
);
|
||||
}
|
||||
|
||||
/// Get current available tokens (for monitoring)
|
||||
pub async fn available_tokens(&self) -> f64 {
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
// Refill before checking
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
|
||||
let refill_amount = elapsed * self.refill_rate;
|
||||
|
||||
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
|
||||
inner.last_refill = now;
|
||||
|
||||
inner.tokens
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_bucket_basic() {
|
||||
let bucket = TokenBucket::new(10, 5); // 10 capacity, 5 per second
|
||||
|
||||
// Should succeed - bucket starts full
|
||||
assert!(bucket.try_acquire(5.0).await.is_ok());
|
||||
assert!(bucket.try_acquire(5.0).await.is_ok());
|
||||
|
||||
// Should fail - no tokens left
|
||||
assert!(bucket.try_acquire(1.0).await.is_err());
|
||||
|
||||
// Wait for refill
|
||||
tokio::time::sleep(Duration::from_millis(300)).await;
|
||||
|
||||
// Should have ~1.5 tokens now
|
||||
assert!(bucket.try_acquire(1.0).await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_bucket_refill() {
|
||||
let bucket = TokenBucket::new(10, 10); // 10 capacity, 10 per second
|
||||
|
||||
// Use all tokens
|
||||
assert!(bucket.try_acquire(10.0).await.is_ok());
|
||||
|
||||
// Wait for partial refill
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Should have ~5 tokens
|
||||
let available = bucket.available_tokens().await;
|
||||
assert!((4.0..=6.0).contains(&available));
|
||||
}
|
||||
}
|
||||
1980
sgl-router/src/core/worker.rs
Normal file
1980
sgl-router/src/core/worker.rs
Normal file
File diff suppressed because it is too large
Load Diff
327
sgl-router/src/grpc/client.rs
Normal file
327
sgl-router/src/grpc/client.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
use std::time::Duration;
|
||||
use tonic::{transport::Channel, Request};
|
||||
use tracing::debug;
|
||||
|
||||
// Include the generated protobuf code
|
||||
pub mod proto {
|
||||
tonic::include_proto!("sglang.grpc.scheduler");
|
||||
}
|
||||
|
||||
// The generated module structure depends on the package name in the .proto file
|
||||
// package sglang.grpc.scheduler; generates a nested module structure
|
||||
|
||||
/// gRPC client for SGLang scheduler
|
||||
pub struct SglangSchedulerClient {
|
||||
client: proto::sglang_scheduler_client::SglangSchedulerClient<Channel>,
|
||||
}
|
||||
|
||||
impl SglangSchedulerClient {
|
||||
/// Create a new client and connect to the scheduler
|
||||
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
debug!("Connecting to SGLang scheduler at {}", endpoint);
|
||||
|
||||
// Convert grpc:// to http:// for tonic
|
||||
let http_endpoint = if endpoint.starts_with("grpc://") {
|
||||
endpoint.replace("grpc://", "http://")
|
||||
} else {
|
||||
endpoint.to_string()
|
||||
};
|
||||
|
||||
let channel = Channel::from_shared(http_endpoint)?
|
||||
.timeout(Duration::from_secs(30))
|
||||
.connect()
|
||||
.await?;
|
||||
|
||||
let client = proto::sglang_scheduler_client::SglangSchedulerClient::new(channel);
|
||||
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
/// Initialize the connection
|
||||
pub async fn initialize(
|
||||
&mut self,
|
||||
client_id: String,
|
||||
) -> Result<proto::InitializeResponse, Box<dyn std::error::Error>> {
|
||||
let request = Request::new(proto::InitializeRequest {
|
||||
client_id,
|
||||
client_version: "0.1.0".to_string(),
|
||||
mode: proto::initialize_request::Mode::Regular as i32,
|
||||
});
|
||||
|
||||
let response = self.client.initialize(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
/// Submit a generation request (returns streaming response)
|
||||
pub async fn generate_stream(
|
||||
&mut self,
|
||||
req: proto::GenerateRequest,
|
||||
) -> Result<tonic::Streaming<proto::GenerateResponse>, Box<dyn std::error::Error>> {
|
||||
let request = Request::new(req);
|
||||
let response = self.client.generate(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
/// Perform health check
|
||||
pub async fn health_check(
|
||||
&mut self,
|
||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
|
||||
debug!("Sending health check request");
|
||||
let request = Request::new(proto::HealthCheckRequest {
|
||||
include_detailed_metrics: false,
|
||||
});
|
||||
|
||||
let response = self.client.health_check(request).await?;
|
||||
debug!("Health check response received");
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
/// Abort a request
|
||||
pub async fn abort_request(
|
||||
&mut self,
|
||||
request_id: String,
|
||||
reason: String,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let request = Request::new(proto::AbortRequest { request_id, reason });
|
||||
|
||||
self.client.abort(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush cache
|
||||
pub async fn flush_cache(
|
||||
&mut self,
|
||||
flush_all: bool,
|
||||
session_ids: &[String],
|
||||
) -> Result<proto::FlushCacheResponse, Box<dyn std::error::Error>> {
|
||||
let request = Request::new(proto::FlushCacheRequest {
|
||||
flush_all,
|
||||
session_ids: session_ids.to_vec(),
|
||||
});
|
||||
|
||||
let response = self.client.flush_cache(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_proto_types_compilation() {
|
||||
// Test that protobuf types can be constructed
|
||||
let init_req = proto::InitializeRequest {
|
||||
client_id: "test-client".to_string(),
|
||||
client_version: "0.1.0".to_string(),
|
||||
mode: 0,
|
||||
};
|
||||
assert_eq!(init_req.client_id, "test-client");
|
||||
assert_eq!(init_req.client_version, "0.1.0");
|
||||
assert_eq!(init_req.mode, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_request_construction() {
|
||||
let sampling_params = proto::SamplingParams {
|
||||
temperature: 0.7,
|
||||
max_new_tokens: 128,
|
||||
top_p: 0.9,
|
||||
top_k: 50,
|
||||
stop: vec!["</s>".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let gen_req = proto::GenerateRequest {
|
||||
request_id: "test-req-123".to_string(),
|
||||
input: Some(proto::generate_request::Input::Text(
|
||||
"Hello world".to_string(),
|
||||
)),
|
||||
sampling_params: Some(sampling_params),
|
||||
return_logprob: true,
|
||||
logprob_start_len: 0,
|
||||
top_logprobs_num: 5,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(gen_req.request_id, "test-req-123");
|
||||
if let Some(proto::generate_request::Input::Text(text)) = &gen_req.input {
|
||||
assert_eq!(text, "Hello world");
|
||||
}
|
||||
assert!(gen_req.return_logprob);
|
||||
assert_eq!(gen_req.top_logprobs_num, 5);
|
||||
|
||||
let params = gen_req.sampling_params.unwrap();
|
||||
assert_eq!(params.temperature, 0.7);
|
||||
assert_eq!(params.max_new_tokens, 128);
|
||||
assert_eq!(params.stop, vec!["</s>"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_health_check_request() {
|
||||
let health_req = proto::HealthCheckRequest {
|
||||
include_detailed_metrics: true,
|
||||
};
|
||||
assert!(health_req.include_detailed_metrics);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_abort_request_construction() {
|
||||
let abort_req = proto::AbortRequest {
|
||||
request_id: "req-456".to_string(),
|
||||
reason: "User canceled".to_string(),
|
||||
};
|
||||
assert_eq!(abort_req.request_id, "req-456");
|
||||
assert_eq!(abort_req.reason, "User canceled");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flush_cache_request() {
|
||||
let flush_req = proto::FlushCacheRequest {
|
||||
flush_all: true,
|
||||
session_ids: vec!["session1".to_string(), "session2".to_string()],
|
||||
};
|
||||
assert!(flush_req.flush_all);
|
||||
assert_eq!(flush_req.session_ids.len(), 2);
|
||||
assert_eq!(flush_req.session_ids[0], "session1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampling_params_defaults() {
|
||||
let params = proto::SamplingParams::default();
|
||||
assert_eq!(params.temperature, 0.0);
|
||||
assert_eq!(params.max_new_tokens, 0);
|
||||
assert_eq!(params.top_p, 0.0);
|
||||
assert_eq!(params.top_k, 0);
|
||||
assert!(params.stop.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multimodal_inputs() {
|
||||
let mm_inputs = proto::MultimodalInputs {
|
||||
image_urls: vec!["http://example.com/image.jpg".to_string()],
|
||||
video_urls: vec![],
|
||||
audio_urls: vec![],
|
||||
image_data: vec![],
|
||||
video_data: vec![],
|
||||
audio_data: vec![],
|
||||
modalities: vec!["image".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(mm_inputs.image_urls.len(), 1);
|
||||
assert_eq!(mm_inputs.image_urls[0], "http://example.com/image.jpg");
|
||||
assert_eq!(mm_inputs.modalities[0], "image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_params() {
|
||||
let session_params = proto::SessionParams {
|
||||
session_id: "sess-789".to_string(),
|
||||
request_id: "req-101".to_string(),
|
||||
offset: 100,
|
||||
replace: true,
|
||||
drop_previous_output: false,
|
||||
};
|
||||
|
||||
assert_eq!(session_params.session_id, "sess-789");
|
||||
assert_eq!(session_params.request_id, "req-101");
|
||||
assert_eq!(session_params.offset, 100);
|
||||
assert!(session_params.replace);
|
||||
assert!(!session_params.drop_previous_output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_request() {
|
||||
let embed_req = proto::EmbedRequest {
|
||||
request_id: "embed-req-202".to_string(),
|
||||
input: Some(proto::embed_request::Input::Text(
|
||||
"This is a test sentence for embedding".to_string(),
|
||||
)),
|
||||
log_metrics: true,
|
||||
data_parallel_rank: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(embed_req.request_id, "embed-req-202");
|
||||
if let Some(proto::embed_request::Input::Text(text)) = &embed_req.input {
|
||||
assert_eq!(text, "This is a test sentence for embedding");
|
||||
}
|
||||
assert!(embed_req.log_metrics);
|
||||
assert_eq!(embed_req.data_parallel_rank, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_connect_invalid_endpoint() {
|
||||
// Test connecting to an invalid endpoint should return error
|
||||
let result = SglangSchedulerClient::connect("invalid://endpoint").await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenized_input() {
|
||||
let tokenized = proto::TokenizedInput {
|
||||
original_text: "Hello world".to_string(),
|
||||
input_ids: vec![1, 15043, 1917, 2],
|
||||
};
|
||||
|
||||
assert_eq!(tokenized.original_text, "Hello world");
|
||||
assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]);
|
||||
}
|
||||
|
||||
// Test response type construction
|
||||
#[test]
|
||||
fn test_generate_stream_chunk() {
|
||||
let chunk = proto::GenerateStreamChunk {
|
||||
token_id: 1234,
|
||||
text: " world".to_string(),
|
||||
prompt_tokens: 5,
|
||||
completion_tokens: 2,
|
||||
cached_tokens: 3,
|
||||
generation_time: 0.025,
|
||||
queue_time: 10,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(chunk.token_id, 1234);
|
||||
assert_eq!(chunk.text, " world");
|
||||
assert_eq!(chunk.prompt_tokens, 5);
|
||||
assert_eq!(chunk.completion_tokens, 2);
|
||||
assert_eq!(chunk.cached_tokens, 3);
|
||||
assert_eq!(chunk.generation_time, 0.025);
|
||||
assert_eq!(chunk.queue_time, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_info() {
|
||||
let model_info = proto::ModelInfo {
|
||||
model_name: "Meta-Llama-3-8B-Instruct".to_string(),
|
||||
max_context_length: 8192,
|
||||
vocab_size: 128256,
|
||||
supports_tool_calling: true,
|
||||
supports_vision: false,
|
||||
special_tokens: vec![
|
||||
"<|begin_of_text|>".to_string(),
|
||||
"<|end_of_text|>".to_string(),
|
||||
],
|
||||
model_type: "llama".to_string(),
|
||||
num_layers: 32,
|
||||
hidden_size: 4096,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 8,
|
||||
tokenizer_type: "llama".to_string(),
|
||||
eos_token_ids: vec![128001, 128009],
|
||||
pad_token_id: 128001,
|
||||
bos_token_id: 128000,
|
||||
};
|
||||
|
||||
assert_eq!(model_info.model_name, "Meta-Llama-3-8B-Instruct");
|
||||
assert_eq!(model_info.max_context_length, 8192);
|
||||
assert_eq!(model_info.vocab_size, 128256);
|
||||
assert!(model_info.supports_tool_calling);
|
||||
assert!(!model_info.supports_vision);
|
||||
assert_eq!(model_info.special_tokens.len(), 2);
|
||||
assert_eq!(model_info.num_layers, 32);
|
||||
assert_eq!(model_info.eos_token_ids, vec![128001, 128009]);
|
||||
}
|
||||
}
|
||||
8
sgl-router/src/grpc/mod.rs
Normal file
8
sgl-router/src/grpc/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! gRPC client module for communicating with SGLang scheduler
|
||||
//!
|
||||
//! This module provides a gRPC client implementation for the SGLang router.
|
||||
|
||||
pub mod client;
|
||||
|
||||
// Re-export the client
|
||||
pub use client::{proto, SglangSchedulerClient};
|
||||
508
sgl-router/src/lib.rs
Normal file
508
sgl-router/src/lib.rs
Normal file
@@ -0,0 +1,508 @@
|
||||
use pyo3::prelude::*;
|
||||
pub mod config;
|
||||
pub mod logging;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub mod core;
|
||||
#[cfg(feature = "grpc-client")]
|
||||
pub mod grpc;
|
||||
pub mod mcp;
|
||||
pub mod metrics;
|
||||
pub mod middleware;
|
||||
pub mod policies;
|
||||
pub mod protocols;
|
||||
pub mod reasoning_parser;
|
||||
pub mod routers;
|
||||
pub mod server;
|
||||
pub mod service_discovery;
|
||||
pub mod tokenizer;
|
||||
pub mod tool_parser;
|
||||
pub mod tree;
|
||||
use crate::metrics::PrometheusConfig;
|
||||
|
||||
#[pyclass(eq)]
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
CacheAware,
|
||||
PowerOfTwo,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
struct Router {
|
||||
host: String,
|
||||
port: u16,
|
||||
worker_urls: Vec<String>,
|
||||
policy: PolicyType,
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval: u64,
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
max_payload_size: usize,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
log_dir: Option<String>,
|
||||
log_level: Option<String>,
|
||||
service_discovery: bool,
|
||||
selector: HashMap<String, String>,
|
||||
service_discovery_port: u16,
|
||||
service_discovery_namespace: Option<String>,
|
||||
prefill_selector: HashMap<String, String>,
|
||||
decode_selector: HashMap<String, String>,
|
||||
bootstrap_port_annotation: String,
|
||||
prometheus_port: Option<u16>,
|
||||
prometheus_host: Option<String>,
|
||||
request_timeout_secs: u64,
|
||||
request_id_headers: Option<Vec<String>>,
|
||||
pd_disaggregation: bool,
|
||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||
decode_urls: Option<Vec<String>>,
|
||||
prefill_policy: Option<PolicyType>,
|
||||
decode_policy: Option<PolicyType>,
|
||||
max_concurrent_requests: usize,
|
||||
cors_allowed_origins: Vec<String>,
|
||||
// Retry configuration
|
||||
retry_max_retries: u32,
|
||||
retry_initial_backoff_ms: u64,
|
||||
retry_max_backoff_ms: u64,
|
||||
retry_backoff_multiplier: f32,
|
||||
retry_jitter_factor: f32,
|
||||
disable_retries: bool,
|
||||
// Circuit breaker configuration
|
||||
cb_failure_threshold: u32,
|
||||
cb_success_threshold: u32,
|
||||
cb_timeout_duration_secs: u64,
|
||||
cb_window_duration_secs: u64,
|
||||
disable_circuit_breaker: bool,
|
||||
// Health check configuration
|
||||
health_failure_threshold: u32,
|
||||
health_success_threshold: u32,
|
||||
health_check_timeout_secs: u64,
|
||||
health_check_interval_secs: u64,
|
||||
health_check_endpoint: String,
|
||||
// IGW (Inference Gateway) configuration
|
||||
enable_igw: bool,
|
||||
queue_size: usize,
|
||||
queue_timeout_secs: u64,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
// Connection mode (determined from worker URLs)
|
||||
connection_mode: config::ConnectionMode,
|
||||
// Model path for tokenizer
|
||||
model_path: Option<String>,
|
||||
// Explicit tokenizer path
|
||||
tokenizer_path: Option<String>,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
/// Determine connection mode from worker URLs
|
||||
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
|
||||
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
|
||||
for url in worker_urls {
|
||||
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||
return config::ConnectionMode::Grpc;
|
||||
}
|
||||
}
|
||||
// Default to HTTP for all other cases (including http://, https://, or no scheme)
|
||||
config::ConnectionMode::Http
|
||||
}
|
||||
|
||||
/// Convert PyO3 Router to RouterConfig
|
||||
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
|
||||
use config::{
|
||||
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
|
||||
};
|
||||
|
||||
// Convert policy helper function
|
||||
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
|
||||
match policy {
|
||||
PolicyType::Random => ConfigPolicyConfig::Random,
|
||||
PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
|
||||
PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 5, // Default value
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
// Determine routing mode
|
||||
let mode = if self.enable_igw {
|
||||
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
}
|
||||
} else if self.pd_disaggregation {
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
|
||||
decode_urls: self.decode_urls.clone().unwrap_or_default(),
|
||||
prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
|
||||
decode_policy: self.decode_policy.as_ref().map(convert_policy),
|
||||
}
|
||||
} else {
|
||||
RoutingMode::Regular {
|
||||
worker_urls: self.worker_urls.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
// Convert main policy
|
||||
let policy = convert_policy(&self.policy);
|
||||
|
||||
// Service discovery configuration
|
||||
let discovery = if self.service_discovery {
|
||||
Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
namespace: self.service_discovery_namespace.clone(),
|
||||
port: self.service_discovery_port,
|
||||
check_interval_secs: 60,
|
||||
selector: self.selector.clone(),
|
||||
prefill_selector: self.prefill_selector.clone(),
|
||||
decode_selector: self.decode_selector.clone(),
|
||||
bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Metrics configuration
|
||||
let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
|
||||
(Some(port), Some(host)) => Some(MetricsConfig {
|
||||
port,
|
||||
host: host.clone(),
|
||||
}),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok(config::RouterConfig {
|
||||
mode,
|
||||
policy,
|
||||
host: self.host.clone(),
|
||||
port: self.port,
|
||||
connection_mode: self.connection_mode.clone(),
|
||||
max_payload_size: self.max_payload_size,
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval_secs: self.worker_startup_check_interval,
|
||||
dp_aware: self.dp_aware,
|
||||
api_key: self.api_key.clone(),
|
||||
discovery,
|
||||
metrics,
|
||||
log_dir: self.log_dir.clone(),
|
||||
log_level: self.log_level.clone(),
|
||||
request_id_headers: self.request_id_headers.clone(),
|
||||
max_concurrent_requests: self.max_concurrent_requests,
|
||||
queue_size: self.queue_size,
|
||||
queue_timeout_secs: self.queue_timeout_secs,
|
||||
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
|
||||
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||
retry: config::RetryConfig {
|
||||
max_retries: self.retry_max_retries,
|
||||
initial_backoff_ms: self.retry_initial_backoff_ms,
|
||||
max_backoff_ms: self.retry_max_backoff_ms,
|
||||
backoff_multiplier: self.retry_backoff_multiplier,
|
||||
jitter_factor: self.retry_jitter_factor,
|
||||
},
|
||||
circuit_breaker: config::CircuitBreakerConfig {
|
||||
failure_threshold: self.cb_failure_threshold,
|
||||
success_threshold: self.cb_success_threshold,
|
||||
timeout_duration_secs: self.cb_timeout_duration_secs,
|
||||
window_duration_secs: self.cb_window_duration_secs,
|
||||
},
|
||||
disable_retries: self.disable_retries,
|
||||
disable_circuit_breaker: self.disable_circuit_breaker,
|
||||
health_check: config::HealthCheckConfig {
|
||||
failure_threshold: self.health_failure_threshold,
|
||||
success_threshold: self.health_success_threshold,
|
||||
timeout_secs: self.health_check_timeout_secs,
|
||||
check_interval_secs: self.health_check_interval_secs,
|
||||
endpoint: self.health_check_endpoint.clone(),
|
||||
},
|
||||
enable_igw: self.enable_igw,
|
||||
model_path: self.model_path.clone(),
|
||||
tokenizer_path: self.tokenizer_path.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Router {
|
||||
#[new]
|
||||
#[pyo3(signature = (
|
||||
worker_urls,
|
||||
policy = PolicyType::RoundRobin,
|
||||
host = String::from("127.0.0.1"),
|
||||
port = 3001,
|
||||
worker_startup_timeout_secs = 600,
|
||||
worker_startup_check_interval = 30,
|
||||
cache_threshold = 0.3,
|
||||
balance_abs_threshold = 64,
|
||||
balance_rel_threshold = 1.5,
|
||||
eviction_interval_secs = 120,
|
||||
max_tree_size = 2usize.pow(26),
|
||||
max_payload_size = 512 * 1024 * 1024, // 512MB default for large batches
|
||||
dp_aware = false,
|
||||
api_key = None,
|
||||
log_dir = None,
|
||||
log_level = None,
|
||||
service_discovery = false,
|
||||
selector = HashMap::new(),
|
||||
service_discovery_port = 80,
|
||||
service_discovery_namespace = None,
|
||||
prefill_selector = HashMap::new(),
|
||||
decode_selector = HashMap::new(),
|
||||
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
|
||||
prometheus_port = None,
|
||||
prometheus_host = None,
|
||||
request_timeout_secs = 1800, // Add configurable request timeout
|
||||
request_id_headers = None, // Custom request ID headers
|
||||
pd_disaggregation = false, // New flag for PD mode
|
||||
prefill_urls = None,
|
||||
decode_urls = None,
|
||||
prefill_policy = None,
|
||||
decode_policy = None,
|
||||
max_concurrent_requests = 256,
|
||||
cors_allowed_origins = vec![],
|
||||
// Retry defaults
|
||||
retry_max_retries = 5,
|
||||
retry_initial_backoff_ms = 50,
|
||||
retry_max_backoff_ms = 30_000,
|
||||
retry_backoff_multiplier = 1.5,
|
||||
retry_jitter_factor = 0.2,
|
||||
disable_retries = false,
|
||||
// Circuit breaker defaults
|
||||
cb_failure_threshold = 10,
|
||||
cb_success_threshold = 3,
|
||||
cb_timeout_duration_secs = 60,
|
||||
cb_window_duration_secs = 120,
|
||||
disable_circuit_breaker = false,
|
||||
// Health check defaults
|
||||
health_failure_threshold = 3,
|
||||
health_success_threshold = 2,
|
||||
health_check_timeout_secs = 5,
|
||||
health_check_interval_secs = 60,
|
||||
health_check_endpoint = String::from("/health"),
|
||||
// IGW defaults
|
||||
enable_igw = false,
|
||||
queue_size = 100,
|
||||
queue_timeout_secs = 60,
|
||||
rate_limit_tokens_per_second = None,
|
||||
// Tokenizer defaults
|
||||
model_path = None,
|
||||
tokenizer_path = None,
|
||||
))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: PolicyType,
|
||||
host: String,
|
||||
port: u16,
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval: u64,
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
max_payload_size: usize,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
log_dir: Option<String>,
|
||||
log_level: Option<String>,
|
||||
service_discovery: bool,
|
||||
selector: HashMap<String, String>,
|
||||
service_discovery_port: u16,
|
||||
service_discovery_namespace: Option<String>,
|
||||
prefill_selector: HashMap<String, String>,
|
||||
decode_selector: HashMap<String, String>,
|
||||
bootstrap_port_annotation: String,
|
||||
prometheus_port: Option<u16>,
|
||||
prometheus_host: Option<String>,
|
||||
request_timeout_secs: u64,
|
||||
request_id_headers: Option<Vec<String>>,
|
||||
pd_disaggregation: bool,
|
||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||
decode_urls: Option<Vec<String>>,
|
||||
prefill_policy: Option<PolicyType>,
|
||||
decode_policy: Option<PolicyType>,
|
||||
max_concurrent_requests: usize,
|
||||
cors_allowed_origins: Vec<String>,
|
||||
retry_max_retries: u32,
|
||||
retry_initial_backoff_ms: u64,
|
||||
retry_max_backoff_ms: u64,
|
||||
retry_backoff_multiplier: f32,
|
||||
retry_jitter_factor: f32,
|
||||
disable_retries: bool,
|
||||
cb_failure_threshold: u32,
|
||||
cb_success_threshold: u32,
|
||||
cb_timeout_duration_secs: u64,
|
||||
cb_window_duration_secs: u64,
|
||||
disable_circuit_breaker: bool,
|
||||
health_failure_threshold: u32,
|
||||
health_success_threshold: u32,
|
||||
health_check_timeout_secs: u64,
|
||||
health_check_interval_secs: u64,
|
||||
health_check_endpoint: String,
|
||||
enable_igw: bool,
|
||||
queue_size: usize,
|
||||
queue_timeout_secs: u64,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
model_path: Option<String>,
|
||||
tokenizer_path: Option<String>,
|
||||
) -> PyResult<Self> {
|
||||
// Determine connection mode from worker URLs
|
||||
let mut all_urls = worker_urls.clone();
|
||||
|
||||
// Add prefill URLs if in PD mode
|
||||
if let Some(ref prefill_urls) = prefill_urls {
|
||||
for (url, _) in prefill_urls {
|
||||
all_urls.push(url.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Add decode URLs if in PD mode
|
||||
if let Some(ref decode_urls) = decode_urls {
|
||||
all_urls.extend(decode_urls.clone());
|
||||
}
|
||||
|
||||
let connection_mode = Self::determine_connection_mode(&all_urls);
|
||||
|
||||
Ok(Router {
|
||||
host,
|
||||
port,
|
||||
worker_urls,
|
||||
policy,
|
||||
worker_startup_timeout_secs,
|
||||
worker_startup_check_interval,
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
max_payload_size,
|
||||
dp_aware,
|
||||
api_key,
|
||||
log_dir,
|
||||
log_level,
|
||||
service_discovery,
|
||||
selector,
|
||||
service_discovery_port,
|
||||
service_discovery_namespace,
|
||||
prefill_selector,
|
||||
decode_selector,
|
||||
bootstrap_port_annotation,
|
||||
prometheus_port,
|
||||
prometheus_host,
|
||||
request_timeout_secs,
|
||||
request_id_headers,
|
||||
pd_disaggregation,
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
max_concurrent_requests,
|
||||
cors_allowed_origins,
|
||||
retry_max_retries,
|
||||
retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms,
|
||||
retry_backoff_multiplier,
|
||||
retry_jitter_factor,
|
||||
disable_retries,
|
||||
cb_failure_threshold,
|
||||
cb_success_threshold,
|
||||
cb_timeout_duration_secs,
|
||||
cb_window_duration_secs,
|
||||
disable_circuit_breaker,
|
||||
health_failure_threshold,
|
||||
health_success_threshold,
|
||||
health_check_timeout_secs,
|
||||
health_check_interval_secs,
|
||||
health_check_endpoint,
|
||||
enable_igw,
|
||||
queue_size,
|
||||
queue_timeout_secs,
|
||||
rate_limit_tokens_per_second,
|
||||
connection_mode,
|
||||
model_path,
|
||||
tokenizer_path,
|
||||
})
|
||||
}
|
||||
|
||||
fn start(&self) -> PyResult<()> {
|
||||
// Convert to RouterConfig and validate
|
||||
let router_config = self.to_router_config().map_err(|e| {
|
||||
pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
|
||||
})?;
|
||||
|
||||
// Validate the configuration
|
||||
router_config.validate().map_err(|e| {
|
||||
pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Configuration validation failed: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
// Create service discovery config if enabled
|
||||
let service_discovery_config = if self.service_discovery {
|
||||
Some(service_discovery::ServiceDiscoveryConfig {
|
||||
enabled: true,
|
||||
selector: self.selector.clone(),
|
||||
check_interval: std::time::Duration::from_secs(60),
|
||||
port: self.service_discovery_port,
|
||||
namespace: self.service_discovery_namespace.clone(),
|
||||
pd_mode: self.pd_disaggregation,
|
||||
prefill_selector: self.prefill_selector.clone(),
|
||||
decode_selector: self.decode_selector.clone(),
|
||||
bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Create Prometheus config if enabled
|
||||
let prometheus_config = Some(PrometheusConfig {
|
||||
port: self.prometheus_port.unwrap_or(29000),
|
||||
host: self
|
||||
.prometheus_host
|
||||
.clone()
|
||||
.unwrap_or_else(|| "127.0.0.1".to_string()),
|
||||
});
|
||||
|
||||
// Use tokio runtime instead of actix-web System for better compatibility
|
||||
let runtime = tokio::runtime::Runtime::new()
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
|
||||
|
||||
// Block on the async startup function
|
||||
runtime.block_on(async move {
|
||||
server::startup(server::ServerConfig {
|
||||
host: self.host.clone(),
|
||||
port: self.port,
|
||||
router_config,
|
||||
max_payload_size: self.max_payload_size,
|
||||
log_dir: self.log_dir.clone(),
|
||||
log_level: self.log_level.clone(),
|
||||
service_discovery_config,
|
||||
prometheus_config,
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
request_id_headers: self.request_id_headers.clone(),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PolicyType>()?;
|
||||
m.add_class::<Router>()?;
|
||||
Ok(())
|
||||
}
|
||||
163
sgl-router/src/logging.rs
Normal file
163
sgl-router/src/logging.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
use std::path::PathBuf;
|
||||
use tracing::Level;
|
||||
use tracing_appender::non_blocking::WorkerGuard;
|
||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||
use tracing_log::LogTracer;
|
||||
use tracing_subscriber::fmt::time::ChronoUtc;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{EnvFilter, Layer};
|
||||
|
||||
/// Configuration for the logging system
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoggingConfig {
|
||||
/// Log level for the application (default: INFO)
|
||||
pub level: Level,
|
||||
/// Whether to use json format for logs (default: false)
|
||||
pub json_format: bool,
|
||||
/// Path to store log files. If None, logs will only go to stdout/stderr
|
||||
pub log_dir: Option<String>,
|
||||
/// Whether to colorize logs when output is a terminal (default: true)
|
||||
pub colorize: bool,
|
||||
/// Log file name to use if log_dir is specified (default: "sgl-router")
|
||||
pub log_file_name: String,
|
||||
/// Custom log targets to filter (default: "sglang_router_rs")
|
||||
pub log_targets: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Default for LoggingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
level: Level::INFO,
|
||||
json_format: false,
|
||||
log_dir: None,
|
||||
colorize: true,
|
||||
log_file_name: "sgl-router".to_string(),
|
||||
log_targets: Some(vec!["sglang_router_rs".to_string()]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Guard that keeps the file appender worker thread alive
|
||||
///
|
||||
/// This must be kept in scope for the duration of the program
|
||||
/// to ensure logs are properly written to files
|
||||
#[allow(dead_code)]
|
||||
pub struct LogGuard {
|
||||
_file_guard: Option<WorkerGuard>,
|
||||
}
|
||||
|
||||
/// Initialize the logging system with the given configuration
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `config` - Configuration for the logging system
|
||||
///
|
||||
/// # Returns
|
||||
/// A LogGuard that must be kept alive for the duration of the program
|
||||
///
|
||||
/// # Panics
|
||||
/// Will not panic, as initialization errors are handled gracefully
|
||||
pub fn init_logging(config: LoggingConfig) -> LogGuard {
|
||||
// Forward logs to tracing - ignore errors to allow for multiple initialization
|
||||
let _ = LogTracer::init();
|
||||
|
||||
// Convert log level to filter string
|
||||
let level_filter = match config.level {
|
||||
Level::TRACE => "trace",
|
||||
Level::DEBUG => "debug",
|
||||
Level::INFO => "info",
|
||||
Level::WARN => "warn",
|
||||
Level::ERROR => "error",
|
||||
};
|
||||
|
||||
// Create env filter
|
||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
||||
// Format: <target>=<level>,<target2>=<level2>,...
|
||||
let filter_string = if let Some(targets) = &config.log_targets {
|
||||
targets
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, target)| {
|
||||
if i > 0 {
|
||||
format!(",{}={}", target, level_filter)
|
||||
} else {
|
||||
format!("{}={}", target, level_filter)
|
||||
}
|
||||
})
|
||||
.collect::<String>()
|
||||
} else {
|
||||
format!("sglang_router_rs={}", level_filter)
|
||||
};
|
||||
|
||||
EnvFilter::new(filter_string)
|
||||
});
|
||||
|
||||
// Setup stdout/stderr layer
|
||||
let mut layers = Vec::new();
|
||||
|
||||
// Standard timestamp format: YYYY-MM-DD HH:MM:SS
|
||||
let time_format = "%Y-%m-%d %H:%M:%S".to_string();
|
||||
|
||||
// Configure the console stdout layer
|
||||
let stdout_layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(config.colorize)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.with_timer(ChronoUtc::new(time_format.clone()));
|
||||
|
||||
let stdout_layer = if config.json_format {
|
||||
stdout_layer.json().flatten_event(true).boxed()
|
||||
} else {
|
||||
stdout_layer.boxed()
|
||||
};
|
||||
|
||||
layers.push(stdout_layer);
|
||||
|
||||
// Create a file appender if log_dir is specified
|
||||
let mut file_guard = None;
|
||||
|
||||
if let Some(log_dir) = &config.log_dir {
|
||||
let file_name = config.log_file_name.clone();
|
||||
let log_dir = PathBuf::from(log_dir);
|
||||
|
||||
// Create log directory if it doesn't exist
|
||||
if !log_dir.exists() {
|
||||
if let Err(e) = std::fs::create_dir_all(&log_dir) {
|
||||
eprintln!("Failed to create log directory: {}", e);
|
||||
return LogGuard { _file_guard: None };
|
||||
}
|
||||
}
|
||||
|
||||
let file_appender = RollingFileAppender::new(Rotation::DAILY, log_dir, file_name);
|
||||
|
||||
let (non_blocking, guard) = tracing_appender::non_blocking(file_appender);
|
||||
file_guard = Some(guard);
|
||||
|
||||
let file_layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false) // Never use ANSI colors in log files
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.with_timer(ChronoUtc::new(time_format))
|
||||
.with_writer(non_blocking);
|
||||
|
||||
let file_layer = if config.json_format {
|
||||
file_layer.json().flatten_event(true).boxed()
|
||||
} else {
|
||||
file_layer.boxed()
|
||||
};
|
||||
|
||||
layers.push(file_layer);
|
||||
}
|
||||
|
||||
// Initialize the subscriber with all layers
|
||||
// Use try_init to handle errors gracefully in case another subscriber is already set
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(layers)
|
||||
.try_init();
|
||||
|
||||
// Return the guard to keep the file appender worker thread alive
|
||||
LogGuard {
|
||||
_file_guard: file_guard,
|
||||
}
|
||||
}
|
||||
636
sgl-router/src/main.rs
Normal file
636
sgl-router/src/main.rs
Normal file
@@ -0,0 +1,636 @@
|
||||
use clap::{ArgAction, Parser, ValueEnum};
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
||||
HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::metrics::PrometheusConfig;
|
||||
use sglang_router_rs::server::{self, ServerConfig};
|
||||
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Helper function to parse prefill arguments from command line
|
||||
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let mut prefill_entries = Vec::new();
|
||||
let mut i = 0;
|
||||
|
||||
while i < args.len() {
|
||||
if args[i] == "--prefill" && i + 1 < args.len() {
|
||||
let url = args[i + 1].clone();
|
||||
let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") {
|
||||
// Check if next arg is a port number
|
||||
if let Ok(port) = args[i + 2].parse::<u16>() {
|
||||
i += 1; // Skip the port argument
|
||||
Some(port)
|
||||
} else if args[i + 2].to_lowercase() == "none" {
|
||||
i += 1; // Skip the "none" argument
|
||||
None
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
prefill_entries.push((url, bootstrap_port));
|
||||
i += 2; // Skip --prefill and URL
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
prefill_entries
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
|
||||
pub enum Backend {
|
||||
#[value(name = "sglang")]
|
||||
Sglang,
|
||||
#[value(name = "vllm")]
|
||||
Vllm,
|
||||
#[value(name = "trtllm")]
|
||||
Trtllm,
|
||||
#[value(name = "openai")]
|
||||
Openai,
|
||||
#[value(name = "anthropic")]
|
||||
Anthropic,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Backend {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
Backend::Sglang => "sglang",
|
||||
Backend::Vllm => "vllm",
|
||||
Backend::Trtllm => "trtllm",
|
||||
Backend::Openai => "openai",
|
||||
Backend::Anthropic => "anthropic",
|
||||
};
|
||||
write!(f, "{}", s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "sglang-router")]
|
||||
#[command(about = "SGLang Router - High-performance request distribution across worker nodes")]
|
||||
#[command(long_about = r#"
|
||||
SGLang Router - High-performance request distribution across worker nodes
|
||||
|
||||
Usage:
|
||||
This launcher enables starting a router with individual worker instances. It is useful for
|
||||
multi-node setups or when you want to start workers and router separately.
|
||||
|
||||
Examples:
|
||||
# Regular mode
|
||||
sglang-router --worker-urls http://worker1:8000 http://worker2:8000
|
||||
|
||||
# PD disaggregated mode with same policy for both
|
||||
sglang-router --pd-disaggregation \
|
||||
--prefill http://127.0.0.1:30001 9001 \
|
||||
--prefill http://127.0.0.2:30002 9002 \
|
||||
--decode http://127.0.0.3:30003 \
|
||||
--decode http://127.0.0.4:30004 \
|
||||
--policy cache_aware
|
||||
|
||||
# PD mode with different policies for prefill and decode
|
||||
sglang-router --pd-disaggregation \
|
||||
--prefill http://127.0.0.1:30001 9001 \
|
||||
--prefill http://127.0.0.2:30002 \
|
||||
--decode http://127.0.0.3:30003 \
|
||||
--decode http://127.0.0.4:30004 \
|
||||
--prefill-policy cache_aware --decode-policy power_of_two
|
||||
|
||||
"#)]
|
||||
struct CliArgs {
|
||||
/// Host address to bind the router server
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
|
||||
/// Port number to bind the router server
|
||||
#[arg(long, default_value_t = 30000)]
|
||||
port: u16,
|
||||
|
||||
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
|
||||
#[arg(long, num_args = 0..)]
|
||||
worker_urls: Vec<String>,
|
||||
|
||||
/// Load balancing policy to use
|
||||
#[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||
policy: String,
|
||||
|
||||
/// Enable PD (Prefill-Decode) disaggregated mode
|
||||
#[arg(long, default_value_t = false)]
|
||||
pd_disaggregation: bool,
|
||||
|
||||
/// Decode server URL (can be specified multiple times)
|
||||
#[arg(long, action = ArgAction::Append)]
|
||||
decode: Vec<String>,
|
||||
|
||||
/// Specific policy for prefill nodes in PD mode
|
||||
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||
prefill_policy: Option<String>,
|
||||
|
||||
/// Specific policy for decode nodes in PD mode
|
||||
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
|
||||
decode_policy: Option<String>,
|
||||
|
||||
/// Timeout in seconds for worker startup
|
||||
#[arg(long, default_value_t = 600)]
|
||||
worker_startup_timeout_secs: u64,
|
||||
|
||||
/// Interval in seconds between checks for worker startup
|
||||
#[arg(long, default_value_t = 30)]
|
||||
worker_startup_check_interval: u64,
|
||||
|
||||
/// Cache threshold (0.0-1.0) for cache-aware routing
|
||||
#[arg(long, default_value_t = 0.3)]
|
||||
cache_threshold: f32,
|
||||
|
||||
/// Absolute threshold for load balancing
|
||||
#[arg(long, default_value_t = 64)]
|
||||
balance_abs_threshold: usize,
|
||||
|
||||
/// Relative threshold for load balancing
|
||||
#[arg(long, default_value_t = 1.5)]
|
||||
balance_rel_threshold: f32,
|
||||
|
||||
/// Interval in seconds between cache eviction operations
|
||||
#[arg(long, default_value_t = 120)]
|
||||
eviction_interval: u64,
|
||||
|
||||
/// Maximum size of the approximation tree for cache-aware routing
|
||||
#[arg(long, default_value_t = 67108864)] // 2^26
|
||||
max_tree_size: usize,
|
||||
|
||||
/// Maximum payload size in bytes
|
||||
#[arg(long, default_value_t = 536870912)] // 512MB
|
||||
max_payload_size: usize,
|
||||
|
||||
/// Enable data parallelism aware schedule
|
||||
#[arg(long, default_value_t = false)]
|
||||
dp_aware: bool,
|
||||
|
||||
/// API key for worker authorization
|
||||
#[arg(long)]
|
||||
api_key: Option<String>,
|
||||
|
||||
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
|
||||
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
|
||||
backend: Backend,
|
||||
|
||||
/// Directory to store log files
|
||||
#[arg(long)]
|
||||
log_dir: Option<String>,
|
||||
|
||||
/// Set the logging level
|
||||
#[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])]
|
||||
log_level: String,
|
||||
|
||||
/// Enable Kubernetes service discovery
|
||||
#[arg(long, default_value_t = false)]
|
||||
service_discovery: bool,
|
||||
|
||||
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
|
||||
#[arg(long, num_args = 0..)]
|
||||
selector: Vec<String>,
|
||||
|
||||
/// Port to use for discovered worker pods
|
||||
#[arg(long, default_value_t = 80)]
|
||||
service_discovery_port: u16,
|
||||
|
||||
/// Kubernetes namespace to watch for pods
|
||||
#[arg(long)]
|
||||
service_discovery_namespace: Option<String>,
|
||||
|
||||
/// Label selector for prefill server pods in PD mode
|
||||
#[arg(long, num_args = 0..)]
|
||||
prefill_selector: Vec<String>,
|
||||
|
||||
/// Label selector for decode server pods in PD mode
|
||||
#[arg(long, num_args = 0..)]
|
||||
decode_selector: Vec<String>,
|
||||
|
||||
/// Port to expose Prometheus metrics
|
||||
#[arg(long, default_value_t = 29000)]
|
||||
prometheus_port: u16,
|
||||
|
||||
/// Host address to bind the Prometheus metrics server
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
prometheus_host: String,
|
||||
|
||||
/// Custom HTTP headers to check for request IDs
|
||||
#[arg(long, num_args = 0..)]
|
||||
request_id_headers: Vec<String>,
|
||||
|
||||
/// Request timeout in seconds
|
||||
#[arg(long, default_value_t = 1800)]
|
||||
request_timeout_secs: u64,
|
||||
|
||||
/// Maximum number of concurrent requests allowed
|
||||
#[arg(long, default_value_t = 256)]
|
||||
max_concurrent_requests: usize,
|
||||
|
||||
/// CORS allowed origins
|
||||
#[arg(long, num_args = 0..)]
|
||||
cors_allowed_origins: Vec<String>,
|
||||
|
||||
// Retry configuration
|
||||
/// Maximum number of retries
|
||||
#[arg(long, default_value_t = 5)]
|
||||
retry_max_retries: u32,
|
||||
|
||||
/// Initial backoff in milliseconds for retries
|
||||
#[arg(long, default_value_t = 50)]
|
||||
retry_initial_backoff_ms: u64,
|
||||
|
||||
/// Maximum backoff in milliseconds for retries
|
||||
#[arg(long, default_value_t = 30000)]
|
||||
retry_max_backoff_ms: u64,
|
||||
|
||||
/// Backoff multiplier for exponential backoff
|
||||
#[arg(long, default_value_t = 1.5)]
|
||||
retry_backoff_multiplier: f32,
|
||||
|
||||
/// Jitter factor for retry backoff
|
||||
#[arg(long, default_value_t = 0.2)]
|
||||
retry_jitter_factor: f32,
|
||||
|
||||
/// Disable retries
|
||||
#[arg(long, default_value_t = false)]
|
||||
disable_retries: bool,
|
||||
|
||||
// Circuit breaker configuration
|
||||
/// Number of failures before circuit breaker opens
|
||||
#[arg(long, default_value_t = 10)]
|
||||
cb_failure_threshold: u32,
|
||||
|
||||
/// Number of successes before circuit breaker closes
|
||||
#[arg(long, default_value_t = 3)]
|
||||
cb_success_threshold: u32,
|
||||
|
||||
/// Timeout duration in seconds for circuit breaker
|
||||
#[arg(long, default_value_t = 60)]
|
||||
cb_timeout_duration_secs: u64,
|
||||
|
||||
/// Window duration in seconds for circuit breaker
|
||||
#[arg(long, default_value_t = 120)]
|
||||
cb_window_duration_secs: u64,
|
||||
|
||||
/// Disable circuit breaker
|
||||
#[arg(long, default_value_t = false)]
|
||||
disable_circuit_breaker: bool,
|
||||
|
||||
// Health check configuration
|
||||
/// Number of consecutive health check failures before marking worker unhealthy
|
||||
#[arg(long, default_value_t = 3)]
|
||||
health_failure_threshold: u32,
|
||||
|
||||
/// Number of consecutive health check successes before marking worker healthy
|
||||
#[arg(long, default_value_t = 2)]
|
||||
health_success_threshold: u32,
|
||||
|
||||
/// Timeout in seconds for health check requests
|
||||
#[arg(long, default_value_t = 5)]
|
||||
health_check_timeout_secs: u64,
|
||||
|
||||
/// Interval in seconds between runtime health checks
|
||||
#[arg(long, default_value_t = 60)]
|
||||
health_check_interval_secs: u64,
|
||||
|
||||
/// Health check endpoint path
|
||||
#[arg(long, default_value = "/health")]
|
||||
health_check_endpoint: String,
|
||||
|
||||
// IGW (Inference Gateway) configuration
|
||||
/// Enable Inference Gateway mode
|
||||
#[arg(long, default_value_t = false)]
|
||||
enable_igw: bool,
|
||||
|
||||
// Tokenizer configuration
|
||||
/// Model path for loading tokenizer (HuggingFace model ID or local path)
|
||||
#[arg(long)]
|
||||
model_path: Option<String>,
|
||||
|
||||
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
|
||||
#[arg(long)]
|
||||
tokenizer_path: Option<String>,
|
||||
}
|
||||
|
||||
impl CliArgs {
|
||||
/// Determine connection mode from worker URLs
|
||||
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
|
||||
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
|
||||
for url in worker_urls {
|
||||
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||
return ConnectionMode::Grpc;
|
||||
}
|
||||
}
|
||||
// Default to HTTP for all other cases (including http://, https://, or no scheme)
|
||||
ConnectionMode::Http
|
||||
}
|
||||
|
||||
/// Parse selector strings into HashMap
|
||||
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
|
||||
let mut map = HashMap::new();
|
||||
for item in selector_list {
|
||||
if let Some(eq_pos) = item.find('=') {
|
||||
let key = item[..eq_pos].to_string();
|
||||
let value = item[eq_pos + 1..].to_string();
|
||||
map.insert(key, value);
|
||||
}
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
/// Convert policy string to PolicyConfig
|
||||
fn parse_policy(&self, policy_str: &str) -> PolicyConfig {
|
||||
match policy_str {
|
||||
"random" => PolicyConfig::Random,
|
||||
"round_robin" => PolicyConfig::RoundRobin,
|
||||
"cache_aware" => PolicyConfig::CacheAware {
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
eviction_interval_secs: self.eviction_interval,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
"power_of_two" => PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 5, // Default value
|
||||
},
|
||||
_ => PolicyConfig::RoundRobin, // Fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert CLI arguments to RouterConfig
|
||||
fn to_router_config(
|
||||
&self,
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
) -> ConfigResult<RouterConfig> {
|
||||
// Determine routing mode
|
||||
let mode = if self.enable_igw {
|
||||
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
}
|
||||
} else if matches!(self.backend, Backend::Openai) {
|
||||
// OpenAI backend mode - use worker_urls as base(s)
|
||||
RoutingMode::OpenAI {
|
||||
worker_urls: self.worker_urls.clone(),
|
||||
}
|
||||
} else if self.pd_disaggregation {
|
||||
let decode_urls = self.decode.clone();
|
||||
|
||||
// Validate PD configuration if not using service discovery
|
||||
if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy: self.prefill_policy.as_ref().map(|p| self.parse_policy(p)),
|
||||
decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)),
|
||||
}
|
||||
} else {
|
||||
// Regular mode
|
||||
if !self.service_discovery && self.worker_urls.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "Regular mode requires --worker-urls when not using service discovery"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
RoutingMode::Regular {
|
||||
worker_urls: self.worker_urls.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
// Main policy
|
||||
let policy = self.parse_policy(&self.policy);
|
||||
|
||||
// Service discovery configuration
|
||||
let discovery = if self.service_discovery {
|
||||
Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
namespace: self.service_discovery_namespace.clone(),
|
||||
port: self.service_discovery_port,
|
||||
check_interval_secs: 60,
|
||||
selector: Self::parse_selector(&self.selector),
|
||||
prefill_selector: Self::parse_selector(&self.prefill_selector),
|
||||
decode_selector: Self::parse_selector(&self.decode_selector),
|
||||
bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Metrics configuration
|
||||
let metrics = Some(MetricsConfig {
|
||||
port: self.prometheus_port,
|
||||
host: self.prometheus_host.clone(),
|
||||
});
|
||||
|
||||
// Determine connection mode from all worker URLs
|
||||
let mut all_urls = Vec::new();
|
||||
match &mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
all_urls.extend(worker_urls.clone());
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
..
|
||||
} => {
|
||||
for (url, _) in prefill_urls {
|
||||
all_urls.push(url.clone());
|
||||
}
|
||||
all_urls.extend(decode_urls.clone());
|
||||
}
|
||||
RoutingMode::OpenAI { .. } => {
|
||||
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
|
||||
}
|
||||
}
|
||||
let connection_mode = match &mode {
|
||||
RoutingMode::OpenAI { .. } => ConnectionMode::Http,
|
||||
_ => Self::determine_connection_mode(&all_urls),
|
||||
};
|
||||
|
||||
// Build RouterConfig
|
||||
Ok(RouterConfig {
|
||||
mode,
|
||||
policy,
|
||||
connection_mode,
|
||||
host: self.host.clone(),
|
||||
port: self.port,
|
||||
max_payload_size: self.max_payload_size,
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval_secs: self.worker_startup_check_interval,
|
||||
dp_aware: self.dp_aware,
|
||||
api_key: self.api_key.clone(),
|
||||
discovery,
|
||||
metrics,
|
||||
log_dir: self.log_dir.clone(),
|
||||
log_level: Some(self.log_level.clone()),
|
||||
request_id_headers: if self.request_id_headers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.request_id_headers.clone())
|
||||
},
|
||||
max_concurrent_requests: self.max_concurrent_requests,
|
||||
queue_size: 100, // Default queue size
|
||||
queue_timeout_secs: 60, // Default timeout
|
||||
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||
retry: RetryConfig {
|
||||
max_retries: self.retry_max_retries,
|
||||
initial_backoff_ms: self.retry_initial_backoff_ms,
|
||||
max_backoff_ms: self.retry_max_backoff_ms,
|
||||
backoff_multiplier: self.retry_backoff_multiplier,
|
||||
jitter_factor: self.retry_jitter_factor,
|
||||
},
|
||||
circuit_breaker: CircuitBreakerConfig {
|
||||
failure_threshold: self.cb_failure_threshold,
|
||||
success_threshold: self.cb_success_threshold,
|
||||
timeout_duration_secs: self.cb_timeout_duration_secs,
|
||||
window_duration_secs: self.cb_window_duration_secs,
|
||||
},
|
||||
disable_retries: self.disable_retries,
|
||||
disable_circuit_breaker: self.disable_circuit_breaker,
|
||||
health_check: HealthCheckConfig {
|
||||
failure_threshold: self.health_failure_threshold,
|
||||
success_threshold: self.health_success_threshold,
|
||||
timeout_secs: self.health_check_timeout_secs,
|
||||
check_interval_secs: self.health_check_interval_secs,
|
||||
endpoint: self.health_check_endpoint.clone(),
|
||||
},
|
||||
enable_igw: self.enable_igw,
|
||||
rate_limit_tokens_per_second: None,
|
||||
model_path: self.model_path.clone(),
|
||||
tokenizer_path: self.tokenizer_path.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create ServerConfig from CLI args and RouterConfig
|
||||
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
|
||||
// Create service discovery config if enabled
|
||||
let service_discovery_config = if self.service_discovery {
|
||||
Some(ServiceDiscoveryConfig {
|
||||
enabled: true,
|
||||
selector: Self::parse_selector(&self.selector),
|
||||
check_interval: std::time::Duration::from_secs(60),
|
||||
port: self.service_discovery_port,
|
||||
namespace: self.service_discovery_namespace.clone(),
|
||||
pd_mode: self.pd_disaggregation,
|
||||
prefill_selector: Self::parse_selector(&self.prefill_selector),
|
||||
decode_selector: Self::parse_selector(&self.decode_selector),
|
||||
bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Create Prometheus config
|
||||
let prometheus_config = Some(PrometheusConfig {
|
||||
port: self.prometheus_port,
|
||||
host: self.prometheus_host.clone(),
|
||||
});
|
||||
|
||||
ServerConfig {
|
||||
host: self.host.clone(),
|
||||
port: self.port,
|
||||
router_config,
|
||||
max_payload_size: self.max_payload_size,
|
||||
log_dir: self.log_dir.clone(),
|
||||
log_level: Some(self.log_level.clone()),
|
||||
service_discovery_config,
|
||||
prometheus_config,
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
request_id_headers: if self.request_id_headers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.request_id_headers.clone())
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Parse prefill arguments manually before clap parsing
|
||||
let prefill_urls = parse_prefill_args();
|
||||
|
||||
// Filter out prefill arguments and their values before passing to clap
|
||||
let mut filtered_args: Vec<String> = Vec::new();
|
||||
let raw_args: Vec<String> = std::env::args().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < raw_args.len() {
|
||||
if raw_args[i] == "--prefill" && i + 1 < raw_args.len() {
|
||||
// Skip --prefill and its URL
|
||||
i += 2;
|
||||
// Also skip bootstrap port if present
|
||||
if i < raw_args.len()
|
||||
&& !raw_args[i].starts_with("--")
|
||||
&& (raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none")
|
||||
{
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
filtered_args.push(raw_args[i].clone());
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse CLI arguments with clap using filtered args
|
||||
let cli_args = CliArgs::parse_from(filtered_args);
|
||||
|
||||
// Print startup info
|
||||
println!("SGLang Router starting...");
|
||||
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
||||
let mode_str = if cli_args.enable_igw {
|
||||
"IGW (Inference Gateway)".to_string()
|
||||
} else if matches!(cli_args.backend, Backend::Openai) {
|
||||
"OpenAI Backend".to_string()
|
||||
} else if cli_args.pd_disaggregation {
|
||||
"PD Disaggregated".to_string()
|
||||
} else {
|
||||
format!("Regular ({})", cli_args.backend)
|
||||
};
|
||||
println!("Mode: {}", mode_str);
|
||||
|
||||
// Warn for runtimes that are parsed but not yet implemented
|
||||
match cli_args.backend {
|
||||
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
|
||||
println!(
|
||||
"WARNING: runtime '{}' not implemented yet; falling back to regular routing. \
|
||||
Provide --worker-urls or PD flags as usual.",
|
||||
cli_args.backend
|
||||
);
|
||||
}
|
||||
Backend::Sglang | Backend::Openai => {}
|
||||
}
|
||||
|
||||
if !cli_args.enable_igw {
|
||||
println!("Policy: {}", cli_args.policy);
|
||||
|
||||
if cli_args.pd_disaggregation && !prefill_urls.is_empty() {
|
||||
println!("Prefill nodes: {:?}", prefill_urls);
|
||||
println!("Decode nodes: {:?}", cli_args.decode);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to RouterConfig
|
||||
let router_config = cli_args.to_router_config(prefill_urls)?;
|
||||
|
||||
// Validate configuration
|
||||
router_config.validate()?;
|
||||
|
||||
// Create ServerConfig
|
||||
let server_config = cli_args.to_server_config(router_config);
|
||||
|
||||
// Create a new runtime for the server (like Python binding does)
|
||||
let runtime = tokio::runtime::Runtime::new()?;
|
||||
|
||||
// Block on the async startup function
|
||||
runtime.block_on(async move { server::startup(server_config).await })?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
535
sgl-router/src/mcp/client_manager.rs
Normal file
535
sgl-router/src/mcp/client_manager.rs
Normal file
@@ -0,0 +1,535 @@
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use dashmap::DashMap;
|
||||
use rmcp::{
|
||||
model::{
|
||||
CallToolRequestParam, GetPromptRequestParam, GetPromptResult, Prompt,
|
||||
ReadResourceRequestParam, ReadResourceResult, Resource, Tool as McpTool,
|
||||
},
|
||||
service::RunningService,
|
||||
transport::{
|
||||
sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig,
|
||||
ConfigureCommandExt, SseClientTransport, StreamableHttpClientTransport, TokioChildProcess,
|
||||
},
|
||||
RoleClient, ServiceExt,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{borrow::Cow, collections::HashMap, time::Duration};
|
||||
|
||||
use crate::mcp::{
|
||||
config::{McpConfig, McpServerConfig, McpTransport},
|
||||
error::{McpError, McpResult},
|
||||
};
|
||||
|
||||
/// Information about an available tool
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolInfo {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub server: String,
|
||||
pub parameters: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Information about an available prompt
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptInfo {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub server: String,
|
||||
pub arguments: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// Information about an available resource
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResourceInfo {
|
||||
pub uri: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub mime_type: Option<String>,
|
||||
pub server: String,
|
||||
}
|
||||
|
||||
/// Manages MCP client connections and tool execution
|
||||
pub struct McpClientManager {
|
||||
/// Map of server_name -> MCP client
|
||||
clients: HashMap<String, RunningService<RoleClient, ()>>,
|
||||
/// Map of tool_name -> (server_name, tool_definition)
|
||||
tools: DashMap<String, (String, McpTool)>,
|
||||
/// Map of prompt_name -> (server_name, prompt_definition)
|
||||
prompts: DashMap<String, (String, Prompt)>,
|
||||
/// Map of resource_uri -> (server_name, resource_definition)
|
||||
resources: DashMap<String, (String, Resource)>,
|
||||
}
|
||||
|
||||
impl McpClientManager {
|
||||
/// Create a new manager and connect to all configured servers
|
||||
pub async fn new(config: McpConfig) -> McpResult<Self> {
|
||||
let mut mgr = Self {
|
||||
clients: HashMap::new(),
|
||||
tools: DashMap::new(),
|
||||
prompts: DashMap::new(),
|
||||
resources: DashMap::new(),
|
||||
};
|
||||
|
||||
for server_config in config.servers {
|
||||
match Self::connect_server(&server_config).await {
|
||||
Ok(client) => {
|
||||
mgr.load_server_inventory(&server_config.name, &client)
|
||||
.await;
|
||||
mgr.clients.insert(server_config.name.clone(), client);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to connect to server '{}': {}",
|
||||
server_config.name,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mgr.clients.is_empty() {
|
||||
return Err(McpError::ConnectionFailed(
|
||||
"Failed to connect to any MCP servers".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(mgr)
|
||||
}
|
||||
|
||||
/// Discover and cache tools/prompts/resources for a connected server
|
||||
async fn load_server_inventory(
|
||||
&self,
|
||||
server_name: &str,
|
||||
client: &RunningService<RoleClient, ()>,
|
||||
) {
|
||||
// Tools
|
||||
match client.peer().list_all_tools().await {
|
||||
Ok(ts) => {
|
||||
tracing::info!("Discovered {} tools from '{}'", ts.len(), server_name);
|
||||
for t in ts {
|
||||
if self.tools.contains_key(t.name.as_ref()) {
|
||||
tracing::warn!(
|
||||
"Tool '{}' from server '{}' is overwriting an existing tool.",
|
||||
&t.name,
|
||||
server_name
|
||||
);
|
||||
}
|
||||
self.tools
|
||||
.insert(t.name.to_string(), (server_name.to_string(), t));
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::warn!("Failed to list tools from '{}': {}", server_name, e),
|
||||
}
|
||||
|
||||
// Prompts
|
||||
match client.peer().list_all_prompts().await {
|
||||
Ok(ps) => {
|
||||
tracing::info!("Discovered {} prompts from '{}'", ps.len(), server_name);
|
||||
for p in ps {
|
||||
if self.prompts.contains_key(&p.name) {
|
||||
tracing::warn!(
|
||||
"Prompt '{}' from server '{}' is overwriting an existing prompt.",
|
||||
&p.name,
|
||||
server_name
|
||||
);
|
||||
}
|
||||
self.prompts
|
||||
.insert(p.name.clone(), (server_name.to_string(), p));
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::debug!("No prompts or failed to list on '{}': {}", server_name, e),
|
||||
}
|
||||
|
||||
// Resources
|
||||
match client.peer().list_all_resources().await {
|
||||
Ok(rs) => {
|
||||
tracing::info!("Discovered {} resources from '{}'", rs.len(), server_name);
|
||||
for r in rs {
|
||||
if self.resources.contains_key(&r.uri) {
|
||||
tracing::warn!(
|
||||
"Resource '{}' from server '{}' is overwriting an existing resource.",
|
||||
&r.uri,
|
||||
server_name
|
||||
);
|
||||
}
|
||||
self.resources
|
||||
.insert(r.uri.clone(), (server_name.to_string(), r));
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::debug!("No resources or failed to list on '{}': {}", server_name, e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a single MCP server with retry logic for remote transports
|
||||
async fn connect_server(config: &McpServerConfig) -> McpResult<RunningService<RoleClient, ()>> {
|
||||
let needs_retry = matches!(
|
||||
&config.transport,
|
||||
McpTransport::Sse { .. } | McpTransport::Streamable { .. }
|
||||
);
|
||||
if needs_retry {
|
||||
Self::connect_server_with_retry(config).await
|
||||
} else {
|
||||
Self::connect_server_impl(config).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect with exponential backoff retry for remote servers
|
||||
async fn connect_server_with_retry(
|
||||
config: &McpServerConfig,
|
||||
) -> McpResult<RunningService<RoleClient, ()>> {
|
||||
let backoff = ExponentialBackoffBuilder::new()
|
||||
.with_initial_interval(Duration::from_secs(1))
|
||||
.with_max_interval(Duration::from_secs(30))
|
||||
.with_max_elapsed_time(Some(Duration::from_secs(120)))
|
||||
.build();
|
||||
|
||||
backoff::future::retry(backoff, || async {
|
||||
match Self::connect_server_impl(config).await {
|
||||
Ok(client) => Ok(client),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to connect to '{}', retrying: {}", config.name, e);
|
||||
Err(backoff::Error::transient(e))
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Internal implementation of server connection
|
||||
async fn connect_server_impl(
|
||||
config: &McpServerConfig,
|
||||
) -> McpResult<RunningService<RoleClient, ()>> {
|
||||
tracing::info!(
|
||||
"Connecting to MCP server '{}' via {:?}",
|
||||
config.name,
|
||||
config.transport
|
||||
);
|
||||
|
||||
match &config.transport {
|
||||
McpTransport::Stdio {
|
||||
command,
|
||||
args,
|
||||
envs,
|
||||
} => {
|
||||
let transport = TokioChildProcess::new(
|
||||
tokio::process::Command::new(command).configure(|cmd| {
|
||||
cmd.args(args)
|
||||
.envs(envs.iter())
|
||||
.stderr(std::process::Stdio::inherit());
|
||||
}),
|
||||
)
|
||||
.map_err(|e| McpError::Transport(format!("create stdio transport: {}", e)))?;
|
||||
|
||||
let client = ().serve(transport).await.map_err(|e| {
|
||||
McpError::ConnectionFailed(format!("initialize stdio client: {}", e))
|
||||
})?;
|
||||
|
||||
tracing::info!("Connected to stdio server '{}'", config.name);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
McpTransport::Sse { url, token } => {
|
||||
let transport = if let Some(tok) = token {
|
||||
let client = reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", tok).parse().map_err(|e| {
|
||||
McpError::Transport(format!("auth token: {}", e))
|
||||
})?,
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.map_err(|e| McpError::Transport(format!("build HTTP client: {}", e)))?;
|
||||
|
||||
let cfg = SseClientConfig {
|
||||
sse_endpoint: url.clone().into(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
SseClientTransport::start_with_client(client, cfg)
|
||||
.await
|
||||
.map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))?
|
||||
} else {
|
||||
SseClientTransport::start(url.as_str())
|
||||
.await
|
||||
.map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))?
|
||||
};
|
||||
|
||||
let client = ().serve(transport).await.map_err(|e| {
|
||||
McpError::ConnectionFailed(format!("initialize SSE client: {}", e))
|
||||
})?;
|
||||
|
||||
tracing::info!("Connected to SSE server '{}' at {}", config.name, url);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
McpTransport::Streamable { url, token } => {
|
||||
let transport = if let Some(tok) = token {
|
||||
let mut cfg = StreamableHttpClientTransportConfig::with_uri(url.as_str());
|
||||
cfg.auth_header = Some(format!("Bearer {}", tok));
|
||||
StreamableHttpClientTransport::from_config(cfg)
|
||||
} else {
|
||||
StreamableHttpClientTransport::from_uri(url.as_str())
|
||||
};
|
||||
|
||||
let client = ().serve(transport).await.map_err(|e| {
|
||||
McpError::ConnectionFailed(format!("initialize streamable client: {}", e))
|
||||
})?;
|
||||
|
||||
tracing::info!(
|
||||
"Connected to streamable HTTP server '{}' at {}",
|
||||
config.name,
|
||||
url
|
||||
);
|
||||
Ok(client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Helpers =====
|
||||
|
||||
fn client_for(&self, server_name: &str) -> McpResult<&RunningService<RoleClient, ()>> {
|
||||
self.clients
|
||||
.get(server_name)
|
||||
.ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))
|
||||
}
|
||||
|
||||
fn tool_entry(&self, name: &str) -> McpResult<(String, McpTool)> {
|
||||
self.tools
|
||||
.get(name)
|
||||
.map(|e| e.value().clone())
|
||||
.ok_or_else(|| McpError::ToolNotFound(name.to_string()))
|
||||
}
|
||||
|
||||
fn prompt_entry(&self, name: &str) -> McpResult<(String, Prompt)> {
|
||||
self.prompts
|
||||
.get(name)
|
||||
.map(|e| e.value().clone())
|
||||
.ok_or_else(|| McpError::PromptNotFound(name.to_string()))
|
||||
}
|
||||
|
||||
fn resource_entry(&self, uri: &str) -> McpResult<(String, Resource)> {
|
||||
self.resources
|
||||
.get(uri)
|
||||
.map(|e| e.value().clone())
|
||||
.ok_or_else(|| McpError::ResourceNotFound(uri.to_string()))
|
||||
}
|
||||
|
||||
// ===== Tool Methods =====
|
||||
|
||||
/// Call a tool by name
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
) -> McpResult<rmcp::model::CallToolResult> {
|
||||
let (server_name, _tool) = self.tool_entry(tool_name)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Calling tool '{}' on '{}'", tool_name, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.call_tool(CallToolRequestParam {
|
||||
name: Cow::Owned(tool_name.to_string()),
|
||||
arguments,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Tool call failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Get all available tools
|
||||
pub fn list_tools(&self) -> Vec<ToolInfo> {
|
||||
self.tools
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
let tool_name = entry.key().clone();
|
||||
let (server_name, tool) = entry.value();
|
||||
ToolInfo {
|
||||
name: tool_name,
|
||||
description: tool.description.as_deref().unwrap_or_default().to_string(),
|
||||
server: server_name.clone(),
|
||||
parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific tool by name
|
||||
pub fn get_tool(&self, name: &str) -> Option<ToolInfo> {
|
||||
self.tools.get(name).map(|entry| {
|
||||
let (server_name, tool) = entry.value();
|
||||
ToolInfo {
|
||||
name: name.to_string(),
|
||||
description: tool.description.as_deref().unwrap_or_default().to_string(),
|
||||
server: server_name.clone(),
|
||||
parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a tool exists
|
||||
pub fn has_tool(&self, name: &str) -> bool {
|
||||
self.tools.contains_key(name)
|
||||
}
|
||||
|
||||
/// Get list of connected servers
|
||||
pub fn list_servers(&self) -> Vec<String> {
|
||||
self.clients.keys().cloned().collect()
|
||||
}
|
||||
|
||||
// ===== Prompt Methods =====
|
||||
|
||||
/// Get a prompt by name with arguments
|
||||
pub async fn get_prompt(
|
||||
&self,
|
||||
prompt_name: &str,
|
||||
arguments: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
) -> McpResult<GetPromptResult> {
|
||||
let (server_name, _prompt) = self.prompt_entry(prompt_name)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Getting prompt '{}' from '{}'", prompt_name, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.get_prompt(GetPromptRequestParam {
|
||||
name: prompt_name.to_string(),
|
||||
arguments,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Failed to get prompt: {}", e)))
|
||||
}
|
||||
|
||||
/// List all available prompts
|
||||
pub fn list_prompts(&self) -> Vec<PromptInfo> {
|
||||
self.prompts
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
let name = entry.key().clone();
|
||||
let (server_name, prompt) = entry.value();
|
||||
PromptInfo {
|
||||
name,
|
||||
description: prompt.description.clone(),
|
||||
server: server_name.clone(),
|
||||
arguments: prompt
|
||||
.arguments
|
||||
.clone()
|
||||
.map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific prompt info by name
|
||||
pub fn get_prompt_info(&self, name: &str) -> Option<PromptInfo> {
|
||||
self.prompts.get(name).map(|entry| {
|
||||
let (server_name, prompt) = entry.value();
|
||||
PromptInfo {
|
||||
name: name.to_string(),
|
||||
description: prompt.description.clone(),
|
||||
server: server_name.clone(),
|
||||
arguments: prompt
|
||||
.arguments
|
||||
.clone()
|
||||
.map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ===== Resource Methods =====
|
||||
|
||||
/// Read a resource by URI
|
||||
pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
|
||||
let (server_name, _resource) = self.resource_entry(uri)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Reading resource '{}' from '{}'", uri, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.read_resource(ReadResourceRequestParam {
|
||||
uri: uri.to_string(),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Failed to read resource: {}", e)))
|
||||
}
|
||||
|
||||
/// List all available resources
|
||||
pub fn list_resources(&self) -> Vec<ResourceInfo> {
|
||||
self.resources
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
let uri = entry.key().clone();
|
||||
let (server_name, resource) = entry.value();
|
||||
ResourceInfo {
|
||||
uri,
|
||||
name: resource.name.clone(),
|
||||
description: resource.description.clone(),
|
||||
mime_type: resource.mime_type.clone(),
|
||||
server: server_name.clone(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific resource info by URI
|
||||
pub fn get_resource_info(&self, uri: &str) -> Option<ResourceInfo> {
|
||||
self.resources.get(uri).map(|entry| {
|
||||
let (server_name, resource) = entry.value();
|
||||
ResourceInfo {
|
||||
uri: uri.to_string(),
|
||||
name: resource.name.clone(),
|
||||
description: resource.description.clone(),
|
||||
mime_type: resource.mime_type.clone(),
|
||||
server: server_name.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Subscribe to resource changes
|
||||
pub async fn subscribe_resource(&self, uri: &str) -> McpResult<()> {
|
||||
let (server_name, _resource) = self.resource_entry(uri)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Subscribing to '{}' on '{}'", uri, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.subscribe(rmcp::model::SubscribeRequestParam {
|
||||
uri: uri.to_string(),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Failed to subscribe: {}", e)))
|
||||
}
|
||||
|
||||
/// Unsubscribe from resource changes
|
||||
pub async fn unsubscribe_resource(&self, uri: &str) -> McpResult<()> {
|
||||
let (server_name, _resource) = self.resource_entry(uri)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Unsubscribing from '{}' on '{}'", uri, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.unsubscribe(rmcp::model::UnsubscribeRequestParam {
|
||||
uri: uri.to_string(),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Failed to unsubscribe: {}", e)))
|
||||
}
|
||||
|
||||
/// Disconnect from all servers (for cleanup)
|
||||
pub async fn shutdown(&mut self) {
|
||||
for (name, client) in self.clients.drain() {
|
||||
if let Err(e) = client.cancel().await {
|
||||
tracing::warn!("Error disconnecting from '{}': {}", name, e);
|
||||
}
|
||||
}
|
||||
self.tools.clear();
|
||||
self.prompts.clear();
|
||||
self.resources.clear();
|
||||
}
|
||||
}
|
||||
52
sgl-router/src/mcp/config.rs
Normal file
52
sgl-router/src/mcp/config.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpConfig {
|
||||
pub servers: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpServerConfig {
|
||||
pub name: String,
|
||||
#[serde(flatten)]
|
||||
pub transport: McpTransport,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "protocol", rename_all = "lowercase")]
|
||||
pub enum McpTransport {
|
||||
Stdio {
|
||||
command: String,
|
||||
#[serde(default)]
|
||||
args: Vec<String>,
|
||||
#[serde(default)]
|
||||
envs: HashMap<String, String>,
|
||||
},
|
||||
Sse {
|
||||
url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
token: Option<String>,
|
||||
},
|
||||
Streamable {
|
||||
url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
token: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl McpConfig {
|
||||
/// Load configuration from a YAML file
|
||||
pub async fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let content = tokio::fs::read_to_string(path).await?;
|
||||
let config: Self = serde_yaml::from_str(&content)?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Load configuration from environment variables (optional)
|
||||
pub fn from_env() -> Option<Self> {
|
||||
// This could be expanded to read from env vars
|
||||
// For now, return None to indicate env config not implemented
|
||||
None
|
||||
}
|
||||
}
|
||||
42
sgl-router/src/mcp/error.rs
Normal file
42
sgl-router/src/mcp/error.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use thiserror::Error;
|
||||
|
||||
pub type McpResult<T> = Result<T, McpError>;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum McpError {
|
||||
#[error("Server not found: {0}")]
|
||||
ServerNotFound(String),
|
||||
|
||||
#[error("Tool not found: {0}")]
|
||||
ToolNotFound(String),
|
||||
|
||||
#[error("Transport error: {0}")]
|
||||
Transport(String),
|
||||
|
||||
#[error("Tool execution failed: {0}")]
|
||||
ToolExecution(String),
|
||||
|
||||
#[error("Connection failed: {0}")]
|
||||
ConnectionFailed(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("Authentication error: {0}")]
|
||||
Auth(String),
|
||||
|
||||
#[error("Resource not found: {0}")]
|
||||
ResourceNotFound(String),
|
||||
|
||||
#[error("Prompt not found: {0}")]
|
||||
PromptNotFound(String),
|
||||
|
||||
#[error(transparent)]
|
||||
Sdk(#[from] Box<rmcp::RmcpError>),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Http(#[from] reqwest::Error),
|
||||
}
|
||||
18
sgl-router/src/mcp/mod.rs
Normal file
18
sgl-router/src/mcp/mod.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
// MCP Client for SGLang Router
|
||||
//
|
||||
// This module provides a complete MCP (Model Context Protocol) client implementation
|
||||
// supporting multiple transport types (stdio, SSE, HTTP) and all MCP features:
|
||||
// - Tools: Discovery and execution
|
||||
// - Prompts: Reusable templates for LLM interactions
|
||||
// - Resources: File/data access with subscription support
|
||||
// - OAuth: Secure authentication for remote servers
|
||||
|
||||
pub mod client_manager;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod oauth;
|
||||
|
||||
// Re-export the main types for convenience
|
||||
pub use client_manager::{McpClientManager, PromptInfo, ResourceInfo, ToolInfo};
|
||||
pub use config::{McpConfig, McpServerConfig, McpTransport};
|
||||
pub use error::{McpError, McpResult};
|
||||
191
sgl-router/src/mcp/oauth.rs
Normal file
191
sgl-router/src/mcp/oauth.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
// OAuth authentication support for MCP servers
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::Html,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use serde::Deserialize;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
|
||||
use crate::mcp::error::{McpError, McpResult};
|
||||
|
||||
/// OAuth callback parameters
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CallbackParams {
|
||||
code: String,
|
||||
#[allow(dead_code)]
|
||||
state: Option<String>,
|
||||
}
|
||||
|
||||
/// State for the callback server
|
||||
#[derive(Clone)]
|
||||
struct CallbackState {
|
||||
code_receiver: Arc<Mutex<Option<oneshot::Sender<String>>>>,
|
||||
}
|
||||
|
||||
/// HTML page returned after successful OAuth callback
|
||||
const CALLBACK_HTML: &str = r#"
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>OAuth Success</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
}
|
||||
.container {
|
||||
background: white;
|
||||
padding: 40px;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
|
||||
text-align: center;
|
||||
}
|
||||
h1 { color: #333; }
|
||||
p { color: #666; margin: 20px 0; }
|
||||
.success { color: #4CAF50; font-size: 48px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="success">✓</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<p>You can now close this window and return to your application.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"#;
|
||||
|
||||
/// OAuth authentication helper for MCP servers
|
||||
pub struct OAuthHelper {
|
||||
server_url: String,
|
||||
redirect_uri: String,
|
||||
callback_port: u16,
|
||||
}
|
||||
|
||||
impl OAuthHelper {
|
||||
/// Create a new OAuth helper
|
||||
pub fn new(server_url: String, redirect_uri: String, callback_port: u16) -> Self {
|
||||
Self {
|
||||
server_url,
|
||||
redirect_uri,
|
||||
callback_port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform OAuth authentication flow
|
||||
pub async fn authenticate(
|
||||
&self,
|
||||
scopes: &[&str],
|
||||
) -> McpResult<rmcp::transport::auth::AuthorizationManager> {
|
||||
// Initialize OAuth state machine
|
||||
let mut oauth_state = OAuthState::new(&self.server_url, None)
|
||||
.await
|
||||
.map_err(|e| McpError::Auth(format!("Failed to initialize OAuth: {}", e)))?;
|
||||
|
||||
oauth_state
|
||||
.start_authorization(scopes, &self.redirect_uri)
|
||||
.await
|
||||
.map_err(|e| McpError::Auth(format!("Failed to start authorization: {}", e)))?;
|
||||
|
||||
// Get authorization URL
|
||||
let auth_url = oauth_state
|
||||
.get_authorization_url()
|
||||
.await
|
||||
.map_err(|e| McpError::Auth(format!("Failed to get authorization URL: {}", e)))?;
|
||||
|
||||
tracing::info!("OAuth authorization URL: {}", auth_url);
|
||||
|
||||
// Start callback server and wait for code
|
||||
let auth_code = self.start_callback_server().await?;
|
||||
|
||||
// Exchange code for token
|
||||
oauth_state
|
||||
.handle_callback(&auth_code)
|
||||
.await
|
||||
.map_err(|e| McpError::Auth(format!("Failed to handle OAuth callback: {}", e)))?;
|
||||
|
||||
// Get authorization manager
|
||||
oauth_state
|
||||
.into_authorization_manager()
|
||||
.ok_or_else(|| McpError::Auth("Failed to get authorization manager".to_string()))
|
||||
}
|
||||
|
||||
/// Start a local HTTP server to receive the OAuth callback
|
||||
async fn start_callback_server(&self) -> McpResult<String> {
|
||||
let (code_sender, code_receiver) = oneshot::channel::<String>();
|
||||
|
||||
let state = CallbackState {
|
||||
code_receiver: Arc::new(Mutex::new(Some(code_sender))),
|
||||
};
|
||||
|
||||
// Create router for callback
|
||||
let app = Router::new()
|
||||
.route("/callback", get(Self::callback_handler))
|
||||
.with_state(state);
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], self.callback_port));
|
||||
|
||||
// Start server in background
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
|
||||
McpError::Auth(format!(
|
||||
"Failed to bind to callback port {}: {}",
|
||||
self.callback_port, e
|
||||
))
|
||||
})?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _ = axum::serve(listener, app).await;
|
||||
});
|
||||
|
||||
tracing::info!(
|
||||
"OAuth callback server started on port {}",
|
||||
self.callback_port
|
||||
);
|
||||
|
||||
// Wait for authorization code
|
||||
code_receiver
|
||||
.await
|
||||
.map_err(|_| McpError::Auth("Failed to receive authorization code".to_string()))
|
||||
}
|
||||
|
||||
/// Handle OAuth callback
|
||||
async fn callback_handler(
|
||||
Query(params): Query<CallbackParams>,
|
||||
State(state): State<CallbackState>,
|
||||
) -> Html<String> {
|
||||
tracing::debug!("Received OAuth callback with code");
|
||||
|
||||
// Send code to waiting task
|
||||
if let Some(sender) = state.code_receiver.lock().await.take() {
|
||||
let _ = sender.send(params.code);
|
||||
}
|
||||
|
||||
Html(CALLBACK_HTML.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an OAuth-authenticated client
|
||||
pub async fn create_oauth_client(
|
||||
server_url: String,
|
||||
_sse_url: String,
|
||||
redirect_uri: String,
|
||||
callback_port: u16,
|
||||
scopes: &[&str],
|
||||
) -> McpResult<rmcp::transport::auth::AuthClient<reqwest::Client>> {
|
||||
let helper = OAuthHelper::new(server_url, redirect_uri, callback_port);
|
||||
let auth_manager = helper.authenticate(scopes).await?;
|
||||
|
||||
let client = rmcp::transport::auth::AuthClient::new(reqwest::Client::default(), auth_manager);
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
1011
sgl-router/src/metrics.rs
Normal file
1011
sgl-router/src/metrics.rs
Normal file
File diff suppressed because it is too large
Load Diff
502
sgl-router/src/middleware.rs
Normal file
502
sgl-router/src/middleware.rs
Normal file
@@ -0,0 +1,502 @@
|
||||
use axum::{
|
||||
extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next,
|
||||
response::IntoResponse, response::Response,
|
||||
};
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tower::{Layer, Service};
|
||||
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
|
||||
use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
|
||||
|
||||
pub use crate::core::token_bucket::TokenBucket;
|
||||
|
||||
use crate::server::AppState;
|
||||
|
||||
/// Generate OpenAI-compatible request ID based on endpoint
|
||||
fn generate_request_id(path: &str) -> String {
|
||||
let prefix = if path.contains("/chat/completions") {
|
||||
"chatcmpl-"
|
||||
} else if path.contains("/completions") {
|
||||
"cmpl-"
|
||||
} else if path.contains("/generate") {
|
||||
"gnt-"
|
||||
} else {
|
||||
"req-"
|
||||
};
|
||||
|
||||
// Generate a random string similar to OpenAI's format
|
||||
let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
|
||||
let mut rng = rand::rng();
|
||||
let random_part: String = (0..24)
|
||||
.map(|_| {
|
||||
let idx = rng.random_range(0..chars.len());
|
||||
chars.chars().nth(idx).unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
format!("{}{}", prefix, random_part)
|
||||
}
|
||||
|
||||
/// Extension type for storing request ID
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RequestId(pub String);
|
||||
|
||||
/// Tower Layer for request ID middleware
|
||||
#[derive(Clone)]
|
||||
pub struct RequestIdLayer {
|
||||
headers: Arc<Vec<String>>,
|
||||
}
|
||||
|
||||
impl RequestIdLayer {
|
||||
pub fn new(headers: Vec<String>) -> Self {
|
||||
Self {
|
||||
headers: Arc::new(headers),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for RequestIdLayer {
|
||||
type Service = RequestIdMiddleware<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
RequestIdMiddleware {
|
||||
inner,
|
||||
headers: self.headers.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tower Service for request ID middleware
|
||||
#[derive(Clone)]
|
||||
pub struct RequestIdMiddleware<S> {
|
||||
inner: S,
|
||||
headers: Arc<Vec<String>>,
|
||||
}
|
||||
|
||||
impl<S> Service<Request> for RequestIdMiddleware<S>
|
||||
where
|
||||
S: Service<Request, Response = Response> + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
|
||||
>;
|
||||
|
||||
fn poll_ready(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request) -> Self::Future {
|
||||
let headers = self.headers.clone();
|
||||
|
||||
// Extract request ID from headers or generate new one
|
||||
let mut request_id = None;
|
||||
|
||||
for header_name in headers.iter() {
|
||||
if let Some(header_value) = req.headers().get(header_name) {
|
||||
if let Ok(value) = header_value.to_str() {
|
||||
request_id = Some(value.to_string());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let request_id = request_id.unwrap_or_else(|| generate_request_id(req.uri().path()));
|
||||
|
||||
// Insert request ID into request extensions
|
||||
req.extensions_mut().insert(RequestId(request_id.clone()));
|
||||
|
||||
// Create a span with the request ID for this request
|
||||
let span = tracing::info_span!(
|
||||
"http_request",
|
||||
method = %req.method(),
|
||||
uri = %req.uri(),
|
||||
version = ?req.version(),
|
||||
request_id = %request_id
|
||||
);
|
||||
|
||||
// Log within the span
|
||||
let _enter = span.enter();
|
||||
tracing::info!(
|
||||
target: "sglang_router_rs::request",
|
||||
"started processing request"
|
||||
);
|
||||
drop(_enter);
|
||||
|
||||
// Capture values we need in the async block
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let version = req.version();
|
||||
|
||||
// Call the inner service
|
||||
let future = self.inner.call(req);
|
||||
|
||||
Box::pin(async move {
|
||||
let start_time = Instant::now();
|
||||
let mut response = future.await?;
|
||||
let latency = start_time.elapsed();
|
||||
|
||||
// Add request ID to response headers
|
||||
response.headers_mut().insert(
|
||||
"x-request-id",
|
||||
HeaderValue::from_str(&request_id)
|
||||
.unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id")),
|
||||
);
|
||||
|
||||
// Log the response with proper request ID in span
|
||||
let status = response.status();
|
||||
let span = tracing::info_span!(
|
||||
"http_request",
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
version = ?version,
|
||||
request_id = %request_id,
|
||||
status = %status,
|
||||
latency = ?latency
|
||||
);
|
||||
|
||||
let _enter = span.enter();
|
||||
if status.is_server_error() {
|
||||
tracing::error!(
|
||||
target: "sglang_router_rs::response",
|
||||
"request failed with server error"
|
||||
);
|
||||
} else if status.is_client_error() {
|
||||
tracing::warn!(
|
||||
target: "sglang_router_rs::response",
|
||||
"request failed with client error"
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
target: "sglang_router_rs::response",
|
||||
"finished processing request"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Logging Middleware =============
|
||||
|
||||
/// Custom span maker that includes request ID
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RequestSpan;
|
||||
|
||||
impl<B> MakeSpan<B> for RequestSpan {
|
||||
fn make_span(&mut self, request: &Request<B>) -> Span {
|
||||
// Don't try to extract request ID here - it won't be available yet
|
||||
// The RequestIdLayer runs after TraceLayer creates the span
|
||||
info_span!(
|
||||
"http_request",
|
||||
method = %request.method(),
|
||||
uri = %request.uri(),
|
||||
version = ?request.version(),
|
||||
request_id = Empty, // Will be set later
|
||||
status_code = Empty,
|
||||
latency = Empty,
|
||||
error = Empty,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom on_request handler
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RequestLogger;
|
||||
|
||||
impl<B> OnRequest<B> for RequestLogger {
|
||||
fn on_request(&mut self, request: &Request<B>, span: &Span) {
|
||||
let _enter = span.enter();
|
||||
|
||||
// Try to get the request ID from extensions
|
||||
// This will work if RequestIdLayer has already run
|
||||
if let Some(request_id) = request.extensions().get::<RequestId>() {
|
||||
span.record("request_id", request_id.0.as_str());
|
||||
}
|
||||
|
||||
// Don't log here - we already log in RequestIdService with the proper request_id
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom on_response handler
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResponseLogger {
|
||||
_start_time: Instant,
|
||||
}
|
||||
|
||||
impl Default for ResponseLogger {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
_start_time: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> OnResponse<B> for ResponseLogger {
|
||||
fn on_response(self, response: &Response<B>, latency: std::time::Duration, span: &Span) {
|
||||
let status = response.status();
|
||||
|
||||
// Record these in the span for structured logging/observability tools
|
||||
span.record("status_code", status.as_u16());
|
||||
span.record("latency", format!("{:?}", latency));
|
||||
|
||||
// Don't log here - RequestIdService handles all logging with proper request IDs
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configured TraceLayer for HTTP logging
|
||||
/// Note: Actual request/response logging with request IDs is done in RequestIdService
|
||||
pub fn create_logging_layer() -> TraceLayer<
|
||||
tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
|
||||
RequestSpan,
|
||||
RequestLogger,
|
||||
ResponseLogger,
|
||||
> {
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(RequestSpan)
|
||||
.on_request(RequestLogger)
|
||||
.on_response(ResponseLogger::default())
|
||||
}
|
||||
|
||||
/// Structured logging data for requests
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct RequestLogEntry {
|
||||
pub timestamp: String,
|
||||
pub request_id: String,
|
||||
pub method: String,
|
||||
pub uri: String,
|
||||
pub status: u16,
|
||||
pub latency_ms: u64,
|
||||
pub user_agent: Option<String>,
|
||||
pub remote_addr: Option<String>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Log a request with structured data
|
||||
pub fn log_request(entry: RequestLogEntry) {
|
||||
if entry.status >= 500 {
|
||||
tracing::error!(
|
||||
target: "sglang_router_rs::http",
|
||||
request_id = %entry.request_id,
|
||||
method = %entry.method,
|
||||
uri = %entry.uri,
|
||||
status = entry.status,
|
||||
latency_ms = entry.latency_ms,
|
||||
user_agent = ?entry.user_agent,
|
||||
remote_addr = ?entry.remote_addr,
|
||||
error = ?entry.error,
|
||||
"HTTP request failed"
|
||||
);
|
||||
} else if entry.status >= 400 {
|
||||
tracing::warn!(
|
||||
target: "sglang_router_rs::http",
|
||||
request_id = %entry.request_id,
|
||||
method = %entry.method,
|
||||
uri = %entry.uri,
|
||||
status = entry.status,
|
||||
latency_ms = entry.latency_ms,
|
||||
user_agent = ?entry.user_agent,
|
||||
remote_addr = ?entry.remote_addr,
|
||||
"HTTP request client error"
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
target: "sglang_router_rs::http",
|
||||
request_id = %entry.request_id,
|
||||
method = %entry.method,
|
||||
uri = %entry.uri,
|
||||
status = entry.status,
|
||||
latency_ms = entry.latency_ms,
|
||||
user_agent = ?entry.user_agent,
|
||||
remote_addr = ?entry.remote_addr,
|
||||
"HTTP request completed"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Concurrency Limiting with Queue Support ============
|
||||
|
||||
/// Request queue entry
|
||||
pub struct QueuedRequest {
|
||||
/// Time when the request was queued
|
||||
queued_at: Instant,
|
||||
/// Channel to send the permit back when acquired
|
||||
permit_tx: oneshot::Sender<Result<(), StatusCode>>,
|
||||
}
|
||||
|
||||
/// Queue metrics for monitoring
|
||||
#[derive(Debug, Default)]
|
||||
pub struct QueueMetrics {
|
||||
pub total_queued: std::sync::atomic::AtomicU64,
|
||||
pub current_queued: std::sync::atomic::AtomicU64,
|
||||
pub total_timeout: std::sync::atomic::AtomicU64,
|
||||
pub total_rejected: std::sync::atomic::AtomicU64,
|
||||
}
|
||||
|
||||
/// Queue processor that handles queued requests
|
||||
pub struct QueueProcessor {
|
||||
token_bucket: Arc<TokenBucket>,
|
||||
queue_rx: mpsc::Receiver<QueuedRequest>,
|
||||
queue_timeout: Duration,
|
||||
}
|
||||
|
||||
impl QueueProcessor {
|
||||
pub fn new(
|
||||
token_bucket: Arc<TokenBucket>,
|
||||
queue_rx: mpsc::Receiver<QueuedRequest>,
|
||||
queue_timeout: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
token_bucket,
|
||||
queue_rx,
|
||||
queue_timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(mut self) {
|
||||
info!("Starting concurrency queue processor");
|
||||
|
||||
// Process requests in a single task to reduce overhead
|
||||
while let Some(queued) = self.queue_rx.recv().await {
|
||||
// Check timeout immediately
|
||||
let elapsed = queued.queued_at.elapsed();
|
||||
if elapsed >= self.queue_timeout {
|
||||
warn!("Request already timed out in queue");
|
||||
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
|
||||
continue;
|
||||
}
|
||||
|
||||
let remaining_timeout = self.queue_timeout - elapsed;
|
||||
|
||||
// Try to acquire token for this request
|
||||
if self.token_bucket.try_acquire(1.0).await.is_ok() {
|
||||
// Got token immediately
|
||||
debug!("Queue: acquired token immediately for queued request");
|
||||
let _ = queued.permit_tx.send(Ok(()));
|
||||
} else {
|
||||
// Need to wait for token
|
||||
let token_bucket = self.token_bucket.clone();
|
||||
|
||||
// Spawn task only when we actually need to wait
|
||||
tokio::spawn(async move {
|
||||
if token_bucket
|
||||
.acquire_timeout(1.0, remaining_timeout)
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
debug!("Queue: acquired token after waiting");
|
||||
let _ = queued.permit_tx.send(Ok(()));
|
||||
} else {
|
||||
warn!("Queue: request timed out waiting for token");
|
||||
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
warn!("Concurrency queue processor shutting down");
|
||||
}
|
||||
}
|
||||
|
||||
/// State for the concurrency limiter
|
||||
pub struct ConcurrencyLimiter {
|
||||
pub queue_tx: Option<mpsc::Sender<QueuedRequest>>,
|
||||
}
|
||||
|
||||
impl ConcurrencyLimiter {
|
||||
/// Create new concurrency limiter with optional queue
|
||||
pub fn new(
|
||||
token_bucket: Arc<TokenBucket>,
|
||||
queue_size: usize,
|
||||
queue_timeout: Duration,
|
||||
) -> (Self, Option<QueueProcessor>) {
|
||||
if queue_size > 0 {
|
||||
let (queue_tx, queue_rx) = mpsc::channel(queue_size);
|
||||
let processor = QueueProcessor::new(token_bucket, queue_rx, queue_timeout);
|
||||
|
||||
(
|
||||
Self {
|
||||
queue_tx: Some(queue_tx),
|
||||
},
|
||||
Some(processor),
|
||||
)
|
||||
} else {
|
||||
(Self { queue_tx: None }, None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Middleware function for concurrency limiting with optional queuing
|
||||
pub async fn concurrency_limit_middleware(
|
||||
State(app_state): State<Arc<AppState>>,
|
||||
request: Request<axum::body::Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let token_bucket = app_state.context.rate_limiter.clone();
|
||||
|
||||
// Try to acquire token immediately
|
||||
if token_bucket.try_acquire(1.0).await.is_ok() {
|
||||
debug!("Acquired token immediately");
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Return the token to the bucket
|
||||
token_bucket.return_tokens(1.0).await;
|
||||
|
||||
response
|
||||
} else {
|
||||
// No tokens available, try to queue if enabled
|
||||
if let Some(queue_tx) = &app_state.concurrency_queue_tx {
|
||||
debug!("No tokens available, attempting to queue request");
|
||||
|
||||
// Create a channel for the token response
|
||||
let (permit_tx, permit_rx) = oneshot::channel();
|
||||
|
||||
let queued = QueuedRequest {
|
||||
queued_at: Instant::now(),
|
||||
permit_tx,
|
||||
};
|
||||
|
||||
// Try to send to queue
|
||||
match queue_tx.try_send(queued) {
|
||||
Ok(_) => {
|
||||
// Wait for token from queue processor
|
||||
match permit_rx.await {
|
||||
Ok(Ok(())) => {
|
||||
debug!("Acquired token from queue");
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Return the token to the bucket
|
||||
token_bucket.return_tokens(1.0).await;
|
||||
|
||||
response
|
||||
}
|
||||
Ok(Err(status)) => {
|
||||
warn!("Queue returned error status: {}", status);
|
||||
status.into_response()
|
||||
}
|
||||
Err(_) => {
|
||||
error!("Queue response channel closed");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Request queue is full, returning 429");
|
||||
StatusCode::TOO_MANY_REQUESTS.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("No tokens available and queuing is disabled, returning 429");
|
||||
StatusCode::TOO_MANY_REQUESTS.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
423
sgl-router/src/policies/cache_aware.rs
Normal file
423
sgl-router/src/policies/cache_aware.rs
Normal file
@@ -0,0 +1,423 @@
|
||||
/*
|
||||
Cache-Aware Load Balancing Router
|
||||
|
||||
This router combines two strategies to optimize both cache utilization and request distribution:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
2. Load Balancing (Shortest Queue with Balance Thresholds)
|
||||
|
||||
The router dynamically switches between these strategies based on load conditions:
|
||||
- Uses load balancing when the system is imbalanced
|
||||
- Uses cache-aware routing when the system is balanced
|
||||
|
||||
A system is considered imbalanced if both conditions are met:
|
||||
1. (max - min) > abs_threshold
|
||||
2. max > rel_threshold * min
|
||||
|
||||
Strategy Details:
|
||||
|
||||
1. Cache-Aware Routing (Approximate Tree)
|
||||
-------------------------------------------
|
||||
This strategy maintains an approximate radix tree for each worker based on request history,
|
||||
eliminating the need for direct cache state queries. The tree stores raw text characters
|
||||
instead of token IDs to avoid tokenization overhead.
|
||||
|
||||
Process:
|
||||
a. For each request, find the worker with the highest prefix match
|
||||
b. If match rate > cache_threshold:
|
||||
Route to the worker with highest match (likely has relevant data cached)
|
||||
c. If match rate ≤ cache_threshold:
|
||||
Route to the worker with smallest tree size (most available cache capacity)
|
||||
d. Background maintenance:
|
||||
Periodically evict least recently used leaf nodes to prevent memory overflow
|
||||
|
||||
2. Load Balancing (Shortest Queue)
|
||||
-------------------------------------------
|
||||
This strategy tracks pending request counts per worker and routes new requests
|
||||
to the least busy worker when the system is detected to be imbalanced.
|
||||
|
||||
Configuration Parameters:
|
||||
------------------------
|
||||
1. cache_threshold: (float, 0.0 to 1.0)
|
||||
Minimum prefix match ratio to use highest-match routing.
|
||||
Below this threshold, routes to worker with most available cache space.
|
||||
|
||||
2. balance_abs_threshold: (integer)
|
||||
Absolute difference threshold for load imbalance detection.
|
||||
System is potentially imbalanced if (max_load - min_load) > abs_threshold
|
||||
|
||||
3. balance_rel_threshold: (float)
|
||||
Relative ratio threshold for load imbalance detection.
|
||||
System is potentially imbalanced if max_load > min_load * rel_threshold
|
||||
Used in conjunction with abs_threshold to determine final imbalance state.
|
||||
|
||||
4. eviction_interval_secs: (integer)
|
||||
Interval between LRU eviction cycles for the approximate trees.
|
||||
|
||||
5. max_tree_size: (integer)
|
||||
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
|
||||
during the next eviction cycle.
|
||||
*/
|
||||
|
||||
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::tree::Tree;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tracing::debug;
|
||||
|
||||
/// Cache-aware routing policy
|
||||
///
|
||||
/// Routes requests based on cache affinity when load is balanced,
|
||||
/// switches to shortest-queue routing when load is imbalanced.
|
||||
#[derive(Debug)]
|
||||
pub struct CacheAwarePolicy {
|
||||
config: CacheAwareConfig,
|
||||
tree: Arc<Mutex<Tree>>,
|
||||
eviction_handle: Option<thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl CacheAwarePolicy {
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(CacheAwareConfig::default())
|
||||
}
|
||||
|
||||
pub fn with_config(config: CacheAwareConfig) -> Self {
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
|
||||
// Start background eviction thread if configured
|
||||
let eviction_handle = if config.eviction_interval_secs > 0 {
|
||||
let tree_clone = Arc::clone(&tree);
|
||||
let max_tree_size = config.max_tree_size;
|
||||
let interval = config.eviction_interval_secs;
|
||||
|
||||
Some(thread::spawn(move || loop {
|
||||
thread::sleep(Duration::from_secs(interval));
|
||||
|
||||
if let Ok(tree_guard) = tree_clone.lock() {
|
||||
tree_guard.evict_tenant_by_size(max_tree_size);
|
||||
debug!("Cache eviction completed, max_size: {}", max_tree_size);
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self {
|
||||
config,
|
||||
tree,
|
||||
eviction_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the tree with worker URLs (used only during initial setup)
|
||||
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
for worker in workers {
|
||||
tree.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a single worker to the tree (incremental update)
|
||||
pub fn add_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.insert("", url);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from the tree
|
||||
pub fn remove_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.remove_tenant(url);
|
||||
}
|
||||
}
|
||||
|
||||
/// Run cache eviction to prevent unbounded growth
|
||||
pub fn evict_cache(&self, max_size: usize) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.evict_tenant_by_size(max_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get current load statistics
|
||||
let loads: Vec<usize> = workers.iter().map(|w| w.load()).collect();
|
||||
let max_load = *loads.iter().max().unwrap_or(&0);
|
||||
let min_load = *loads.iter().min().unwrap_or(&0);
|
||||
|
||||
// Check if load is imbalanced
|
||||
let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold
|
||||
&& (max_load as f32) > (min_load as f32 * self.config.balance_rel_threshold);
|
||||
|
||||
if is_imbalanced {
|
||||
// Log load balancing trigger
|
||||
let worker_loads: Vec<(String, usize)> = workers
|
||||
.iter()
|
||||
.map(|w| (w.url().to_string(), w.load()))
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
"Load balancing triggered | max: {} | min: {} | workers: {:?}",
|
||||
max_load, min_load, worker_loads
|
||||
);
|
||||
|
||||
RouterMetrics::record_load_balancing_event();
|
||||
RouterMetrics::set_load_range(max_load, min_load);
|
||||
|
||||
// Use shortest queue when imbalanced
|
||||
let min_load_idx = healthy_indices
|
||||
.iter()
|
||||
.min_by_key(|&&idx| workers[idx].load())
|
||||
.copied()?;
|
||||
|
||||
// Even in imbalanced mode, update the tree to maintain cache state
|
||||
if let Some(text) = request_text {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.insert(text, workers[min_load_idx].url());
|
||||
}
|
||||
}
|
||||
|
||||
// Increment processed counter
|
||||
workers[min_load_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(workers[min_load_idx].url());
|
||||
RouterMetrics::record_policy_decision(self.name(), workers[min_load_idx].url());
|
||||
|
||||
return Some(min_load_idx);
|
||||
}
|
||||
|
||||
// Use cache-aware routing when balanced
|
||||
let text = request_text.unwrap_or("");
|
||||
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
let (matched_text, matched_worker) = tree.prefix_match(text);
|
||||
let match_rate = if text.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
matched_text.chars().count() as f32 / text.chars().count() as f32
|
||||
};
|
||||
|
||||
let selected_url = if match_rate > self.config.cache_threshold {
|
||||
RouterMetrics::record_cache_hit();
|
||||
matched_worker.to_string()
|
||||
} else {
|
||||
RouterMetrics::record_cache_miss();
|
||||
tree.get_smallest_tenant()
|
||||
};
|
||||
|
||||
// Find the index of the selected worker
|
||||
if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) {
|
||||
// Only proceed if the worker is healthy
|
||||
if workers[selected_idx].is_healthy() {
|
||||
// Update the tree with this request
|
||||
tree.insert(text, &selected_url);
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(&selected_url);
|
||||
|
||||
return Some(selected_idx);
|
||||
}
|
||||
} else {
|
||||
// Selected worker no longer exists, remove it from tree
|
||||
tree.remove_tenant(&selected_url);
|
||||
debug!("Removed stale worker {} from cache tree", selected_url);
|
||||
}
|
||||
|
||||
// Fallback to first healthy worker
|
||||
return healthy_indices.first().copied();
|
||||
}
|
||||
|
||||
// Fallback to first healthy worker if tree operations fail
|
||||
healthy_indices.first().copied()
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"cache_aware"
|
||||
}
|
||||
|
||||
fn needs_request_text(&self) -> bool {
|
||||
true // Cache-aware policy needs request text for cache affinity
|
||||
}
|
||||
|
||||
fn on_request_complete(&self, worker_url: &str, success: bool) {
|
||||
// Could track success rates per worker for more intelligent routing
|
||||
if !success {
|
||||
// Optionally reduce affinity for failed requests
|
||||
tracing::debug!(
|
||||
"Request to {} completed with success={}",
|
||||
worker_url,
|
||||
success
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn select_worker_pair(
|
||||
&self,
|
||||
prefill_workers: &[Box<dyn Worker>],
|
||||
decode_workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<(usize, usize)> {
|
||||
// DEPRECATED: This method is no longer used when separate policies are configured.
|
||||
// The PD router now uses separate policies for prefill and decode selection.
|
||||
// This implementation remains for backward compatibility when a single policy is used.
|
||||
|
||||
// In PD mode with single policy:
|
||||
// - Prefill: Use cache-aware routing for better cache utilization
|
||||
// - Decode: Use least-load routing for better load distribution
|
||||
|
||||
// Select prefill worker using cache-aware logic
|
||||
let prefill_idx = self.select_worker(prefill_workers, request_text)?;
|
||||
|
||||
// Select decode worker using least-load logic
|
||||
let healthy_decode = get_healthy_worker_indices(decode_workers);
|
||||
if healthy_decode.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let decode_idx = healthy_decode
|
||||
.iter()
|
||||
.min_by_key(|&&idx| decode_workers[idx].load())
|
||||
.copied()?;
|
||||
|
||||
Some((prefill_idx, decode_idx))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CacheAwarePolicy {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CacheAwarePolicy {
|
||||
fn drop(&mut self) {
|
||||
// Note: We can't properly stop the eviction thread since it's in an infinite loop
|
||||
// In a production system, we'd use a channel or atomic flag to signal shutdown
|
||||
if let Some(handle) = self.eviction_handle.take() {
|
||||
// The thread will continue running until the program exits
|
||||
// This is acceptable for now since the router typically runs for the lifetime of the program
|
||||
drop(handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_cache_aware_with_balanced_load() {
|
||||
// Create policy without eviction thread for testing
|
||||
let config = CacheAwareConfig {
|
||||
eviction_interval_secs: 0, // Disable eviction thread
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Initialize the policy with workers
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// First request should be distributed
|
||||
let idx1 = policy.select_worker(&workers, Some("hello world")).unwrap();
|
||||
|
||||
// Same request should go to same worker (cache hit)
|
||||
let idx2 = policy.select_worker(&workers, Some("hello world")).unwrap();
|
||||
assert_eq!(idx1, idx2);
|
||||
|
||||
// Similar request should also go to same worker
|
||||
let idx3 = policy.select_worker(&workers, Some("hello")).unwrap();
|
||||
assert_eq!(idx1, idx3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_aware_with_imbalanced_load() {
|
||||
let policy = CacheAwarePolicy::with_config(CacheAwareConfig {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 5,
|
||||
balance_rel_threshold: 2.0,
|
||||
eviction_interval_secs: 0, // Disable eviction thread
|
||||
max_tree_size: 10000,
|
||||
});
|
||||
|
||||
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
|
||||
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
|
||||
|
||||
// Create significant load imbalance
|
||||
for _ in 0..20 {
|
||||
worker1.increment_load();
|
||||
}
|
||||
// worker2 has load 0
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(worker1), Box::new(worker2)];
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// Should select worker2 (lower load) despite cache affinity
|
||||
for _ in 0..5 {
|
||||
let idx = policy.select_worker(&workers, Some("test")).unwrap();
|
||||
assert_eq!(idx, 1); // Should always pick worker2
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_aware_worker_removal() {
|
||||
let config = CacheAwareConfig {
|
||||
eviction_interval_secs: 0, // Disable eviction thread
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// Route some requests
|
||||
policy.select_worker(&workers, Some("test1"));
|
||||
policy.select_worker(&workers, Some("test2"));
|
||||
|
||||
// Remove a worker
|
||||
policy.remove_worker("http://w1:8000");
|
||||
workers[0].set_healthy(false);
|
||||
|
||||
// All requests should now go to worker2
|
||||
let idx = policy.select_worker(&workers, Some("test1")).unwrap();
|
||||
assert_eq!(idx, 1);
|
||||
}
|
||||
}
|
||||
94
sgl-router/src/policies/factory.rs
Normal file
94
sgl-router/src/policies/factory.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
//! Factory for creating load balancing policies
|
||||
|
||||
use super::{
|
||||
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
|
||||
RoundRobinPolicy,
|
||||
};
|
||||
use crate::config::PolicyConfig;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Factory for creating policy instances
|
||||
pub struct PolicyFactory;
|
||||
|
||||
impl PolicyFactory {
|
||||
/// Create a policy from configuration
|
||||
pub fn create_from_config(config: &PolicyConfig) -> Arc<dyn LoadBalancingPolicy> {
|
||||
match config {
|
||||
PolicyConfig::Random => Arc::new(RandomPolicy::new()),
|
||||
PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()),
|
||||
PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()),
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
} => {
|
||||
let config = CacheAwareConfig {
|
||||
cache_threshold: *cache_threshold,
|
||||
balance_abs_threshold: *balance_abs_threshold,
|
||||
balance_rel_threshold: *balance_rel_threshold,
|
||||
eviction_interval_secs: *eviction_interval_secs,
|
||||
max_tree_size: *max_tree_size,
|
||||
};
|
||||
Arc::new(CacheAwarePolicy::with_config(config))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a policy by name (for dynamic loading)
|
||||
pub fn create_by_name(name: &str) -> Option<Arc<dyn LoadBalancingPolicy>> {
|
||||
match name.to_lowercase().as_str() {
|
||||
"random" => Some(Arc::new(RandomPolicy::new())),
|
||||
"round_robin" | "roundrobin" => Some(Arc::new(RoundRobinPolicy::new())),
|
||||
"power_of_two" | "poweroftwo" => Some(Arc::new(PowerOfTwoPolicy::new())),
|
||||
"cache_aware" | "cacheaware" => Some(Arc::new(CacheAwarePolicy::new())),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_from_config() {
|
||||
// Test Random
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
assert_eq!(policy.name(), "random");
|
||||
|
||||
// Test RoundRobin
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin);
|
||||
assert_eq!(policy.name(), "round_robin");
|
||||
|
||||
// Test PowerOfTwo
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 60,
|
||||
});
|
||||
assert_eq!(policy.name(), "power_of_two");
|
||||
|
||||
// Test CacheAware
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.7,
|
||||
balance_abs_threshold: 10,
|
||||
balance_rel_threshold: 1.5,
|
||||
eviction_interval_secs: 30,
|
||||
max_tree_size: 1000,
|
||||
});
|
||||
assert_eq!(policy.name(), "cache_aware");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_by_name() {
|
||||
assert!(PolicyFactory::create_by_name("random").is_some());
|
||||
assert!(PolicyFactory::create_by_name("RANDOM").is_some());
|
||||
assert!(PolicyFactory::create_by_name("round_robin").is_some());
|
||||
assert!(PolicyFactory::create_by_name("RoundRobin").is_some());
|
||||
assert!(PolicyFactory::create_by_name("power_of_two").is_some());
|
||||
assert!(PolicyFactory::create_by_name("PowerOfTwo").is_some());
|
||||
assert!(PolicyFactory::create_by_name("cache_aware").is_some());
|
||||
assert!(PolicyFactory::create_by_name("CacheAware").is_some());
|
||||
assert!(PolicyFactory::create_by_name("unknown").is_none());
|
||||
}
|
||||
}
|
||||
148
sgl-router/src/policies/mod.rs
Normal file
148
sgl-router/src/policies/mod.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
//! Load balancing policies for SGLang router
|
||||
//!
|
||||
//! This module provides a unified abstraction for routing policies that work
|
||||
//! across both regular and prefill-decode (PD) routing modes.
|
||||
|
||||
use crate::core::Worker;
|
||||
use std::fmt::Debug;
|
||||
|
||||
mod cache_aware;
|
||||
mod factory;
|
||||
mod power_of_two;
|
||||
mod random;
|
||||
mod round_robin;
|
||||
|
||||
pub use cache_aware::CacheAwarePolicy;
|
||||
pub use factory::PolicyFactory;
|
||||
pub use power_of_two::PowerOfTwoPolicy;
|
||||
pub use random::RandomPolicy;
|
||||
pub use round_robin::RoundRobinPolicy;
|
||||
|
||||
/// Core trait for load balancing policies
|
||||
///
|
||||
/// This trait provides a unified interface for implementing routing algorithms
|
||||
/// that can work with both regular single-worker selection and PD dual-worker selection.
|
||||
pub trait LoadBalancingPolicy: Send + Sync + Debug {
|
||||
/// Select a single worker from the available workers
|
||||
///
|
||||
/// This is used for regular routing mode where requests go to a single worker.
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<usize>;
|
||||
|
||||
/// Select a pair of workers (prefill and decode) for PD routing
|
||||
///
|
||||
/// Returns indices of (prefill_worker, decode_worker) from their respective arrays.
|
||||
/// Default implementation uses select_worker for each array independently.
|
||||
fn select_worker_pair(
|
||||
&self,
|
||||
prefill_workers: &[Box<dyn Worker>],
|
||||
decode_workers: &[Box<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<(usize, usize)> {
|
||||
// Default implementation: independently select from each pool
|
||||
let prefill_idx = self.select_worker(prefill_workers, request_text)?;
|
||||
let decode_idx = self.select_worker(decode_workers, request_text)?;
|
||||
Some((prefill_idx, decode_idx))
|
||||
}
|
||||
|
||||
/// Update policy state after request completion
|
||||
///
|
||||
/// This is called when a request completes (successfully or not) to allow
|
||||
/// policies to update their internal state.
|
||||
fn on_request_complete(&self, _worker_url: &str, _success: bool) {
|
||||
// Default: no-op for stateless policies
|
||||
}
|
||||
|
||||
/// Get policy name for metrics and debugging
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Check if this policy needs request text for routing decisions
|
||||
fn needs_request_text(&self) -> bool {
|
||||
false // Default: most policies don't need request text
|
||||
}
|
||||
|
||||
/// Update worker load information
|
||||
///
|
||||
/// This is called periodically with current load information for load-aware policies.
|
||||
fn update_loads(&self, _loads: &std::collections::HashMap<String, isize>) {
|
||||
// Default: no-op for policies that don't use load information
|
||||
}
|
||||
|
||||
/// Reset any internal state
|
||||
///
|
||||
/// This is useful for policies that maintain state (e.g., round-robin counters).
|
||||
fn reset(&self) {
|
||||
// Default: no-op for stateless policies
|
||||
}
|
||||
|
||||
/// Get as Any for downcasting
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
}
|
||||
|
||||
/// Configuration for cache-aware policy
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheAwareConfig {
|
||||
pub cache_threshold: f32,
|
||||
pub balance_abs_threshold: usize,
|
||||
pub balance_rel_threshold: f32,
|
||||
pub eviction_interval_secs: u64,
|
||||
pub max_tree_size: usize,
|
||||
}
|
||||
|
||||
impl Default for CacheAwareConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
eviction_interval_secs: 30,
|
||||
max_tree_size: 10000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to filter healthy workers and return their indices
|
||||
pub(crate) fn get_healthy_worker_indices(workers: &[Box<dyn Worker>]) -> Vec<usize> {
|
||||
workers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, w)| w.is_healthy() && w.circuit_breaker().can_execute())
|
||||
.map(|(idx, _)| idx)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_get_healthy_worker_indices() {
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// All healthy initially
|
||||
let indices = get_healthy_worker_indices(&workers);
|
||||
assert_eq!(indices, vec![0, 1, 2]);
|
||||
|
||||
// Mark one unhealthy
|
||||
workers[1].set_healthy(false);
|
||||
let indices = get_healthy_worker_indices(&workers);
|
||||
assert_eq!(indices, vec![0, 2]);
|
||||
}
|
||||
}
|
||||
201
sgl-router/src/policies/power_of_two.rs
Normal file
201
sgl-router/src/policies/power_of_two.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
//! Power-of-two choices load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use rand::Rng;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
use tracing::info;
|
||||
|
||||
/// Power-of-two choices policy
|
||||
///
|
||||
/// Randomly selects two workers and routes to the one with lower load.
|
||||
/// This provides good load distribution with minimal coordination overhead.
|
||||
#[derive(Debug)]
|
||||
pub struct PowerOfTwoPolicy {
|
||||
/// Cached load information from external monitoring
|
||||
cached_loads: RwLock<HashMap<String, isize>>,
|
||||
}
|
||||
|
||||
impl PowerOfTwoPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cached_loads: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_worker_load(&self, worker: &dyn Worker) -> isize {
|
||||
// First check cached loads (from external monitoring)
|
||||
if let Ok(loads) = self.cached_loads.read() {
|
||||
if let Some(&load) = loads.get(worker.url()) {
|
||||
return load;
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to local load counter
|
||||
worker.load() as isize
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for PowerOfTwoPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if healthy_indices.len() == 1 {
|
||||
return Some(healthy_indices[0]);
|
||||
}
|
||||
|
||||
// Select two random workers
|
||||
let mut rng = rand::rng();
|
||||
let idx1 = rng.random_range(0..healthy_indices.len());
|
||||
let mut idx2 = rng.random_range(0..healthy_indices.len());
|
||||
|
||||
// Ensure we pick two different workers
|
||||
while idx2 == idx1 {
|
||||
idx2 = rng.random_range(0..healthy_indices.len());
|
||||
}
|
||||
|
||||
let worker_idx1 = healthy_indices[idx1];
|
||||
let worker_idx2 = healthy_indices[idx2];
|
||||
|
||||
// Compare loads and select the less loaded one
|
||||
let load1 = self.get_worker_load(workers[worker_idx1].as_ref());
|
||||
let load2 = self.get_worker_load(workers[worker_idx2].as_ref());
|
||||
|
||||
// Log selection for debugging
|
||||
let selected_idx = if load1 <= load2 {
|
||||
worker_idx1
|
||||
} else {
|
||||
worker_idx2
|
||||
};
|
||||
|
||||
info!(
|
||||
"Power-of-two selection: {}={} vs {}={} -> selected {}",
|
||||
workers[worker_idx1].url(),
|
||||
load1,
|
||||
workers[worker_idx2].url(),
|
||||
load2,
|
||||
workers[selected_idx].url()
|
||||
);
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(workers[selected_idx].url());
|
||||
RouterMetrics::record_policy_decision(self.name(), workers[selected_idx].url());
|
||||
|
||||
Some(selected_idx)
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"power_of_two"
|
||||
}
|
||||
|
||||
fn update_loads(&self, loads: &HashMap<String, isize>) {
|
||||
if let Ok(mut cached) = self.cached_loads.write() {
|
||||
*cached = loads.clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PowerOfTwoPolicy {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_selection() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular);
|
||||
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular);
|
||||
let worker3 = BasicWorker::new("http://w3:8000".to_string(), WorkerType::Regular);
|
||||
|
||||
// Set different loads
|
||||
for _ in 0..10 {
|
||||
worker1.increment_load();
|
||||
}
|
||||
for _ in 0..5 {
|
||||
worker2.increment_load();
|
||||
}
|
||||
// worker3 has load 0
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> =
|
||||
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
|
||||
|
||||
// Run multiple selections
|
||||
let mut selected_counts = [0; 3];
|
||||
for _ in 0..100 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
selected_counts[idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Worker with lowest load (worker3) should be selected most often
|
||||
assert!(selected_counts[2] > selected_counts[1]);
|
||||
assert!(selected_counts[1] > selected_counts[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_with_cached_loads() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Update cached loads
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("http://w1:8000".to_string(), 100);
|
||||
loads.insert("http://w2:8000".to_string(), 10);
|
||||
policy.update_loads(&loads);
|
||||
|
||||
// Should prefer worker2 with lower cached load
|
||||
let mut w2_selected = 0;
|
||||
for _ in 0..50 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
if idx == 1 {
|
||||
w2_selected += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Worker2 should be selected significantly more often
|
||||
assert!(w2_selected > 35); // Should win most of the time
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_single_worker() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
))];
|
||||
|
||||
// With single worker, should always select it
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
}
|
||||
}
|
||||
121
sgl-router/src/policies/random.rs
Normal file
121
sgl-router/src/policies/random.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
//! Random load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use rand::Rng;
|
||||
|
||||
/// Random selection policy
|
||||
///
|
||||
/// Selects workers randomly with uniform distribution among healthy workers.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RandomPolicy;
|
||||
|
||||
impl RandomPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for RandomPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut rng = rand::rng();
|
||||
let random_idx = rng.random_range(0..healthy_indices.len());
|
||||
let worker = workers[healthy_indices[random_idx]].url();
|
||||
|
||||
RouterMetrics::record_processed_request(worker);
|
||||
RouterMetrics::record_policy_decision(self.name(), worker);
|
||||
Some(healthy_indices[random_idx])
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"random"
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_random_selection() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Test multiple selections to ensure randomness
|
||||
let mut counts = HashMap::new();
|
||||
for _ in 0..100 {
|
||||
if let Some(idx) = policy.select_worker(&workers, None) {
|
||||
*counts.entry(idx).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// All workers should be selected at least once
|
||||
assert_eq!(counts.len(), 3);
|
||||
assert!(counts.values().all(|&count| count > 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_with_unhealthy_workers() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Mark first worker as unhealthy
|
||||
workers[0].set_healthy(false);
|
||||
|
||||
// Should always select the healthy worker (index 1)
|
||||
for _ in 0..10 {
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_no_healthy_workers() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
))];
|
||||
|
||||
workers[0].set_healthy(false);
|
||||
assert_eq!(policy.select_worker(&workers, None), None);
|
||||
}
|
||||
}
|
||||
140
sgl-router/src/policies/round_robin.rs
Normal file
140
sgl-router/src/policies/round_robin.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
//! Round-robin load balancing policy
|
||||
|
||||
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
/// Round-robin selection policy
|
||||
///
|
||||
/// Selects workers in sequential order, cycling through all healthy workers.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RoundRobinPolicy {
|
||||
counter: AtomicUsize,
|
||||
}
|
||||
|
||||
impl RoundRobinPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
counter: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancingPolicy for RoundRobinPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get and increment counter atomically
|
||||
let count = self.counter.fetch_add(1, Ordering::Relaxed);
|
||||
let selected_idx = count % healthy_indices.len();
|
||||
let worker = workers[healthy_indices[selected_idx]].url();
|
||||
|
||||
RouterMetrics::record_processed_request(worker);
|
||||
RouterMetrics::record_policy_decision(self.name(), worker);
|
||||
Some(healthy_indices[selected_idx])
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"round_robin"
|
||||
}
|
||||
|
||||
fn reset(&self) {
|
||||
self.counter.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{BasicWorker, WorkerType};
|
||||
|
||||
#[test]
|
||||
fn test_round_robin_selection() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Should select workers in order: 0, 1, 2, 0, 1, 2, ...
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(2));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_robin_with_unhealthy_workers() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Mark middle worker as unhealthy
|
||||
workers[1].set_healthy(false);
|
||||
|
||||
// Should skip unhealthy worker: 0, 2, 0, 2, ...
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(2));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_robin_reset() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
];
|
||||
|
||||
// Advance the counter
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(1));
|
||||
|
||||
// Reset should start from beginning
|
||||
policy.reset();
|
||||
assert_eq!(policy.select_worker(&workers, None), Some(0));
|
||||
}
|
||||
}
|
||||
541
sgl-router/src/proto/sglang_scheduler.proto
Normal file
541
sgl-router/src/proto/sglang_scheduler.proto
Normal file
@@ -0,0 +1,541 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package sglang.grpc.scheduler;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/struct.proto";
|
||||
|
||||
// Service definition for SGLang scheduler communication
|
||||
// This protocol bridges the Rust router and Python scheduler
|
||||
service SglangScheduler {
|
||||
// Initialize connection and get model info
|
||||
rpc Initialize(InitializeRequest) returns (InitializeResponse);
|
||||
|
||||
// Submit a generation request (supports streaming)
|
||||
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
||||
|
||||
// Submit an embedding request
|
||||
rpc Embed(EmbedRequest) returns (EmbedResponse);
|
||||
|
||||
// Health check and metrics
|
||||
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
||||
|
||||
// Abort a running request
|
||||
rpc Abort(AbortRequest) returns (AbortResponse);
|
||||
|
||||
// Flush KV cache
|
||||
rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse);
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Common Types
|
||||
// =====================
|
||||
|
||||
// Sampling parameters matching SGLang's SamplingParams
|
||||
message SamplingParams {
|
||||
float temperature = 1;
|
||||
float top_p = 2;
|
||||
int32 top_k = 3;
|
||||
float min_p = 4;
|
||||
float frequency_penalty = 5;
|
||||
float presence_penalty = 6;
|
||||
float repetition_penalty = 7;
|
||||
|
||||
int32 max_new_tokens = 8;
|
||||
repeated string stop = 9;
|
||||
repeated int32 stop_token_ids = 10;
|
||||
bool skip_special_tokens = 11;
|
||||
bool spaces_between_special_tokens = 12;
|
||||
|
||||
// Structured generation
|
||||
oneof constraint {
|
||||
string regex = 13;
|
||||
string json_schema = 14;
|
||||
string ebnf_grammar = 15;
|
||||
}
|
||||
|
||||
// LoRA adapter
|
||||
string lora_path = 16;
|
||||
|
||||
// Speculative decoding
|
||||
int32 n = 17; // Number of samples
|
||||
|
||||
// Token healing
|
||||
bool token_healing = 18;
|
||||
|
||||
// Additional parameters
|
||||
int32 min_new_tokens = 19;
|
||||
bool ignore_eos = 20;
|
||||
bool no_stop_trim = 21;
|
||||
int32 stream_interval = 22;
|
||||
map<string, float> logit_bias = 23;
|
||||
string structural_tag = 24;
|
||||
|
||||
// Custom parameters for extensibility
|
||||
google.protobuf.Struct custom_params = 25;
|
||||
}
|
||||
|
||||
// Session parameters for continual prompting
|
||||
message SessionParams {
|
||||
string session_id = 1;
|
||||
string request_id = 2;
|
||||
int32 offset = 3;
|
||||
bool replace = 4;
|
||||
bool drop_previous_output = 5;
|
||||
}
|
||||
|
||||
// Disaggregated serving parameters
|
||||
message DisaggregatedParams {
|
||||
string bootstrap_host = 1;
|
||||
int32 bootstrap_port = 2;
|
||||
int32 bootstrap_room = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Initialize
|
||||
// =====================
|
||||
|
||||
message InitializeRequest {
|
||||
string client_id = 1;
|
||||
string client_version = 2;
|
||||
|
||||
// Operating mode
|
||||
enum Mode {
|
||||
REGULAR = 0; // Normal mode with local scheduler
|
||||
PREFILL = 1; // Prefill-only mode for disaggregated serving
|
||||
DECODE = 2; // Decode-only mode for disaggregated serving
|
||||
}
|
||||
Mode mode = 3;
|
||||
}
|
||||
|
||||
message InitializeResponse {
|
||||
bool success = 1;
|
||||
string scheduler_version = 2;
|
||||
|
||||
// Model information
|
||||
ModelInfo model_info = 3;
|
||||
|
||||
// Server capabilities
|
||||
ServerCapabilities capabilities = 4;
|
||||
|
||||
// Error message if success is false
|
||||
string error_message = 5;
|
||||
}
|
||||
|
||||
message ModelInfo {
|
||||
string model_name = 1;
|
||||
int32 max_context_length = 2;
|
||||
int32 vocab_size = 3;
|
||||
bool supports_tool_calling = 4;
|
||||
bool supports_vision = 5;
|
||||
repeated string special_tokens = 6;
|
||||
|
||||
// Additional model metadata
|
||||
string model_type = 7;
|
||||
int32 num_layers = 8;
|
||||
int32 hidden_size = 9;
|
||||
int32 num_attention_heads = 10;
|
||||
int32 num_key_value_heads = 11;
|
||||
|
||||
// Tokenizer info
|
||||
string tokenizer_type = 12;
|
||||
repeated int32 eos_token_ids = 13;
|
||||
int32 pad_token_id = 14;
|
||||
int32 bos_token_id = 15;
|
||||
}
|
||||
|
||||
message ServerCapabilities {
|
||||
bool continuous_batching = 1;
|
||||
bool disaggregated_serving = 2;
|
||||
bool speculative_decoding = 3;
|
||||
int32 max_batch_size = 4;
|
||||
int32 max_num_batched_tokens = 5;
|
||||
int32 max_prefill_tokens = 6;
|
||||
string attention_backend = 7; // "flashinfer", "triton", "torch"
|
||||
|
||||
// Additional capabilities
|
||||
bool supports_lora = 8;
|
||||
bool supports_grammar = 9;
|
||||
bool supports_multimodal = 10;
|
||||
repeated string supported_modalities = 11; // ["image", "video", "audio"]
|
||||
bool supports_custom_logit_processor = 12;
|
||||
bool supports_session = 13;
|
||||
|
||||
// Hardware info
|
||||
int32 num_gpus = 14;
|
||||
string gpu_type = 15;
|
||||
int64 total_gpu_memory = 16;
|
||||
|
||||
// Parallelism info
|
||||
int32 tensor_parallel_size = 17;
|
||||
int32 pipeline_parallel_size = 18;
|
||||
int32 data_parallel_size = 19;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Generate Request
|
||||
// =====================
|
||||
|
||||
message GenerateRequest {
|
||||
string request_id = 1;
|
||||
|
||||
// Input can be either text or tokenized
|
||||
oneof input {
|
||||
string text = 2;
|
||||
TokenizedInput tokenized = 3;
|
||||
}
|
||||
|
||||
// Multimodal inputs
|
||||
MultimodalInputs mm_inputs = 4;
|
||||
|
||||
// Generation parameters
|
||||
SamplingParams sampling_params = 5;
|
||||
|
||||
// Return options
|
||||
bool return_logprob = 6;
|
||||
int32 logprob_start_len = 7;
|
||||
int32 top_logprobs_num = 8;
|
||||
repeated int32 token_ids_logprob = 9;
|
||||
bool return_hidden_states = 10;
|
||||
|
||||
// Session management
|
||||
SessionParams session_params = 11;
|
||||
|
||||
// For disaggregated serving
|
||||
DisaggregatedParams disaggregated_params = 12;
|
||||
|
||||
// Custom logit processor (serialized)
|
||||
string custom_logit_processor = 13;
|
||||
|
||||
// Request metadata
|
||||
google.protobuf.Timestamp timestamp = 14;
|
||||
bool log_metrics = 15;
|
||||
|
||||
// Input embeddings (alternative to text/tokens)
|
||||
repeated float input_embeds = 16;
|
||||
|
||||
// LoRA adapter ID (if pre-loaded)
|
||||
string lora_id = 17;
|
||||
|
||||
// Data parallel routing
|
||||
int32 data_parallel_rank = 18;
|
||||
|
||||
// For load balancing
|
||||
int32 dp_balance_id = 19;
|
||||
}
|
||||
|
||||
message TokenizedInput {
|
||||
string original_text = 1; // For reference
|
||||
repeated int32 input_ids = 2;
|
||||
}
|
||||
|
||||
message MultimodalInputs {
|
||||
// Simplified multimodal handling - actual data processed by tokenizer
|
||||
repeated string image_urls = 1;
|
||||
repeated string video_urls = 2;
|
||||
repeated string audio_urls = 3;
|
||||
|
||||
// Pre-processed multimodal features (if available)
|
||||
google.protobuf.Struct processed_features = 4;
|
||||
|
||||
// Raw data for direct processing
|
||||
repeated bytes image_data = 5;
|
||||
repeated bytes video_data = 6;
|
||||
repeated bytes audio_data = 7;
|
||||
|
||||
// Modality metadata
|
||||
repeated string modalities = 8;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Generate Response
|
||||
// =====================
|
||||
|
||||
message GenerateResponse {
|
||||
string request_id = 1;
|
||||
|
||||
// Response type
|
||||
oneof response {
|
||||
GenerateStreamChunk chunk = 2;
|
||||
GenerateComplete complete = 3;
|
||||
GenerateError error = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message GenerateStreamChunk {
|
||||
// Generated token
|
||||
int32 token_id = 1;
|
||||
string text = 2;
|
||||
|
||||
// Cumulative counts
|
||||
int32 prompt_tokens = 3;
|
||||
int32 completion_tokens = 4;
|
||||
int32 cached_tokens = 5;
|
||||
|
||||
// Logprobs (if requested)
|
||||
LogProbs logprobs = 6;
|
||||
|
||||
// Hidden states (if requested)
|
||||
repeated float hidden_states = 7;
|
||||
|
||||
// Metadata
|
||||
float generation_time = 8; // Time to generate this token
|
||||
int32 queue_time = 9; // Time spent in queue
|
||||
}
|
||||
|
||||
message GenerateComplete {
|
||||
// Final output
|
||||
repeated int32 output_ids = 1;
|
||||
string output_text = 2;
|
||||
|
||||
// Finish reason
|
||||
enum FinishReason {
|
||||
// The model generated a stop sequence.
|
||||
STOP = 0;
|
||||
// The model reached the maximum generation length.
|
||||
LENGTH = 1;
|
||||
// The model generated an end-of-sequence (EOS) token.
|
||||
EOS_TOKEN = 2;
|
||||
// The model generated a user-provided stop string.
|
||||
STOP_STR = 3;
|
||||
// The request was aborted by the user or system.
|
||||
ABORT = 4;
|
||||
}
|
||||
FinishReason finish_reason = 3;
|
||||
|
||||
// Final counts
|
||||
int32 prompt_tokens = 4;
|
||||
int32 completion_tokens = 5;
|
||||
int32 cached_tokens = 6;
|
||||
|
||||
// Performance metrics
|
||||
float total_generation_time = 7;
|
||||
float time_to_first_token = 8;
|
||||
float tokens_per_second = 9;
|
||||
|
||||
// Spec decode metrics
|
||||
int32 spec_verify_count = 10;
|
||||
|
||||
// All logprobs if requested
|
||||
repeated LogProbs all_logprobs = 11;
|
||||
|
||||
// All hidden states if requested
|
||||
repeated HiddenStates all_hidden_states = 12;
|
||||
}
|
||||
|
||||
message GenerateError {
|
||||
string message = 1;
|
||||
string http_status_code = 2;
|
||||
string details = 3;
|
||||
}
|
||||
|
||||
message LogProbs {
|
||||
repeated float token_logprobs = 1;
|
||||
repeated int32 token_ids = 2;
|
||||
|
||||
// Top logprobs at each position
|
||||
repeated TopLogProbs top_logprobs = 3;
|
||||
|
||||
// Decoded text for tokens
|
||||
repeated string token_texts = 4;
|
||||
}
|
||||
|
||||
message TopLogProbs {
|
||||
repeated float values = 1;
|
||||
repeated int32 token_ids = 2;
|
||||
repeated string token_texts = 3;
|
||||
}
|
||||
|
||||
message HiddenStates {
|
||||
repeated float values = 1;
|
||||
int32 layer = 2;
|
||||
int32 position = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Embedding Request
|
||||
// =====================
|
||||
|
||||
message EmbedRequest {
|
||||
string request_id = 1;
|
||||
|
||||
oneof input {
|
||||
string text = 2;
|
||||
TokenizedInput tokenized = 3;
|
||||
}
|
||||
|
||||
// Multimodal inputs
|
||||
MultimodalInputs mm_inputs = 4;
|
||||
|
||||
// Dummy sampling params for compatibility
|
||||
// EmbedRequest doesn't use sampling_params
|
||||
SamplingParams sampling_params = 5;
|
||||
|
||||
bool log_metrics = 6;
|
||||
|
||||
// Token type IDs for models that require them
|
||||
repeated int32 token_type_ids = 7;
|
||||
|
||||
// Data parallel routing
|
||||
int32 data_parallel_rank = 8;
|
||||
|
||||
// For cross-encoder requests
|
||||
bool is_cross_encoder = 9;
|
||||
repeated string texts = 10; // For cross-encoder batch
|
||||
}
|
||||
|
||||
message EmbedResponse {
|
||||
string request_id = 1;
|
||||
|
||||
oneof response {
|
||||
EmbedComplete complete = 2;
|
||||
EmbedError error = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message EmbedComplete {
|
||||
repeated float embedding = 1;
|
||||
int32 prompt_tokens = 2;
|
||||
int32 cached_tokens = 3;
|
||||
|
||||
// Additional metadata
|
||||
int32 embedding_dim = 4;
|
||||
float generation_time = 5;
|
||||
|
||||
// For batch embeddings
|
||||
repeated Embedding batch_embeddings = 6;
|
||||
}
|
||||
|
||||
message Embedding {
|
||||
repeated float values = 1;
|
||||
int32 index = 2;
|
||||
}
|
||||
|
||||
message EmbedError {
|
||||
string message = 1;
|
||||
string code = 2;
|
||||
string details = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Management Operations
|
||||
// =====================
|
||||
|
||||
message HealthCheckRequest {
|
||||
bool include_detailed_metrics = 1;
|
||||
}
|
||||
|
||||
message HealthCheckResponse {
|
||||
bool healthy = 1;
|
||||
|
||||
// Current load metrics
|
||||
int32 num_requests_running = 2;
|
||||
int32 num_requests_waiting = 3;
|
||||
float gpu_cache_usage = 4;
|
||||
float gpu_memory_usage = 5;
|
||||
|
||||
// KV cache metrics
|
||||
int32 kv_cache_total_blocks = 6;
|
||||
int32 kv_cache_used_blocks = 7;
|
||||
float kv_cache_hit_rate = 8;
|
||||
|
||||
// Additional metrics
|
||||
int32 num_grammar_queue_requests = 9;
|
||||
float generation_throughput = 10; // tokens/sec
|
||||
float average_queue_time = 11; // seconds
|
||||
float average_generation_time = 12; // seconds
|
||||
|
||||
// System metrics
|
||||
float cpu_usage = 13;
|
||||
int64 memory_usage = 14;
|
||||
|
||||
// Disaggregation metrics
|
||||
int32 num_prefill_requests = 15;
|
||||
int32 num_decode_requests = 16;
|
||||
|
||||
// Detailed metrics (optional)
|
||||
google.protobuf.Struct detailed_metrics = 17;
|
||||
}
|
||||
|
||||
message AbortRequest {
|
||||
string request_id = 1;
|
||||
string reason = 2;
|
||||
}
|
||||
|
||||
message AbortResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message FlushCacheRequest {
|
||||
bool flush_all = 1;
|
||||
repeated string session_ids = 2; // Flush specific sessions
|
||||
}
|
||||
|
||||
message FlushCacheResponse {
|
||||
bool success = 1;
|
||||
int32 num_entries_flushed = 2;
|
||||
int64 memory_freed = 3; // bytes
|
||||
string message = 4;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Additional Operations (Future)
|
||||
// =====================
|
||||
|
||||
// Load LoRA adapter
|
||||
message LoadLoRARequest {
|
||||
string adapter_id = 1;
|
||||
string adapter_path = 2;
|
||||
int32 rank = 3;
|
||||
}
|
||||
|
||||
message LoadLoRAResponse {
|
||||
bool success = 1;
|
||||
string adapter_id = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
// Unload LoRA adapter
|
||||
message UnloadLoRARequest {
|
||||
string adapter_id = 1;
|
||||
}
|
||||
|
||||
message UnloadLoRAResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
// Update weights
|
||||
message UpdateWeightsRequest {
|
||||
oneof source {
|
||||
string disk_path = 1;
|
||||
bytes tensor_data = 2;
|
||||
string remote_url = 3;
|
||||
}
|
||||
string weight_name = 4;
|
||||
}
|
||||
|
||||
message UpdateWeightsResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
// Get internal state for debugging
|
||||
message GetInternalStateRequest {
|
||||
repeated string state_keys = 1;
|
||||
}
|
||||
|
||||
message GetInternalStateResponse {
|
||||
google.protobuf.Struct state = 1;
|
||||
}
|
||||
|
||||
// Set internal state for testing
|
||||
message SetInternalStateRequest {
|
||||
google.protobuf.Struct state = 1;
|
||||
}
|
||||
|
||||
message SetInternalStateResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
5
sgl-router/src/protocols/mod.rs
Normal file
5
sgl-router/src/protocols/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
// Protocol definitions and validation for various LLM APIs
|
||||
// This module provides a structured approach to handling different API protocols
|
||||
|
||||
pub mod spec;
|
||||
pub mod validation;
|
||||
2681
sgl-router/src/protocols/spec.rs
Normal file
2681
sgl-router/src/protocols/spec.rs
Normal file
File diff suppressed because it is too large
Load Diff
1164
sgl-router/src/protocols/validation.rs
Normal file
1164
sgl-router/src/protocols/validation.rs
Normal file
File diff suppressed because it is too large
Load Diff
474
sgl-router/src/reasoning_parser/README.md
Normal file
474
sgl-router/src/reasoning_parser/README.md
Normal file
@@ -0,0 +1,474 @@
|
||||
# Reasoning Parser Architecture
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
### High-Level Overview
|
||||
|
||||
The reasoning parser layer provides a unified interface for detecting and extracting reasoning content from Large Language Model (LLM) outputs, particularly from models that support Chain-of-Thought (CoT) reasoning with explicit thinking blocks. The architecture follows a trait-based design pattern enabling pluggable parser implementations while maintaining consistent APIs across different model families that use various reasoning token formats.
|
||||
|
||||
**Key Components:**
|
||||
- **Factory Pattern**: Registry-based creation and pooling of model-specific parsers
|
||||
- **Trait System**: `ReasoningParser` trait for implementation flexibility
|
||||
- **Parser Pooling**: Efficient reuse of parser instances across concurrent requests
|
||||
- **Streaming Support**: Incremental parsing with partial token buffering
|
||||
- **Model Detection**: Pattern-based matching for automatic parser selection
|
||||
- **State Management**: Stateful parsing for streaming scenarios with buffer management
|
||||
- **Thread Safety**: Arc<Mutex> based sharing for high-concurrency environments
|
||||
- **Extensibility**: Easy addition of new model-specific parsers
|
||||
|
||||
**Data Flow:**
|
||||
1. Request → Factory (model detection) → Pooled Parser Retrieval
|
||||
2. One-Shot: Text → Parser → ParserResult (normal + reasoning text)
|
||||
3. Streaming: Chunks → Parser (stateful) → Incremental ParserResult
|
||||
4. Buffer Management: Partial Tokens → Buffer → Complete Token Detection
|
||||
5. Reset: Parser State → Clear Buffers → Ready for Reuse
|
||||
|
||||
### Architecture Highlights
|
||||
|
||||
- **Model-Specific Parsers**: DeepSeek-R1, Qwen3, Kimi, GLM45, Step3 variants
|
||||
- **Parser Pooling**: Singleton instances per model type for memory efficiency
|
||||
- **High Concurrency**: Mutex-protected parsers handle 1000+ req/sec
|
||||
- **Buffer Overflow Protection**: Configurable max buffer size (default 64KB)
|
||||
- **Partial Token Detection**: Intelligent buffering for incomplete delimiters
|
||||
- **Passthrough Mode**: Graceful fallback for unknown models
|
||||
- **Zero-Copy Where Possible**: Efficient string handling in hot paths
|
||||
|
||||
## 2. Mermaid Diagrams
|
||||
|
||||
### Component Flow Diagram
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph Input
|
||||
R[Request] --> MID[Model ID]
|
||||
end
|
||||
|
||||
subgraph Factory Layer
|
||||
MID --> PF[ParserFactory]
|
||||
PF --> REG[ParserRegistry]
|
||||
REG --> PM[Pattern Matching]
|
||||
PM --> PP[Parser Pool]
|
||||
end
|
||||
|
||||
subgraph Parser Pool
|
||||
PP --> DS[DeepSeek-R1]
|
||||
PP --> QW[Qwen3]
|
||||
PP --> QWT[Qwen3-Thinking]
|
||||
PP --> KM[Kimi]
|
||||
PP --> GL[GLM45]
|
||||
PP --> S3[Step3]
|
||||
PP --> PT[Passthrough]
|
||||
end
|
||||
|
||||
subgraph Parser Instance
|
||||
DS --> BP[BaseReasoningParser]
|
||||
QW --> BP
|
||||
KM --> BP
|
||||
GL --> BP
|
||||
S3 --> BP
|
||||
end
|
||||
|
||||
subgraph Processing
|
||||
BP --> DAP[detect_and_parse]
|
||||
BP --> PSI[parse_streaming]
|
||||
BP --> RST[reset]
|
||||
end
|
||||
|
||||
subgraph State Management
|
||||
BP --> BUF[Buffer]
|
||||
BP --> IR[in_reasoning flag]
|
||||
BP --> STS[stripped_think_start]
|
||||
end
|
||||
|
||||
subgraph Output
|
||||
DAP --> PR[ParserResult]
|
||||
PSI --> PR
|
||||
PR --> NT[normal_text]
|
||||
PR --> RT[reasoning_text]
|
||||
end
|
||||
```
|
||||
|
||||
### Sequence Flow Diagram
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as Client
|
||||
participant F as ParserFactory
|
||||
participant R as Registry
|
||||
participant P as Parser Pool
|
||||
participant BP as BaseParser
|
||||
participant PR as ParserResult
|
||||
|
||||
C->>F: get_pooled("deepseek-r1-model")
|
||||
F->>R: find_pooled_parser_for_model()
|
||||
R->>R: pattern_match("deepseek-r1")
|
||||
R->>P: get_pooled_parser("deepseek_r1")
|
||||
|
||||
alt Parser exists in pool
|
||||
P-->>F: Arc<Mutex<Parser>>
|
||||
else Create new parser
|
||||
P->>BP: new DeepSeekR1Parser()
|
||||
P->>P: insert into pool
|
||||
P-->>F: Arc<Mutex<Parser>>
|
||||
end
|
||||
|
||||
F-->>C: PooledParser
|
||||
|
||||
C->>BP: lock().parse_reasoning_streaming_incremental()
|
||||
loop streaming chunks
|
||||
C->>BP: parse_reasoning_streaming_incremental(chunk)
|
||||
BP->>BP: buffer.push_str(chunk)
|
||||
BP->>BP: check partial tokens
|
||||
|
||||
alt Complete token found
|
||||
BP->>PR: create result
|
||||
BP->>BP: clear buffer
|
||||
BP-->>C: ParserResult
|
||||
else Partial token
|
||||
BP->>BP: keep buffering
|
||||
BP-->>C: ParserResult::default()
|
||||
end
|
||||
end
|
||||
|
||||
C->>BP: reset()
|
||||
BP->>BP: clear buffers & flags
|
||||
C->>BP: unlock()
|
||||
```
|
||||
|
||||
### Class/Type Diagram
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class ReasoningParser {
|
||||
<<trait>>
|
||||
+detect_and_parse_reasoning(&mut self, text: &str) Result~ParserResult~
|
||||
+parse_reasoning_streaming_incremental(&mut self, text: &str) Result~ParserResult~
|
||||
+reset(&mut self)
|
||||
+model_type(&self) &str
|
||||
}
|
||||
|
||||
class ParserResult {
|
||||
+normal_text: String
|
||||
+reasoning_text: String
|
||||
+new(normal: String, reasoning: String) Self
|
||||
+normal(text: String) Self
|
||||
+reasoning(text: String) Self
|
||||
+is_empty() bool
|
||||
}
|
||||
|
||||
class ParserConfig {
|
||||
+think_start_token: String
|
||||
+think_end_token: String
|
||||
+stream_reasoning: bool
|
||||
+max_buffer_size: usize
|
||||
+initial_in_reasoning: bool
|
||||
+default() Self
|
||||
}
|
||||
|
||||
class BaseReasoningParser {
|
||||
-config: ParserConfig
|
||||
-in_reasoning: bool
|
||||
-buffer: String
|
||||
-stripped_think_start: bool
|
||||
-model_type: String
|
||||
+new(config: ParserConfig) Self
|
||||
+with_model_type(model: String) Self
|
||||
-is_partial_token(&self, text: &str) bool
|
||||
}
|
||||
|
||||
class DeepSeekR1Parser {
|
||||
-base: BaseReasoningParser
|
||||
+new() Self
|
||||
}
|
||||
|
||||
class Qwen3Parser {
|
||||
-base: BaseReasoningParser
|
||||
+new() Self
|
||||
}
|
||||
|
||||
class QwenThinkingParser {
|
||||
-base: BaseReasoningParser
|
||||
+new() Self
|
||||
}
|
||||
|
||||
class KimiParser {
|
||||
-base: BaseReasoningParser
|
||||
+new() Self
|
||||
}
|
||||
|
||||
class Glm45Parser {
|
||||
-base: BaseReasoningParser
|
||||
+new() Self
|
||||
}
|
||||
|
||||
class Step3Parser {
|
||||
-base: BaseReasoningParser
|
||||
+new() Self
|
||||
}
|
||||
|
||||
class ParserFactory {
|
||||
-registry: ParserRegistry
|
||||
+new() Self
|
||||
+get_pooled(model_id: &str) PooledParser
|
||||
+create(model_id: &str) Result~Box~dyn ReasoningParser~~
|
||||
+clear_pool()
|
||||
}
|
||||
|
||||
class ParserRegistry {
|
||||
-creators: Arc~RwLock~HashMap~~
|
||||
-pool: Arc~RwLock~HashMap~~
|
||||
-patterns: Arc~RwLock~Vec~~
|
||||
+register_parser(name: &str, creator: F)
|
||||
+register_pattern(pattern: &str, parser_name: &str)
|
||||
+get_pooled_parser(name: &str) Option~PooledParser~
|
||||
+find_pooled_parser_for_model(model: &str) Option~PooledParser~
|
||||
}
|
||||
|
||||
ReasoningParser <|.. BaseReasoningParser
|
||||
ReasoningParser <|.. DeepSeekR1Parser
|
||||
ReasoningParser <|.. Qwen3Parser
|
||||
ReasoningParser <|.. QwenThinkingParser
|
||||
ReasoningParser <|.. KimiParser
|
||||
ReasoningParser <|.. Glm45Parser
|
||||
ReasoningParser <|.. Step3Parser
|
||||
|
||||
DeepSeekR1Parser o-- BaseReasoningParser
|
||||
Qwen3Parser o-- BaseReasoningParser
|
||||
QwenThinkingParser o-- BaseReasoningParser
|
||||
KimiParser o-- BaseReasoningParser
|
||||
Glm45Parser o-- BaseReasoningParser
|
||||
Step3Parser o-- BaseReasoningParser
|
||||
|
||||
BaseReasoningParser o-- ParserConfig
|
||||
ParserFactory o-- ParserRegistry
|
||||
ParserRegistry o-- ReasoningParser
|
||||
```
|
||||
|
||||
## 3. Module-by-Module Deep Dive
|
||||
|
||||
### 3.1 mod.rs (Main Module)
|
||||
|
||||
**Key Responsibilities:**
|
||||
- Module organization and public API surface
|
||||
- Re-exports for convenient access to core types
|
||||
- Separation of concerns across submodules
|
||||
|
||||
**Module Structure:**
|
||||
- `factory`: Parser creation and pooling logic
|
||||
- `parsers`: Concrete parser implementations
|
||||
- `traits`: Core trait definitions and types
|
||||
|
||||
### 3.2 traits.rs (Trait Definitions)
|
||||
|
||||
**ParserResult Methods**:
|
||||
- `new()`: Create with both normal and reasoning text
|
||||
- `normal()`: Create with only normal text (convenience)
|
||||
- `reasoning()`: Create with only reasoning text (convenience)
|
||||
- `is_empty()`: Check if result contains any text
|
||||
|
||||
**ReasoningParser Trait**:
|
||||
- **`detect_and_parse_reasoning`**: One-shot parsing for complete text
|
||||
- **`parse_reasoning_streaming_incremental`**: Stateful streaming parser
|
||||
- **`reset`**: Clear state for parser reuse
|
||||
- **`model_type`**: Identify parser variant for debugging
|
||||
|
||||
**ParserConfig Defaults**:
|
||||
- Default tokens: `<think>` and `</think>`
|
||||
- Stream reasoning: true (immediate output)
|
||||
- Max buffer: 65536 bytes (64KB)
|
||||
- Initial state: false (explicit reasoning blocks)
|
||||
|
||||
### 3.3 factory.rs (Parser Creation & Pooling)
|
||||
|
||||
**ParserRegistry Methods**:
|
||||
|
||||
1. **`register_parser`**:
|
||||
- Register creator function for parser type
|
||||
- Lazy instantiation when requested
|
||||
- Thread-safe registration
|
||||
|
||||
2. **`register_pattern`**:
|
||||
- Map model ID patterns to parser names
|
||||
- First-match-wins ordering
|
||||
- Case-insensitive matching
|
||||
|
||||
3. **`get_pooled_parser`**:
|
||||
- Check pool for existing instance
|
||||
- Create and pool if not present
|
||||
- Return Arc<Mutex> for sharing
|
||||
|
||||
4. **`find_pooled_parser_for_model`**:
|
||||
- Pattern match against model ID
|
||||
- Delegate to get_pooled_parser
|
||||
- Case-insensitive comparison
|
||||
|
||||
**ParserFactory Methods**:
|
||||
|
||||
1. **`new()`**:
|
||||
- Register all built-in parsers
|
||||
- Setup model pattern mappings
|
||||
- Initialize empty pool
|
||||
|
||||
2. **`get_pooled`**:
|
||||
- Primary API for getting parsers
|
||||
- Automatic passthrough fallback
|
||||
- Guaranteed non-null return
|
||||
|
||||
3. **`create`**:
|
||||
- Create fresh parser instance
|
||||
- No pooling (for testing/isolation)
|
||||
- Returns Result for error handling
|
||||
|
||||
**Registered Parsers**:
|
||||
- `base`: Generic configurable parser
|
||||
- `deepseek_r1`: DeepSeek-R1 (initial_in_reasoning=true)
|
||||
- `qwen3`: Qwen3 base model (initial_in_reasoning=false)
|
||||
- `qwen3_thinking`: Qwen3 thinking variant (initial_in_reasoning=true)
|
||||
- `kimi`: Kimi with Unicode tokens
|
||||
- `glm45`: GLM-4.5 parser
|
||||
- `step3`: Step3 parser
|
||||
- `passthrough`: No-op fallback parser
|
||||
|
||||
**Model Pattern Mappings**:
|
||||
```
|
||||
"deepseek-r1" → "deepseek_r1"
|
||||
"qwen3-thinking" → "qwen3_thinking"
|
||||
"qwen-thinking" → "qwen3_thinking"
|
||||
"qwen3" → "qwen3"
|
||||
"qwen" → "qwen3"
|
||||
"glm45" → "glm45"
|
||||
"kimi" → "kimi"
|
||||
"step3" → "step3"
|
||||
```
|
||||
|
||||
### 3.4 parsers/base.rs (Base Implementation)
|
||||
|
||||
**Key Methods:**
|
||||
|
||||
**`detect_and_parse_reasoning`**:
|
||||
```
|
||||
Algorithm:
|
||||
1. Check buffer overflow protection
|
||||
2. Detect reasoning presence (in_reasoning OR contains start_token)
|
||||
3. If no reasoning → return as normal text
|
||||
4. Remove start token and trim
|
||||
5. If no end token → assume truncated reasoning
|
||||
6. Split on end token
|
||||
7. Extract reasoning and normal portions
|
||||
```
|
||||
|
||||
**`parse_reasoning_streaming_incremental`**:
|
||||
```
|
||||
Algorithm:
|
||||
1. Check buffer capacity
|
||||
2. Append text to buffer
|
||||
3. Check if buffer is partial token prefix
|
||||
4. If partial → buffer and return empty
|
||||
5. Strip start token if present
|
||||
6. Find end token position
|
||||
7. Handle based on state:
|
||||
- In reasoning + end found → split and return both
|
||||
- In reasoning + streaming → return accumulated reasoning
|
||||
- Not in reasoning → return as normal text
|
||||
- In reasoning + no end → continue buffering
|
||||
```
|
||||
|
||||
**Critical Features:**
|
||||
|
||||
1. **Partial Token Detection**:
|
||||
- Prevents premature token matching
|
||||
- Buffers incomplete delimiters
|
||||
- Essential for streaming correctness
|
||||
|
||||
2. **Buffer Management**:
|
||||
- Overflow protection
|
||||
- Accumulation for partial content
|
||||
- Clear on complete token detection
|
||||
|
||||
3. **State Tracking**:
|
||||
- `in_reasoning`: Current parsing state
|
||||
- `stripped_think_start`: Prevent double processing
|
||||
- `buffer`: Accumulated partial content
|
||||
|
||||
|
||||
## 4. Extensibility Guide
|
||||
|
||||
### Adding a New Parser
|
||||
|
||||
**Step 1: Create Parser Implementation**
|
||||
|
||||
```rust
|
||||
// src/reasoning_parser/parsers/mymodel.rs
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParserConfig, ReasoningParser};
|
||||
|
||||
pub struct MyModelParser {
|
||||
base: BaseReasoningParser,
|
||||
}
|
||||
|
||||
impl MyModelParser {
|
||||
pub fn new() -> Self {
|
||||
let config = ParserConfig {
|
||||
think_start_token: "<reasoning>".to_string(),
|
||||
think_end_token: "</reasoning>".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning: false, // or true for implicit
|
||||
};
|
||||
|
||||
Self {
|
||||
base: BaseReasoningParser::new(config)
|
||||
.with_model_type("mymodel".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReasoningParser for MyModelParser {
|
||||
// Delegate to base or implement custom logic
|
||||
fn detect_and_parse_reasoning(&mut self, text: &str)
|
||||
-> Result<ParserResult, ParseError> {
|
||||
self.base.detect_and_parse_reasoning(text)
|
||||
}
|
||||
|
||||
// ... other trait methods
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Register in Factory**
|
||||
|
||||
```rust
|
||||
// In factory.rs ParserFactory::new()
|
||||
registry.register_parser("mymodel", || {
|
||||
Box::new(MyModelParser::new())
|
||||
});
|
||||
|
||||
// Register patterns
|
||||
registry.register_pattern("my-model", "mymodel");
|
||||
registry.register_pattern("mymodel", "mymodel");
|
||||
```
|
||||
|
||||
**Step 3: Export from Module**
|
||||
|
||||
```rust
|
||||
// In parsers/mod.rs
|
||||
pub use self::mymodel::MyModelParser;
|
||||
|
||||
// In reasoning_parser/mod.rs
|
||||
pub use parsers::MyModelParser;
|
||||
```
|
||||
|
||||
### Custom Parsing Logic
|
||||
|
||||
For parsers requiring custom logic beyond configuration:
|
||||
|
||||
```rust
|
||||
impl ReasoningParser for CustomParser {
|
||||
fn parse_reasoning_streaming_incremental(&mut self, text: &str)
|
||||
-> Result<ParserResult, ParseError> {
|
||||
// Custom state machine
|
||||
// Custom token detection
|
||||
// Custom buffering strategy
|
||||
// Return appropriate ParserResult
|
||||
}
|
||||
}
|
||||
```
|
||||
566
sgl-router/src/reasoning_parser/factory.rs
Normal file
566
sgl-router/src/reasoning_parser/factory.rs
Normal file
@@ -0,0 +1,566 @@
|
||||
// Factory and registry for creating model-specific reasoning parsers.
|
||||
// Now with parser pooling support for efficient reuse across requests.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
|
||||
use crate::reasoning_parser::parsers::{
|
||||
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
|
||||
QwenThinkingParser, Step3Parser,
|
||||
};
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
|
||||
|
||||
/// Type alias for pooled parser instances.
|
||||
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
|
||||
|
||||
/// Type alias for parser creator functions.
|
||||
type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
|
||||
|
||||
/// Registry for model-specific parsers with pooling support.
|
||||
#[derive(Clone)]
|
||||
pub struct ParserRegistry {
|
||||
/// Creator functions for parsers (used when pool is empty)
|
||||
creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
|
||||
/// Pooled parser instances for reuse
|
||||
pool: Arc<RwLock<HashMap<String, PooledParser>>>,
|
||||
/// Model pattern to parser name mappings
|
||||
patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
|
||||
}
|
||||
|
||||
impl ParserRegistry {
|
||||
/// Create a new empty registry.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
creators: Arc::new(RwLock::new(HashMap::new())),
|
||||
pool: Arc::new(RwLock::new(HashMap::new())),
|
||||
patterns: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a parser creator for a given parser type.
|
||||
pub fn register_parser<F>(&self, name: &str, creator: F)
|
||||
where
|
||||
F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
|
||||
{
|
||||
let mut creators = self.creators.write().unwrap();
|
||||
creators.insert(name.to_string(), Arc::new(creator));
|
||||
}
|
||||
|
||||
/// Register a model pattern to parser mapping.
|
||||
/// Patterns are checked in order, first match wins.
|
||||
pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
|
||||
let mut patterns = self.patterns.write().unwrap();
|
||||
patterns.push((pattern.to_string(), parser_name.to_string()));
|
||||
}
|
||||
|
||||
/// Get a pooled parser by exact name.
|
||||
/// Returns a shared parser instance from the pool, creating one if needed.
|
||||
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
|
||||
// First check if we have a pooled instance
|
||||
{
|
||||
let pool = self.pool.read().unwrap();
|
||||
if let Some(parser) = pool.get(name) {
|
||||
return Some(Arc::clone(parser));
|
||||
}
|
||||
}
|
||||
|
||||
// If not in pool, create one and add to pool
|
||||
let creators = self.creators.read().unwrap();
|
||||
if let Some(creator) = creators.get(name) {
|
||||
let parser = Arc::new(Mutex::new(creator()));
|
||||
|
||||
// Add to pool for future use
|
||||
let mut pool = self.pool.write().unwrap();
|
||||
pool.insert(name.to_string(), Arc::clone(&parser));
|
||||
|
||||
Some(parser)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a parser by exact name (creates new instance, not pooled).
|
||||
/// Use this for compatibility or when you need a fresh instance.
|
||||
pub fn get_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
|
||||
let creators = self.creators.read().unwrap();
|
||||
creators.get(name).map(|creator| creator())
|
||||
}
|
||||
|
||||
/// Find a pooled parser for a given model ID by pattern matching.
|
||||
pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
|
||||
let patterns = self.patterns.read().unwrap();
|
||||
let model_lower = model_id.to_lowercase();
|
||||
|
||||
for (pattern, parser_name) in patterns.iter() {
|
||||
if model_lower.contains(&pattern.to_lowercase()) {
|
||||
return self.get_pooled_parser(parser_name);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Find a parser for a given model ID by pattern matching (creates new instance).
|
||||
pub fn find_parser_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
|
||||
let patterns = self.patterns.read().unwrap();
|
||||
let model_lower = model_id.to_lowercase();
|
||||
|
||||
for (pattern, parser_name) in patterns.iter() {
|
||||
if model_lower.contains(&pattern.to_lowercase()) {
|
||||
return self.get_parser(parser_name);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Clear the parser pool, forcing new instances to be created.
|
||||
/// Useful for testing or when parsers need to be reset globally.
|
||||
pub fn clear_pool(&self) {
|
||||
let mut pool = self.pool.write().unwrap();
|
||||
pool.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ParserRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Factory for creating reasoning parsers based on model type.
|
||||
#[derive(Clone)]
|
||||
pub struct ParserFactory {
|
||||
registry: ParserRegistry,
|
||||
}
|
||||
|
||||
impl ParserFactory {
|
||||
/// Create a new factory with default parsers registered.
|
||||
pub fn new() -> Self {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// Register base parser
|
||||
registry.register_parser("base", || {
|
||||
Box::new(BaseReasoningParser::new(ParserConfig::default()))
|
||||
});
|
||||
|
||||
// Register DeepSeek-R1 parser (starts with in_reasoning=true)
|
||||
registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
|
||||
|
||||
// Register Qwen3 parser (starts with in_reasoning=false)
|
||||
registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
|
||||
|
||||
// Register Qwen3-thinking parser (starts with in_reasoning=true)
|
||||
registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
|
||||
|
||||
// Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
|
||||
registry.register_parser("kimi", || Box::new(KimiParser::new()));
|
||||
|
||||
// Register GLM45 parser (same format as Qwen3 but separate for debugging)
|
||||
registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
|
||||
|
||||
// Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
|
||||
registry.register_parser("step3", || Box::new(Step3Parser::new()));
|
||||
|
||||
// Register model patterns
|
||||
registry.register_pattern("deepseek-r1", "deepseek_r1");
|
||||
registry.register_pattern("qwen3-thinking", "qwen3_thinking");
|
||||
registry.register_pattern("qwen-thinking", "qwen3_thinking");
|
||||
registry.register_pattern("qwen3", "qwen3");
|
||||
registry.register_pattern("qwen", "qwen3");
|
||||
registry.register_pattern("glm45", "glm45");
|
||||
registry.register_pattern("kimi", "kimi");
|
||||
registry.register_pattern("step3", "step3");
|
||||
|
||||
Self { registry }
|
||||
}
|
||||
|
||||
/// Get a pooled parser for the given model ID.
|
||||
/// Returns a shared instance that can be used concurrently.
|
||||
/// Falls back to a passthrough parser if model is not recognized.
|
||||
pub fn get_pooled(&self, model_id: &str) -> PooledParser {
|
||||
// First try to find by pattern
|
||||
if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
|
||||
return parser;
|
||||
}
|
||||
|
||||
// Fall back to no-op parser (get or create passthrough in pool)
|
||||
self.registry
|
||||
.get_pooled_parser("passthrough")
|
||||
.unwrap_or_else(|| {
|
||||
// Register passthrough if not already registered
|
||||
self.registry.register_parser("passthrough", || {
|
||||
let config = ParserConfig {
|
||||
think_start_token: "".to_string(),
|
||||
think_end_token: "".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning: false,
|
||||
};
|
||||
Box::new(
|
||||
BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
|
||||
)
|
||||
});
|
||||
self.registry.get_pooled_parser("passthrough").unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new parser instance for the given model ID.
|
||||
/// Returns a fresh instance (not pooled).
|
||||
/// Use this when you need an isolated parser instance.
|
||||
pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
|
||||
// First try to find by pattern
|
||||
if let Some(parser) = self.registry.find_parser_for_model(model_id) {
|
||||
return Ok(parser);
|
||||
}
|
||||
|
||||
// Fall back to no-op parser (base parser without reasoning detection)
|
||||
let config = ParserConfig {
|
||||
think_start_token: "".to_string(),
|
||||
think_end_token: "".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning: false,
|
||||
};
|
||||
Ok(Box::new(
|
||||
BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
|
||||
))
|
||||
}
|
||||
|
||||
/// Get the internal registry for custom registration.
|
||||
pub fn registry(&self) -> &ParserRegistry {
|
||||
&self.registry
|
||||
}
|
||||
|
||||
/// Clear the parser pool.
|
||||
/// Useful for testing or when parsers need to be reset globally.
|
||||
pub fn clear_pool(&self) {
|
||||
self.registry.clear_pool();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ParserFactory {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_deepseek_r1() {
|
||||
let factory = ParserFactory::new();
|
||||
let parser = factory.create("deepseek-r1-distill").unwrap();
|
||||
assert_eq!(parser.model_type(), "deepseek_r1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_qwen3() {
|
||||
let factory = ParserFactory::new();
|
||||
let parser = factory.create("qwen3-7b").unwrap();
|
||||
assert_eq!(parser.model_type(), "qwen3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_kimi() {
|
||||
let factory = ParserFactory::new();
|
||||
let parser = factory.create("kimi-chat").unwrap();
|
||||
assert_eq!(parser.model_type(), "kimi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_fallback_to_passthrough() {
|
||||
let factory = ParserFactory::new();
|
||||
let parser = factory.create("unknown-model").unwrap();
|
||||
assert_eq!(parser.model_type(), "passthrough");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_insensitive_matching() {
|
||||
let factory = ParserFactory::new();
|
||||
let parser1 = factory.create("DeepSeek-R1").unwrap();
|
||||
let parser2 = factory.create("QWEN3").unwrap();
|
||||
let parser3 = factory.create("Kimi").unwrap();
|
||||
|
||||
assert_eq!(parser1.model_type(), "deepseek_r1");
|
||||
assert_eq!(parser2.model_type(), "qwen3");
|
||||
assert_eq!(parser3.model_type(), "kimi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_step3_model() {
|
||||
let factory = ParserFactory::new();
|
||||
let step3 = factory.create("step3-model").unwrap();
|
||||
assert_eq!(step3.model_type(), "step3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glm45_model() {
|
||||
let factory = ParserFactory::new();
|
||||
let glm45 = factory.create("glm45-v2").unwrap();
|
||||
assert_eq!(glm45.model_type(), "glm45");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pooled_parser_reuse() {
|
||||
let factory = ParserFactory::new();
|
||||
|
||||
// Get the same parser twice - should be the same instance
|
||||
let parser1 = factory.get_pooled("deepseek-r1");
|
||||
let parser2 = factory.get_pooled("deepseek-r1");
|
||||
|
||||
// Both should point to the same Arc
|
||||
assert!(Arc::ptr_eq(&parser1, &parser2));
|
||||
|
||||
// Different models should get different parsers
|
||||
let parser3 = factory.get_pooled("qwen3");
|
||||
assert!(!Arc::ptr_eq(&parser1, &parser3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pooled_parser_concurrent_access() {
|
||||
use std::thread;
|
||||
|
||||
let factory = ParserFactory::new();
|
||||
let parser = factory.get_pooled("deepseek-r1");
|
||||
|
||||
// Spawn multiple threads that use the same parser
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..3 {
|
||||
let parser_clone = Arc::clone(&parser);
|
||||
let handle = thread::spawn(move || {
|
||||
let mut parser = parser_clone.lock().unwrap();
|
||||
let input = format!("thread {} reasoning</think>answer", i);
|
||||
let result = parser.detect_and_parse_reasoning(&input).unwrap();
|
||||
assert_eq!(result.normal_text, "answer");
|
||||
assert!(result.reasoning_text.contains("reasoning"));
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads to complete
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pool_clearing() {
|
||||
let factory = ParserFactory::new();
|
||||
|
||||
// Get a pooled parser
|
||||
let parser1 = factory.get_pooled("deepseek-r1");
|
||||
|
||||
// Clear the pool
|
||||
factory.clear_pool();
|
||||
|
||||
// Get another parser - should be a new instance
|
||||
let parser2 = factory.get_pooled("deepseek-r1");
|
||||
|
||||
// They should be different instances (different Arc pointers)
|
||||
assert!(!Arc::ptr_eq(&parser1, &parser2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_passthrough_parser_pooling() {
|
||||
let factory = ParserFactory::new();
|
||||
|
||||
// Unknown models should get passthrough parser
|
||||
let parser1 = factory.get_pooled("unknown-model-1");
|
||||
let parser2 = factory.get_pooled("unknown-model-2");
|
||||
|
||||
// Both should use the same passthrough parser instance
|
||||
assert!(Arc::ptr_eq(&parser1, &parser2));
|
||||
|
||||
// Verify it's actually a passthrough parser
|
||||
let parser = parser1.lock().unwrap();
|
||||
assert_eq!(parser.model_type(), "passthrough");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_concurrency_parser_access() {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
let factory = ParserFactory::new();
|
||||
let num_threads = 100;
|
||||
let requests_per_thread = 50;
|
||||
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
|
||||
|
||||
// Track successful operations
|
||||
let success_count = Arc::new(AtomicUsize::new(0));
|
||||
let error_count = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let start = Instant::now();
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let factory = factory.clone();
|
||||
let models = models.clone();
|
||||
let success_count = Arc::clone(&success_count);
|
||||
let error_count = Arc::clone(&error_count);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for request_id in 0..requests_per_thread {
|
||||
// Rotate through different models
|
||||
let model = &models[(thread_id + request_id) % models.len()];
|
||||
let parser = factory.get_pooled(model);
|
||||
|
||||
// Use blocking lock - this is the realistic scenario
|
||||
// In production, requests would wait for the parser to be available
|
||||
// Handle poisoned locks gracefully
|
||||
let mut p = match parser.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(_poisoned) => {
|
||||
// Lock was poisoned by a panicking thread
|
||||
// In production, we might want to recreate the parser
|
||||
// For testing, we'll just skip this iteration
|
||||
error_count.fetch_add(1, Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Simulate realistic parsing work with substantial text
|
||||
// Typical reasoning can be 500-5000 tokens
|
||||
let reasoning_text = format!(
|
||||
"Thread {} is processing request {}. Let me think through this step by step. \
|
||||
First, I need to understand the problem. The problem involves analyzing data \
|
||||
and making calculations. Let me break this down: \n\
|
||||
1. Initial analysis shows that we have multiple variables to consider. \
|
||||
2. The data suggests a pattern that needs further investigation. \
|
||||
3. Computing the values: {} * {} = {}. \
|
||||
4. Cross-referencing with previous results indicates consistency. \
|
||||
5. The mathematical proof follows from the axioms... \
|
||||
6. Considering edge cases and boundary conditions... \
|
||||
7. Validating against known constraints... \
|
||||
8. The conclusion follows logically from premises A, B, and C. \
|
||||
This reasoning chain demonstrates the validity of our approach.",
|
||||
thread_id, request_id, thread_id, request_id, thread_id * request_id
|
||||
);
|
||||
|
||||
let answer_text = format!(
|
||||
"Based on my analysis, the answer for thread {} request {} is: \
|
||||
The solution involves multiple steps as outlined in the reasoning. \
|
||||
The final result is {} with confidence level high. \
|
||||
This conclusion is supported by rigorous mathematical analysis \
|
||||
and has been validated against multiple test cases. \
|
||||
The implementation should handle edge cases appropriately.",
|
||||
thread_id,
|
||||
request_id,
|
||||
thread_id * request_id
|
||||
);
|
||||
|
||||
let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
|
||||
|
||||
match p.detect_and_parse_reasoning(&input) {
|
||||
Ok(result) => {
|
||||
// Verify parsing worked correctly with substantial content
|
||||
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
|
||||
assert!(result
|
||||
.normal_text
|
||||
.contains(&format!("thread {}", thread_id)));
|
||||
|
||||
// For parsers that accumulate reasoning (stream_reasoning=false)
|
||||
// the reasoning_text should be populated
|
||||
if !result.reasoning_text.is_empty() {
|
||||
assert!(result
|
||||
.reasoning_text
|
||||
.contains(&format!("Thread {}", thread_id)));
|
||||
assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
|
||||
}
|
||||
|
||||
// Normal text should always be present
|
||||
assert!(result.normal_text.len() > 100); // Ensure substantial answer
|
||||
success_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Parse error: {:?}", e);
|
||||
error_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
// Explicitly drop the lock to release it quickly
|
||||
drop(p);
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let total_requests = num_threads * requests_per_thread;
|
||||
let successes = success_count.load(Ordering::Relaxed);
|
||||
let errors = error_count.load(Ordering::Relaxed);
|
||||
|
||||
// Print stats for debugging
|
||||
println!(
|
||||
"High concurrency test: {} threads, {} requests each",
|
||||
num_threads, requests_per_thread
|
||||
);
|
||||
println!(
|
||||
"Completed in {:?}, {} successes, {} errors",
|
||||
duration, successes, errors
|
||||
);
|
||||
println!(
|
||||
"Throughput: {:.0} requests/sec",
|
||||
(total_requests as f64) / duration.as_secs_f64()
|
||||
);
|
||||
|
||||
// All requests should succeed
|
||||
assert_eq!(successes, total_requests);
|
||||
assert_eq!(errors, 0);
|
||||
|
||||
// Performance check: should handle at least 1000 req/sec
|
||||
let throughput = (total_requests as f64) / duration.as_secs_f64();
|
||||
assert!(
|
||||
throughput > 1000.0,
|
||||
"Throughput too low: {:.0} req/sec",
|
||||
throughput
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_pool_modifications() {
|
||||
use std::thread;
|
||||
|
||||
let factory = ParserFactory::new();
|
||||
let mut handles = vec![];
|
||||
|
||||
// Thread 1: Continuously get parsers
|
||||
let factory1 = factory.clone();
|
||||
handles.push(thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
let _parser = factory1.get_pooled("deepseek-r1");
|
||||
}
|
||||
}));
|
||||
|
||||
// Thread 2: Continuously clear pool
|
||||
let factory2 = factory.clone();
|
||||
handles.push(thread::spawn(move || {
|
||||
for _ in 0..10 {
|
||||
factory2.clear_pool();
|
||||
thread::sleep(std::time::Duration::from_micros(100));
|
||||
}
|
||||
}));
|
||||
|
||||
// Thread 3: Get different parsers
|
||||
let factory3 = factory.clone();
|
||||
handles.push(thread::spawn(move || {
|
||||
for i in 0..100 {
|
||||
let models = ["qwen3", "kimi", "unknown"];
|
||||
let _parser = factory3.get_pooled(models[i % 3]);
|
||||
}
|
||||
}));
|
||||
|
||||
// Wait for all threads - should not deadlock or panic
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
10
sgl-router/src/reasoning_parser/mod.rs
Normal file
10
sgl-router/src/reasoning_parser/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
pub mod factory;
|
||||
pub mod parsers;
|
||||
pub mod traits;
|
||||
|
||||
pub use factory::{ParserFactory, ParserRegistry, PooledParser};
|
||||
pub use parsers::{
|
||||
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
|
||||
QwenThinkingParser, Step3Parser,
|
||||
};
|
||||
pub use traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
386
sgl-router/src/reasoning_parser/parsers/base.rs
Normal file
386
sgl-router/src/reasoning_parser/parsers/base.rs
Normal file
@@ -0,0 +1,386 @@
|
||||
// Base implementation of reasoning parser that handles common logic
|
||||
// for detecting and extracting reasoning blocks from text.
|
||||
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
use tracing as log;
|
||||
|
||||
/// Base reasoning parser implementation.
|
||||
///
|
||||
/// This parser handles the common logic for detecting reasoning blocks
|
||||
/// delimited by start and end tokens (e.g., <think> and </think>).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BaseReasoningParser {
|
||||
config: ParserConfig,
|
||||
in_reasoning: bool,
|
||||
buffer: String,
|
||||
stripped_think_start: bool,
|
||||
model_type: String,
|
||||
}
|
||||
|
||||
impl BaseReasoningParser {
|
||||
/// Create a new BaseReasoningParser with the given configuration.
|
||||
pub fn new(config: ParserConfig) -> Self {
|
||||
let in_reasoning = config.initial_in_reasoning;
|
||||
Self {
|
||||
config,
|
||||
in_reasoning,
|
||||
buffer: String::new(),
|
||||
stripped_think_start: false,
|
||||
model_type: "base".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom model type identifier.
|
||||
pub fn with_model_type(mut self, model_type: String) -> Self {
|
||||
self.model_type = model_type;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if the current buffer is a prefix of one of the tokens.
|
||||
fn is_partial_token(&self, text: &str) -> bool {
|
||||
(self.config.think_start_token.starts_with(text) && self.config.think_start_token != text)
|
||||
|| (self.config.think_end_token.starts_with(text)
|
||||
&& self.config.think_end_token != text)
|
||||
}
|
||||
}
|
||||
|
||||
impl ReasoningParser for BaseReasoningParser {
|
||||
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
|
||||
log::debug!("detect_and_parse_reasoning called with text: {:?}", text);
|
||||
|
||||
// Check input size against buffer limit
|
||||
if text.len() > self.config.max_buffer_size {
|
||||
return Err(ParseError::BufferOverflow(text.len()));
|
||||
}
|
||||
|
||||
let in_reasoning = self.in_reasoning || text.contains(&self.config.think_start_token);
|
||||
log::debug!("in_reasoning: {}", in_reasoning);
|
||||
|
||||
if !in_reasoning {
|
||||
log::debug!("No reasoning detected, returning normal text.");
|
||||
return Ok(ParserResult::normal(text.to_string()));
|
||||
}
|
||||
|
||||
// The text is considered to be in a reasoning block.
|
||||
let processed_text = text
|
||||
.replace(&self.config.think_start_token, "")
|
||||
.trim()
|
||||
.to_string();
|
||||
log::debug!(
|
||||
"Processed text after removing think_start_token: {:?}",
|
||||
processed_text
|
||||
);
|
||||
|
||||
if !processed_text.contains(&self.config.think_end_token) {
|
||||
log::debug!(
|
||||
"Reasoning truncated, think_end_token not found. Returning reasoning text."
|
||||
);
|
||||
// Assume reasoning was truncated before end token
|
||||
return Ok(ParserResult::reasoning(processed_text));
|
||||
}
|
||||
|
||||
// Extract reasoning content
|
||||
let splits: Vec<&str> = processed_text
|
||||
.splitn(2, &self.config.think_end_token)
|
||||
.collect();
|
||||
let reasoning_text = splits.first().unwrap_or(&"").to_string();
|
||||
let normal_text = splits
|
||||
.get(1)
|
||||
.map(|s| s.trim().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
log::debug!("Extracted reasoning_text: {:?}", reasoning_text);
|
||||
log::debug!("Extracted normal_text: {:?}", normal_text);
|
||||
|
||||
Ok(ParserResult::new(normal_text, reasoning_text))
|
||||
}
|
||||
|
||||
fn parse_reasoning_streaming_incremental(
|
||||
&mut self,
|
||||
text: &str,
|
||||
) -> Result<ParserResult, ParseError> {
|
||||
// Check if adding this text would exceed buffer limit
|
||||
if self.buffer.len() + text.len() > self.config.max_buffer_size {
|
||||
return Err(ParseError::BufferOverflow(self.buffer.len() + text.len()));
|
||||
}
|
||||
|
||||
// Incrementally parse the streaming text
|
||||
self.buffer.push_str(text);
|
||||
let mut current_text = self.buffer.clone();
|
||||
|
||||
log::debug!(
|
||||
"parse_reasoning_streaming_incremental called with text: {:?}",
|
||||
text
|
||||
);
|
||||
log::debug!("current buffer: {:?}", self.buffer);
|
||||
log::debug!("current_text: {:?}", current_text);
|
||||
log::debug!(
|
||||
"in_reasoning: {}, stripped_think_start: {}, stream_reasoning: {}",
|
||||
self.in_reasoning,
|
||||
self.stripped_think_start,
|
||||
self.config.stream_reasoning
|
||||
);
|
||||
|
||||
// If the current text is a prefix of a token, keep buffering
|
||||
if self.is_partial_token(¤t_text) {
|
||||
return Ok(ParserResult::default());
|
||||
}
|
||||
|
||||
// Strip start token if present
|
||||
if !self.stripped_think_start && current_text.contains(&self.config.think_start_token) {
|
||||
current_text = current_text.replace(&self.config.think_start_token, "");
|
||||
self.buffer = current_text.clone();
|
||||
self.stripped_think_start = true;
|
||||
self.in_reasoning = true;
|
||||
}
|
||||
|
||||
// Handle end of reasoning block
|
||||
let think_end_idx = if self.in_reasoning {
|
||||
current_text
|
||||
.find(&self.config.think_end_token)
|
||||
.unwrap_or(current_text.len())
|
||||
} else {
|
||||
current_text.len()
|
||||
};
|
||||
|
||||
if self.in_reasoning && think_end_idx < current_text.len() {
|
||||
let reasoning_text = ¤t_text[..think_end_idx];
|
||||
self.buffer.clear();
|
||||
self.in_reasoning = false;
|
||||
let start_idx = think_end_idx + self.config.think_end_token.len();
|
||||
let normal_text = if start_idx < current_text.len() {
|
||||
¤t_text[start_idx..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
return Ok(ParserResult::new(
|
||||
normal_text.to_string(),
|
||||
reasoning_text.trim().to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Continue with reasoning content
|
||||
if self.in_reasoning && self.config.stream_reasoning {
|
||||
// Stream the content immediately
|
||||
let reasoning_text = current_text;
|
||||
self.buffer.clear();
|
||||
Ok(ParserResult::reasoning(reasoning_text))
|
||||
} else if !self.in_reasoning {
|
||||
// If we're not in a reasoning block, return as normal text
|
||||
// CRITICAL FIX: Return current_text (with buffer) not just text
|
||||
// This prevents buffer loss when partial tokens are followed by normal text
|
||||
let normal_text = current_text;
|
||||
self.buffer.clear();
|
||||
Ok(ParserResult::normal(normal_text))
|
||||
} else {
|
||||
// If we are in a reasoning block but no end token is found, buffer it
|
||||
Ok(ParserResult::default())
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.in_reasoning = self.config.initial_in_reasoning;
|
||||
self.buffer.clear();
|
||||
self.stripped_think_start = false;
|
||||
}
|
||||
|
||||
fn model_type(&self) -> &str {
|
||||
&self.model_type
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_parser(
|
||||
initial_in_reasoning: bool,
|
||||
stream_reasoning: bool,
|
||||
) -> BaseReasoningParser {
|
||||
let config = ParserConfig {
|
||||
think_start_token: "<think>".to_string(),
|
||||
think_end_token: "</think>".to_string(),
|
||||
stream_reasoning,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning,
|
||||
};
|
||||
BaseReasoningParser::new(config)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_and_parse_reasoning() {
|
||||
let mut parser = create_test_parser(false, true);
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("<think>with reasoning</think> and more text.")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "and more text.");
|
||||
assert_eq!(result.reasoning_text, "with reasoning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_and_parse_no_reasoning() {
|
||||
let mut parser = create_test_parser(false, true);
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("This is a test without reasoning.")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "This is a test without reasoning.");
|
||||
assert_eq!(result.reasoning_text, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_and_parse_truncated_reasoning() {
|
||||
let mut parser = create_test_parser(false, true);
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("<think>with truncated reasoning")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "");
|
||||
assert_eq!(result.reasoning_text, "with truncated reasoning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_streaming_partial_token() {
|
||||
let mut parser = create_test_parser(false, true);
|
||||
let result = parser
|
||||
.parse_reasoning_streaming_incremental("<thi")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "");
|
||||
assert_eq!(result.reasoning_text, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_streaming_complete() {
|
||||
let mut parser = create_test_parser(false, true);
|
||||
let result = parser
|
||||
.parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, " and more text.");
|
||||
assert_eq!(result.reasoning_text, "with reasoning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_streaming_no_end_token() {
|
||||
let mut parser = create_test_parser(true, true);
|
||||
let result = parser
|
||||
.parse_reasoning_streaming_incremental("<think>with reasoning")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "");
|
||||
assert_eq!(result.reasoning_text, "with reasoning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initial_in_reasoning_true() {
|
||||
// Parser starts with in_reasoning=true (like DeepSeek-R1)
|
||||
let mut parser = create_test_parser(true, true);
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("no think tags here")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "");
|
||||
assert_eq!(result.reasoning_text, "no think tags here");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_buffer_loss_bug_fix() {
|
||||
// Critical test for buffer preservation
|
||||
let mut parser = create_test_parser(false, true);
|
||||
|
||||
// Step 1: Send partial end tag when not in reasoning mode
|
||||
let result1 = parser.parse_reasoning_streaming_incremental("</").unwrap();
|
||||
assert_eq!(result1.normal_text, "");
|
||||
assert_eq!(result1.reasoning_text, "");
|
||||
|
||||
// Step 2: Send normal text that doesn't complete the end tag
|
||||
// Must return "</answer" not just "answer"
|
||||
let result2 = parser
|
||||
.parse_reasoning_streaming_incremental("answer")
|
||||
.unwrap();
|
||||
assert_eq!(result2.normal_text, "</answer");
|
||||
assert_eq!(result2.reasoning_text, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_with_stream_reasoning_enabled() {
|
||||
let mut parser = create_test_parser(false, true);
|
||||
|
||||
// Start reasoning block
|
||||
let result1 = parser
|
||||
.parse_reasoning_streaming_incremental("<think>reasoning ")
|
||||
.unwrap();
|
||||
assert_eq!(result1.normal_text, "");
|
||||
assert_eq!(result1.reasoning_text, "reasoning ");
|
||||
|
||||
// Continue streaming reasoning
|
||||
let result2 = parser
|
||||
.parse_reasoning_streaming_incremental("content ")
|
||||
.unwrap();
|
||||
assert_eq!(result2.normal_text, "");
|
||||
assert_eq!(result2.reasoning_text, "content ");
|
||||
|
||||
// End reasoning block
|
||||
let result3 = parser
|
||||
.parse_reasoning_streaming_incremental("more</think> normal")
|
||||
.unwrap();
|
||||
assert_eq!(result3.normal_text, " normal");
|
||||
assert_eq!(result3.reasoning_text, "more");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_state() {
|
||||
let mut parser = create_test_parser(false, true);
|
||||
|
||||
// Process some text
|
||||
parser
|
||||
.parse_reasoning_streaming_incremental("<think>reasoning</think> normal")
|
||||
.unwrap();
|
||||
|
||||
// Reset and verify state
|
||||
parser.reset();
|
||||
assert!(!parser.in_reasoning);
|
||||
assert!(parser.buffer.is_empty());
|
||||
assert!(!parser.stripped_think_start);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_buffer_overflow_detect_and_parse() {
|
||||
let config = ParserConfig {
|
||||
max_buffer_size: 10, // Set a very small buffer
|
||||
..Default::default()
|
||||
};
|
||||
let mut parser = BaseReasoningParser::new(config);
|
||||
|
||||
let large_text = "a".repeat(20);
|
||||
let result = parser.detect_and_parse_reasoning(&large_text);
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(ParseError::BufferOverflow(size)) => {
|
||||
assert_eq!(size, 20);
|
||||
}
|
||||
_ => panic!("Expected BufferOverflow error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_buffer_overflow_streaming() {
|
||||
let config = ParserConfig {
|
||||
max_buffer_size: 10, // Set a very small buffer
|
||||
..Default::default()
|
||||
};
|
||||
let mut parser = BaseReasoningParser::new(config);
|
||||
|
||||
// Send a partial token that will be buffered
|
||||
let result1 = parser.parse_reasoning_streaming_incremental("<thi");
|
||||
assert!(result1.is_ok());
|
||||
assert_eq!(result1.unwrap().normal_text, "");
|
||||
|
||||
// Second chunk would exceed buffer
|
||||
// Buffer has "<thi" (4 chars) + "this_is_too_large" (17 chars) = 21 total
|
||||
let result2 = parser.parse_reasoning_streaming_incremental("this_is_too_large");
|
||||
assert!(result2.is_err());
|
||||
match result2 {
|
||||
Err(ParseError::BufferOverflow(size)) => {
|
||||
assert_eq!(size, 21); // 4 + 17
|
||||
}
|
||||
_ => panic!("Expected BufferOverflow error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
112
sgl-router/src/reasoning_parser/parsers/deepseek_r1.rs
Normal file
112
sgl-router/src/reasoning_parser/parsers/deepseek_r1.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
// DeepSeek-R1 specific reasoning parser.
|
||||
// This parser starts with in_reasoning=true, assuming all text is reasoning
|
||||
// until an end token is encountered.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
|
||||
/// DeepSeek-R1 reasoning parser.
|
||||
///
|
||||
/// This parser assumes reasoning from the start of text (in_reasoning=true)
|
||||
/// and uses <think> and </think> tokens.
|
||||
pub struct DeepSeekR1Parser {
|
||||
base: BaseReasoningParser,
|
||||
}
|
||||
|
||||
impl DeepSeekR1Parser {
|
||||
/// Create a new DeepSeek-R1 parser.
|
||||
pub fn new() -> Self {
|
||||
let config = ParserConfig {
|
||||
think_start_token: "<think>".to_string(),
|
||||
think_end_token: "</think>".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning: true, // Always starts with reasoning
|
||||
};
|
||||
|
||||
Self {
|
||||
base: BaseReasoningParser::new(config).with_model_type("deepseek_r1".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DeepSeekR1Parser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReasoningParser for DeepSeekR1Parser {
|
||||
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
|
||||
self.base.detect_and_parse_reasoning(text)
|
||||
}
|
||||
|
||||
fn parse_reasoning_streaming_incremental(
|
||||
&mut self,
|
||||
text: &str,
|
||||
) -> Result<ParserResult, ParseError> {
|
||||
self.base.parse_reasoning_streaming_incremental(text)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.base.reset()
|
||||
}
|
||||
|
||||
fn model_type(&self) -> &str {
|
||||
self.base.model_type()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_deepseek_r1_initial_state() {
|
||||
let mut parser = DeepSeekR1Parser::new();
|
||||
|
||||
// Should treat text as reasoning even without start token
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("This is reasoning content")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "");
|
||||
assert_eq!(result.reasoning_text, "This is reasoning content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deepseek_r1_with_end_token() {
|
||||
let mut parser = DeepSeekR1Parser::new();
|
||||
|
||||
// Should extract reasoning until end token
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("reasoning content</think>normal content")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "normal content");
|
||||
assert_eq!(result.reasoning_text, "reasoning content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deepseek_r1_streaming() {
|
||||
let mut parser = DeepSeekR1Parser::new();
|
||||
|
||||
// First chunk - all reasoning
|
||||
let result1 = parser
|
||||
.parse_reasoning_streaming_incremental("thinking about")
|
||||
.unwrap();
|
||||
assert_eq!(result1.reasoning_text, "thinking about");
|
||||
assert_eq!(result1.normal_text, "");
|
||||
|
||||
// Second chunk - ends reasoning
|
||||
let result2 = parser
|
||||
.parse_reasoning_streaming_incremental(" the problem</think>answer")
|
||||
.unwrap();
|
||||
assert_eq!(result2.reasoning_text, "the problem"); // Text is trimmed
|
||||
assert_eq!(result2.normal_text, "answer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_type() {
|
||||
let parser = DeepSeekR1Parser::new();
|
||||
assert_eq!(parser.model_type(), "deepseek_r1");
|
||||
}
|
||||
}
|
||||
118
sgl-router/src/reasoning_parser/parsers/glm45.rs
Normal file
118
sgl-router/src/reasoning_parser/parsers/glm45.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
// GLM45 specific reasoning parser.
|
||||
// Uses the same format as Qwen3 but has its own implementation for debugging.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
|
||||
/// GLM45 reasoning parser.
|
||||
///
|
||||
/// This parser uses the same format as Qwen3 (<think>...</think>) but has
|
||||
/// its own implementation for better debugging and potential future customization.
|
||||
pub struct Glm45Parser {
|
||||
base: BaseReasoningParser,
|
||||
}
|
||||
|
||||
impl Glm45Parser {
|
||||
/// Create a new GLM45 parser.
|
||||
pub fn new() -> Self {
|
||||
let config = ParserConfig {
|
||||
think_start_token: "<think>".to_string(),
|
||||
think_end_token: "</think>".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning: false, // Requires explicit start token like Qwen3
|
||||
};
|
||||
|
||||
Self {
|
||||
base: BaseReasoningParser::new(config).with_model_type("glm45".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Glm45Parser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReasoningParser for Glm45Parser {
|
||||
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
|
||||
self.base.detect_and_parse_reasoning(text)
|
||||
}
|
||||
|
||||
fn parse_reasoning_streaming_incremental(
|
||||
&mut self,
|
||||
text: &str,
|
||||
) -> Result<ParserResult, ParseError> {
|
||||
self.base.parse_reasoning_streaming_incremental(text)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.base.reset()
|
||||
}
|
||||
|
||||
fn model_type(&self) -> &str {
|
||||
self.base.model_type()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_glm45_initial_state() {
|
||||
let mut parser = Glm45Parser::new();
|
||||
|
||||
// Should NOT treat text as reasoning without start token
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("This is normal content")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "This is normal content");
|
||||
assert_eq!(result.reasoning_text, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glm45_with_tokens() {
|
||||
let mut parser = Glm45Parser::new();
|
||||
|
||||
// Should extract reasoning with proper tokens
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("<think>reasoning content</think>answer")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "answer");
|
||||
assert_eq!(result.reasoning_text, "reasoning content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glm45_streaming() {
|
||||
let mut parser = Glm45Parser::new();
|
||||
|
||||
// First chunk - normal text
|
||||
let result1 = parser
|
||||
.parse_reasoning_streaming_incremental("normal text ")
|
||||
.unwrap();
|
||||
assert_eq!(result1.normal_text, "normal text ");
|
||||
assert_eq!(result1.reasoning_text, "");
|
||||
|
||||
// Second chunk - enters reasoning
|
||||
let result2 = parser
|
||||
.parse_reasoning_streaming_incremental("<think>reasoning")
|
||||
.unwrap();
|
||||
assert_eq!(result2.normal_text, "");
|
||||
assert_eq!(result2.reasoning_text, "reasoning");
|
||||
|
||||
// Third chunk - exits reasoning
|
||||
let result3 = parser
|
||||
.parse_reasoning_streaming_incremental("</think>answer")
|
||||
.unwrap();
|
||||
assert_eq!(result3.normal_text, "answer");
|
||||
assert_eq!(result3.reasoning_text, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_type() {
|
||||
let parser = Glm45Parser::new();
|
||||
assert_eq!(parser.model_type(), "glm45");
|
||||
}
|
||||
}
|
||||
137
sgl-router/src/reasoning_parser/parsers/kimi.rs
Normal file
137
sgl-router/src/reasoning_parser/parsers/kimi.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
// Kimi specific reasoning parser.
|
||||
// This parser uses Unicode tokens and starts with in_reasoning=false.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
|
||||
/// Kimi reasoning parser.
|
||||
///
|
||||
/// This parser uses Unicode tokens (◁think▷ and ◁/think▷) and requires
|
||||
/// explicit start tokens to enter reasoning mode.
|
||||
pub struct KimiParser {
|
||||
base: BaseReasoningParser,
|
||||
}
|
||||
|
||||
impl KimiParser {
|
||||
/// Create a new Kimi parser.
|
||||
pub fn new() -> Self {
|
||||
let config = ParserConfig {
|
||||
think_start_token: "◁think▷".to_string(),
|
||||
think_end_token: "◁/think▷".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning: false, // Requires explicit start token
|
||||
};
|
||||
|
||||
Self {
|
||||
base: BaseReasoningParser::new(config).with_model_type("kimi".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KimiParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReasoningParser for KimiParser {
|
||||
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
|
||||
self.base.detect_and_parse_reasoning(text)
|
||||
}
|
||||
|
||||
fn parse_reasoning_streaming_incremental(
|
||||
&mut self,
|
||||
text: &str,
|
||||
) -> Result<ParserResult, ParseError> {
|
||||
self.base.parse_reasoning_streaming_incremental(text)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.base.reset()
|
||||
}
|
||||
|
||||
fn model_type(&self) -> &str {
|
||||
self.base.model_type()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_kimi_initial_state() {
|
||||
let mut parser = KimiParser::new();
|
||||
|
||||
// Should NOT treat text as reasoning without start token
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("This is normal content")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "This is normal content");
|
||||
assert_eq!(result.reasoning_text, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kimi_with_unicode_tokens() {
|
||||
let mut parser = KimiParser::new();
|
||||
|
||||
// Should extract reasoning with Unicode tokens
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("◁think▷reasoning content◁/think▷answer")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "answer");
|
||||
assert_eq!(result.reasoning_text, "reasoning content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kimi_partial_unicode() {
|
||||
let mut parser = KimiParser::new();
|
||||
|
||||
// Test partial Unicode token buffering
|
||||
let result1 = parser
|
||||
.parse_reasoning_streaming_incremental("◁thi")
|
||||
.unwrap();
|
||||
assert_eq!(result1.normal_text, "");
|
||||
assert_eq!(result1.reasoning_text, "");
|
||||
|
||||
// Complete the token
|
||||
let result2 = parser
|
||||
.parse_reasoning_streaming_incremental("nk▷reasoning")
|
||||
.unwrap();
|
||||
assert_eq!(result2.normal_text, "");
|
||||
assert_eq!(result2.reasoning_text, "reasoning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kimi_streaming() {
|
||||
let mut parser = KimiParser::new();
|
||||
|
||||
// Normal text first
|
||||
let result1 = parser
|
||||
.parse_reasoning_streaming_incremental("normal ")
|
||||
.unwrap();
|
||||
assert_eq!(result1.normal_text, "normal ");
|
||||
assert_eq!(result1.reasoning_text, "");
|
||||
|
||||
// Enter reasoning with Unicode token
|
||||
let result2 = parser
|
||||
.parse_reasoning_streaming_incremental("◁think▷thinking")
|
||||
.unwrap();
|
||||
assert_eq!(result2.normal_text, "");
|
||||
assert_eq!(result2.reasoning_text, "thinking");
|
||||
|
||||
// Exit reasoning
|
||||
let result3 = parser
|
||||
.parse_reasoning_streaming_incremental("◁/think▷answer")
|
||||
.unwrap();
|
||||
assert_eq!(result3.normal_text, "answer");
|
||||
assert_eq!(result3.reasoning_text, ""); // Already returned in stream mode
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_type() {
|
||||
let parser = KimiParser::new();
|
||||
assert_eq!(parser.model_type(), "kimi");
|
||||
}
|
||||
}
|
||||
13
sgl-router/src/reasoning_parser/parsers/mod.rs
Normal file
13
sgl-router/src/reasoning_parser/parsers/mod.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
pub mod base;
|
||||
pub mod deepseek_r1;
|
||||
pub mod glm45;
|
||||
pub mod kimi;
|
||||
pub mod qwen3;
|
||||
pub mod step3;
|
||||
|
||||
pub use base::BaseReasoningParser;
|
||||
pub use deepseek_r1::DeepSeekR1Parser;
|
||||
pub use glm45::Glm45Parser;
|
||||
pub use kimi::KimiParser;
|
||||
pub use qwen3::{Qwen3Parser, QwenThinkingParser};
|
||||
pub use step3::Step3Parser;
|
||||
178
sgl-router/src/reasoning_parser/parsers/qwen3.rs
Normal file
178
sgl-router/src/reasoning_parser/parsers/qwen3.rs
Normal file
@@ -0,0 +1,178 @@
|
||||
// Qwen3 specific reasoning parser.
|
||||
// This parser starts with in_reasoning=false, requiring an explicit
|
||||
// start token to enter reasoning mode.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
|
||||
/// Qwen3 reasoning parser.
|
||||
///
|
||||
/// This parser requires explicit <think> tokens to enter reasoning mode
|
||||
/// (in_reasoning=false initially).
|
||||
pub struct Qwen3Parser {
|
||||
base: BaseReasoningParser,
|
||||
}
|
||||
|
||||
impl Qwen3Parser {
|
||||
/// Create a new Qwen3 parser.
|
||||
pub fn new() -> Self {
|
||||
let config = ParserConfig {
|
||||
think_start_token: "<think>".to_string(),
|
||||
think_end_token: "</think>".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning: false, // Requires explicit start token
|
||||
};
|
||||
|
||||
Self {
|
||||
base: BaseReasoningParser::new(config).with_model_type("qwen3".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Qwen3Parser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReasoningParser for Qwen3Parser {
|
||||
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
|
||||
self.base.detect_and_parse_reasoning(text)
|
||||
}
|
||||
|
||||
fn parse_reasoning_streaming_incremental(
|
||||
&mut self,
|
||||
text: &str,
|
||||
) -> Result<ParserResult, ParseError> {
|
||||
self.base.parse_reasoning_streaming_incremental(text)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.base.reset()
|
||||
}
|
||||
|
||||
fn model_type(&self) -> &str {
|
||||
self.base.model_type()
|
||||
}
|
||||
}
|
||||
|
||||
/// QwenThinking parser - variant that assumes reasoning from start.
|
||||
///
|
||||
/// This is for qwen*thinking models that behave like DeepSeek-R1.
|
||||
pub struct QwenThinkingParser {
|
||||
base: BaseReasoningParser,
|
||||
}
|
||||
|
||||
impl QwenThinkingParser {
|
||||
/// Create a new QwenThinking parser.
|
||||
pub fn new() -> Self {
|
||||
let config = ParserConfig {
|
||||
think_start_token: "<think>".to_string(),
|
||||
think_end_token: "</think>".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning: true, // Assumes reasoning from start
|
||||
};
|
||||
|
||||
Self {
|
||||
base: BaseReasoningParser::new(config).with_model_type("qwen_thinking".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QwenThinkingParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReasoningParser for QwenThinkingParser {
|
||||
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
|
||||
self.base.detect_and_parse_reasoning(text)
|
||||
}
|
||||
|
||||
fn parse_reasoning_streaming_incremental(
|
||||
&mut self,
|
||||
text: &str,
|
||||
) -> Result<ParserResult, ParseError> {
|
||||
self.base.parse_reasoning_streaming_incremental(text)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.base.reset()
|
||||
}
|
||||
|
||||
fn model_type(&self) -> &str {
|
||||
self.base.model_type()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_qwen3_initial_state() {
|
||||
let mut parser = Qwen3Parser::new();
|
||||
|
||||
// Should NOT treat text as reasoning without start token
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("This is normal content")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "This is normal content");
|
||||
assert_eq!(result.reasoning_text, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qwen3_with_tokens() {
|
||||
let mut parser = Qwen3Parser::new();
|
||||
|
||||
// Should extract reasoning with proper tokens
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("<think>reasoning</think>answer")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "answer");
|
||||
assert_eq!(result.reasoning_text, "reasoning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qwen_thinking_initial_state() {
|
||||
let mut parser = QwenThinkingParser::new();
|
||||
|
||||
// Should treat text as reasoning even without start token
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("This is reasoning content")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "");
|
||||
assert_eq!(result.reasoning_text, "This is reasoning content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qwen3_streaming() {
|
||||
let mut parser = Qwen3Parser::new();
|
||||
|
||||
// First chunk - normal text (no start token yet)
|
||||
let result1 = parser
|
||||
.parse_reasoning_streaming_incremental("normal text ")
|
||||
.unwrap();
|
||||
assert_eq!(result1.normal_text, "normal text ");
|
||||
assert_eq!(result1.reasoning_text, "");
|
||||
|
||||
// Second chunk - enters reasoning
|
||||
let result2 = parser
|
||||
.parse_reasoning_streaming_incremental("<think>reasoning")
|
||||
.unwrap();
|
||||
assert_eq!(result2.normal_text, "");
|
||||
assert_eq!(result2.reasoning_text, "reasoning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_types() {
|
||||
let qwen3 = Qwen3Parser::new();
|
||||
assert_eq!(qwen3.model_type(), "qwen3");
|
||||
|
||||
let qwen_thinking = QwenThinkingParser::new();
|
||||
assert_eq!(qwen_thinking.model_type(), "qwen_thinking");
|
||||
}
|
||||
}
|
||||
123
sgl-router/src/reasoning_parser/parsers/step3.rs
Normal file
123
sgl-router/src/reasoning_parser/parsers/step3.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
// Step3 specific reasoning parser.
|
||||
// Uses the same format as DeepSeek-R1 but has its own implementation for debugging.
|
||||
|
||||
use crate::reasoning_parser::parsers::BaseReasoningParser;
|
||||
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
|
||||
|
||||
/// Step3 reasoning parser.
|
||||
///
|
||||
/// This parser uses the same format as DeepSeek-R1 (<think>...</think>) but has
|
||||
/// its own implementation for better debugging and potential future customization.
|
||||
pub struct Step3Parser {
|
||||
base: BaseReasoningParser,
|
||||
}
|
||||
|
||||
impl Step3Parser {
|
||||
/// Create a new Step3 parser.
|
||||
pub fn new() -> Self {
|
||||
let config = ParserConfig {
|
||||
think_start_token: "<think>".to_string(),
|
||||
think_end_token: "</think>".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536,
|
||||
initial_in_reasoning: true, // Assumes reasoning from start like DeepSeek-R1
|
||||
};
|
||||
|
||||
Self {
|
||||
base: BaseReasoningParser::new(config).with_model_type("step3".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Step3Parser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReasoningParser for Step3Parser {
|
||||
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
|
||||
self.base.detect_and_parse_reasoning(text)
|
||||
}
|
||||
|
||||
fn parse_reasoning_streaming_incremental(
|
||||
&mut self,
|
||||
text: &str,
|
||||
) -> Result<ParserResult, ParseError> {
|
||||
self.base.parse_reasoning_streaming_incremental(text)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.base.reset()
|
||||
}
|
||||
|
||||
fn model_type(&self) -> &str {
|
||||
self.base.model_type()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_step3_initial_state() {
|
||||
let mut parser = Step3Parser::new();
|
||||
|
||||
// Should treat text as reasoning even without start token
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("This is reasoning content")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "");
|
||||
assert_eq!(result.reasoning_text, "This is reasoning content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_step3_with_end_token() {
|
||||
let mut parser = Step3Parser::new();
|
||||
|
||||
// Should handle text with end token
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("reasoning content</think>answer")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "answer");
|
||||
assert_eq!(result.reasoning_text, "reasoning content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_step3_with_both_tokens() {
|
||||
let mut parser = Step3Parser::new();
|
||||
|
||||
// Should handle both start and end tokens
|
||||
let result = parser
|
||||
.detect_and_parse_reasoning("<think>reasoning content</think>answer")
|
||||
.unwrap();
|
||||
assert_eq!(result.normal_text, "answer");
|
||||
assert_eq!(result.reasoning_text, "reasoning content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_step3_streaming() {
|
||||
let mut parser = Step3Parser::new();
|
||||
|
||||
// First chunk - treated as reasoning (initial_in_reasoning=true)
|
||||
let result1 = parser
|
||||
.parse_reasoning_streaming_incremental("reasoning text ")
|
||||
.unwrap();
|
||||
assert_eq!(result1.normal_text, "");
|
||||
assert_eq!(result1.reasoning_text, "reasoning text ");
|
||||
|
||||
// Second chunk - continues reasoning until end token
|
||||
let result2 = parser
|
||||
.parse_reasoning_streaming_incremental("more reasoning</think>answer")
|
||||
.unwrap();
|
||||
assert_eq!(result2.normal_text, "answer");
|
||||
assert_eq!(result2.reasoning_text, "more reasoning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_type() {
|
||||
let parser = Step3Parser::new();
|
||||
assert_eq!(parser.model_type(), "step3");
|
||||
}
|
||||
}
|
||||
130
sgl-router/src/reasoning_parser/traits.rs
Normal file
130
sgl-router/src/reasoning_parser/traits.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use std::fmt;
|
||||
|
||||
/// Result of parsing text for reasoning content.
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct ParserResult {
|
||||
/// The normal text outside of reasoning blocks.
|
||||
pub normal_text: String,
|
||||
|
||||
/// The extracted reasoning text from within reasoning blocks.
|
||||
pub reasoning_text: String,
|
||||
}
|
||||
|
||||
impl ParserResult {
|
||||
/// Create a new ParserResult with the given normal and reasoning text.
|
||||
pub fn new(normal_text: String, reasoning_text: String) -> Self {
|
||||
Self {
|
||||
normal_text,
|
||||
reasoning_text,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a result with only normal text.
|
||||
pub fn normal(text: String) -> Self {
|
||||
Self {
|
||||
normal_text: text,
|
||||
reasoning_text: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a result with only reasoning text.
|
||||
pub fn reasoning(text: String) -> Self {
|
||||
Self {
|
||||
normal_text: String::new(),
|
||||
reasoning_text: text,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this result contains any text.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.normal_text.is_empty() && self.reasoning_text.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for parsing reasoning content from LLM outputs.
|
||||
pub trait ReasoningParser: Send + Sync {
|
||||
/// Detects and parses reasoning from the input text (one-time parsing).
|
||||
///
|
||||
/// This method is used for non-streaming scenarios where the complete
|
||||
/// text is available at once.
|
||||
///
|
||||
/// Returns an error if the text exceeds buffer limits or contains invalid UTF-8.
|
||||
fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError>;
|
||||
|
||||
/// Parses reasoning incrementally from streaming input.
|
||||
///
|
||||
/// This method maintains internal state across calls to handle partial
|
||||
/// tokens and chunk boundaries correctly.
|
||||
///
|
||||
/// Returns an error if the buffer exceeds max_buffer_size.
|
||||
fn parse_reasoning_streaming_incremental(
|
||||
&mut self,
|
||||
text: &str,
|
||||
) -> Result<ParserResult, ParseError>;
|
||||
|
||||
/// Reset the parser state for reuse.
|
||||
///
|
||||
/// This should clear any buffers and reset flags to initial state.
|
||||
fn reset(&mut self);
|
||||
|
||||
/// Get the model type this parser is designed for.
|
||||
fn model_type(&self) -> &str;
|
||||
}
|
||||
|
||||
/// Error types for reasoning parsing operations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ParseError {
|
||||
#[error("Invalid UTF-8 in stream: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
|
||||
#[error("Buffer overflow: {0} bytes exceeds maximum")]
|
||||
BufferOverflow(usize),
|
||||
|
||||
#[error("Unknown model type: {0}")]
|
||||
UnknownModel(String),
|
||||
|
||||
#[error("Parser configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
}
|
||||
|
||||
/// Configuration for parser behavior.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParserConfig {
|
||||
/// The token that marks the start of reasoning content.
|
||||
pub think_start_token: String,
|
||||
|
||||
/// The token that marks the end of reasoning content.
|
||||
pub think_end_token: String,
|
||||
|
||||
/// Whether to stream reasoning content as it arrives.
|
||||
pub stream_reasoning: bool,
|
||||
|
||||
/// Maximum buffer size in bytes.
|
||||
pub max_buffer_size: usize,
|
||||
|
||||
/// Initial state for in_reasoning flag (fixed per parser type).
|
||||
pub initial_in_reasoning: bool,
|
||||
}
|
||||
|
||||
impl Default for ParserConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
think_start_token: "<think>".to_string(),
|
||||
think_end_token: "</think>".to_string(),
|
||||
stream_reasoning: true,
|
||||
max_buffer_size: 65536, // 64KB default
|
||||
initial_in_reasoning: false, // Default to false (explicit reasoning)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ParserResult {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"ParserResult {{ normal: {} chars, reasoning: {} chars }}",
|
||||
self.normal_text.len(),
|
||||
self.reasoning_text.len()
|
||||
)
|
||||
}
|
||||
}
|
||||
195
sgl-router/src/routers/factory.rs
Normal file
195
sgl-router/src/routers/factory.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Factory for creating router instances
|
||||
|
||||
use super::{
|
||||
http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router},
|
||||
RouterTrait,
|
||||
};
|
||||
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
|
||||
use crate::policies::PolicyFactory;
|
||||
use crate::server::AppContext;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Factory for creating router instances based on configuration
|
||||
pub struct RouterFactory;
|
||||
|
||||
impl RouterFactory {
|
||||
/// Create a router instance from application context
|
||||
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Check if IGW mode is enabled
|
||||
if ctx.router_config.enable_igw {
|
||||
return Self::create_igw_router(ctx).await;
|
||||
}
|
||||
|
||||
// Check connection mode and route to appropriate implementation
|
||||
match ctx.router_config.connection_mode {
|
||||
ConnectionMode::Grpc => {
|
||||
// Route to gRPC implementation based on routing mode
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
Self::create_grpc_pd_router(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy.as_ref(),
|
||||
decode_policy.as_ref(),
|
||||
&ctx.router_config.policy,
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
RoutingMode::OpenAI { .. } => {
|
||||
Err("OpenAI mode requires HTTP connection_mode".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
ConnectionMode::Http => {
|
||||
// Route to HTTP implementation based on routing mode
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
|
||||
.await
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
Self::create_pd_router(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy.as_ref(),
|
||||
decode_policy.as_ref(),
|
||||
&ctx.router_config.policy,
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
RoutingMode::OpenAI { worker_urls, .. } => {
|
||||
Self::create_openai_router(worker_urls.clone(), ctx).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a regular router with injected policy
|
||||
async fn create_regular_router(
|
||||
worker_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policy
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Create regular router with injected policy and context
|
||||
let router = Router::new(worker_urls.to_vec(), policy, ctx).await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a PD router with injected policy
|
||||
async fn create_pd_router(
|
||||
prefill_urls: &[(String, Option<u16>)],
|
||||
decode_urls: &[String],
|
||||
prefill_policy_config: Option<&PolicyConfig>,
|
||||
decode_policy_config: Option<&PolicyConfig>,
|
||||
main_policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
||||
let prefill_policy =
|
||||
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
|
||||
let decode_policy =
|
||||
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
||||
|
||||
// Create PD router with separate policies and context
|
||||
let router = PDRouter::new(
|
||||
prefill_urls.to_vec(),
|
||||
decode_urls.to_vec(),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a gRPC router with injected policy
|
||||
pub async fn create_grpc_router(
|
||||
worker_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
use super::grpc::router::GrpcRouter;
|
||||
|
||||
// Create policy
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Create gRPC router with context
|
||||
let router = GrpcRouter::new(worker_urls.to_vec(), policy, ctx).await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a gRPC PD router with tokenizer and worker configuration
|
||||
pub async fn create_grpc_pd_router(
|
||||
prefill_urls: &[(String, Option<u16>)],
|
||||
decode_urls: &[String],
|
||||
prefill_policy_config: Option<&PolicyConfig>,
|
||||
decode_policy_config: Option<&PolicyConfig>,
|
||||
main_policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
use super::grpc::pd_router::GrpcPDRouter;
|
||||
|
||||
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
||||
let prefill_policy =
|
||||
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
|
||||
let decode_policy =
|
||||
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
||||
|
||||
// Create gRPC PD router with context
|
||||
let router = GrpcPDRouter::new(
|
||||
prefill_urls.to_vec(),
|
||||
decode_urls.to_vec(),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create an OpenAI router
|
||||
async fn create_openai_router(
|
||||
worker_urls: Vec<String>,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Use the first worker URL as the OpenAI-compatible base
|
||||
let base_url = worker_urls
|
||||
.first()
|
||||
.cloned()
|
||||
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
|
||||
|
||||
let router =
|
||||
OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create an IGW router (placeholder for future implementation)
|
||||
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// For now, return an error indicating IGW is not yet implemented
|
||||
Err("IGW mode is not yet implemented".to_string())
|
||||
}
|
||||
}
|
||||
4
sgl-router/src/routers/grpc/mod.rs
Normal file
4
sgl-router/src/routers/grpc/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! gRPC router implementations
|
||||
|
||||
pub mod pd_router;
|
||||
pub mod router;
|
||||
328
sgl-router/src/routers/grpc/pd_router.rs
Normal file
328
sgl-router/src/routers/grpc/pd_router.rs
Normal file
@@ -0,0 +1,328 @@
|
||||
// PD (Prefill-Decode) gRPC Router Implementation
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||
};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// gRPC PD (Prefill-Decode) router implementation for SGLang
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
pub struct GrpcPDRouter {
|
||||
/// Prefill worker connections
|
||||
prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
/// Decode worker connections
|
||||
decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
/// gRPC clients for prefill workers
|
||||
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// gRPC clients for decode workers
|
||||
decode_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// Load balancing policy for prefill
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
/// Load balancing policy for decode
|
||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
/// Tokenizer for handling text encoding/decoding
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
/// Reasoning parser factory for structured reasoning outputs
|
||||
reasoning_parser_factory: ParserFactory,
|
||||
/// Tool parser registry for function/tool calls
|
||||
tool_parser_registry: &'static ParserRegistry,
|
||||
/// Worker health checkers
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
/// Configuration
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
}
|
||||
|
||||
impl GrpcPDRouter {
|
||||
/// Create a new gRPC PD router
|
||||
pub async fn new(
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
decode_urls: Vec<String>,
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
ctx: &Arc<crate::server::AppContext>,
|
||||
) -> Result<Self, String> {
|
||||
// Update metrics
|
||||
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
|
||||
|
||||
// Extract necessary components from context
|
||||
let tokenizer = ctx
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
|
||||
.clone();
|
||||
let reasoning_parser_factory = ctx
|
||||
.reasoning_parser_factory
|
||||
.as_ref()
|
||||
.ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
|
||||
.clone();
|
||||
let tool_parser_registry = ctx
|
||||
.tool_parser_registry
|
||||
.ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?;
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||
let core_cb_config = CircuitBreakerConfig {
|
||||
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||
success_threshold: circuit_breaker_config.success_threshold,
|
||||
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create gRPC clients for prefill workers
|
||||
let mut prefill_grpc_clients = HashMap::new();
|
||||
for (url, _bootstrap_port) in &prefill_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
prefill_grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC prefill worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create gRPC clients for decode workers
|
||||
let mut decode_grpc_clients = HashMap::new();
|
||||
for url in &decode_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
decode_grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC decode worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
|
||||
return Err("Failed to connect to any gRPC workers".to_string());
|
||||
}
|
||||
|
||||
// Create Prefill Worker trait objects with gRPC connection mode
|
||||
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||
.iter()
|
||||
.map(|(url, bootstrap_port)| {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: *bootstrap_port,
|
||||
},
|
||||
crate::core::ConnectionMode::Grpc {
|
||||
port: *bootstrap_port,
|
||||
},
|
||||
)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create Decode Worker trait objects with gRPC connection mode
|
||||
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Decode,
|
||||
crate::core::ConnectionMode::Grpc { port: None },
|
||||
)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initialize policies with workers if needed
|
||||
if let Some(cache_aware) = prefill_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&prefill_workers);
|
||||
}
|
||||
|
||||
if let Some(cache_aware) = decode_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&decode_workers);
|
||||
}
|
||||
|
||||
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
|
||||
let decode_workers = Arc::new(RwLock::new(decode_workers));
|
||||
|
||||
let prefill_health_checker = crate::core::start_health_checker(
|
||||
Arc::clone(&prefill_workers),
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
);
|
||||
let decode_health_checker = crate::core::start_health_checker(
|
||||
Arc::clone(&decode_workers),
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
);
|
||||
|
||||
Ok(GrpcPDRouter {
|
||||
prefill_workers,
|
||||
decode_workers,
|
||||
prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)),
|
||||
decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
tokenizer,
|
||||
reasoning_parser_factory,
|
||||
tool_parser_registry,
|
||||
_prefill_health_checker: Some(prefill_health_checker),
|
||||
_decode_health_checker: Some(decode_health_checker),
|
||||
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
|
||||
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
|
||||
dp_aware: ctx.router_config.dp_aware,
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
circuit_breaker_config: core_cb_config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GrpcPDRouter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GrpcPDRouter")
|
||||
.field(
|
||||
"prefill_workers_count",
|
||||
&self.prefill_workers.read().unwrap().len(),
|
||||
)
|
||||
.field(
|
||||
"decode_workers_count",
|
||||
&self.decode_workers.read().unwrap().len(),
|
||||
)
|
||||
.field("timeout_secs", &self.timeout_secs)
|
||||
.field("interval_secs", &self.interval_secs)
|
||||
.field("dp_aware", &self.dp_aware)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouterTrait for GrpcPDRouter {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::GenerateRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ChatCompletionRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::CompletionRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn flush_cache(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"grpc_pd"
|
||||
}
|
||||
|
||||
fn readiness(&self) -> Response {
|
||||
(StatusCode::SERVICE_UNAVAILABLE).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for GrpcPDRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
Err("Not implemented".to_string())
|
||||
}
|
||||
|
||||
fn remove_worker(&self, _worker_url: &str) {}
|
||||
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
266
sgl-router/src/routers/grpc/router.rs
Normal file
266
sgl-router/src/routers/grpc/router.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
// gRPC Router Implementation
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||
};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// gRPC router implementation for SGLang
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
pub struct GrpcRouter {
|
||||
/// Worker connections
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
/// gRPC clients for each worker
|
||||
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// Load balancing policy
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
/// Tokenizer for handling text encoding/decoding
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
/// Reasoning parser factory for structured reasoning outputs
|
||||
reasoning_parser_factory: ParserFactory,
|
||||
/// Tool parser registry for function/tool calls
|
||||
tool_parser_registry: &'static ParserRegistry,
|
||||
/// Worker health checker
|
||||
_health_checker: Option<HealthChecker>,
|
||||
/// Configuration
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
}
|
||||
|
||||
impl GrpcRouter {
|
||||
/// Create a new gRPC router
|
||||
pub async fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
ctx: &Arc<crate::server::AppContext>,
|
||||
) -> Result<Self, String> {
|
||||
// Update metrics
|
||||
RouterMetrics::set_active_workers(worker_urls.len());
|
||||
|
||||
// Extract necessary components from context
|
||||
let tokenizer = ctx
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.ok_or_else(|| "gRPC router requires tokenizer".to_string())?
|
||||
.clone();
|
||||
let reasoning_parser_factory = ctx
|
||||
.reasoning_parser_factory
|
||||
.as_ref()
|
||||
.ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
|
||||
.clone();
|
||||
let tool_parser_registry = ctx
|
||||
.tool_parser_registry
|
||||
.ok_or_else(|| "gRPC router requires tool parser registry".to_string())?;
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||
let core_cb_config = CircuitBreakerConfig {
|
||||
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||
success_threshold: circuit_breaker_config.success_threshold,
|
||||
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create gRPC clients for each worker
|
||||
let mut grpc_clients = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if grpc_clients.is_empty() {
|
||||
return Err("Failed to connect to any gRPC workers".to_string());
|
||||
}
|
||||
|
||||
// Create Worker trait objects with gRPC connection mode
|
||||
let mut workers: Vec<Box<dyn Worker>> = Vec::new();
|
||||
|
||||
// Move clients from the HashMap to the workers
|
||||
for url in &worker_urls {
|
||||
if let Some(client) = grpc_clients.remove(url) {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Regular,
|
||||
crate::core::ConnectionMode::Grpc { port: None },
|
||||
)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.with_grpc_client(client);
|
||||
|
||||
workers.push(Box::new(worker) as Box<dyn Worker>);
|
||||
} else {
|
||||
warn!("No gRPC client for worker {}, skipping", url);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize policy with workers if needed
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&workers);
|
||||
}
|
||||
|
||||
let workers = Arc::new(RwLock::new(workers));
|
||||
let health_checker = crate::core::start_health_checker(
|
||||
Arc::clone(&workers),
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
);
|
||||
|
||||
Ok(GrpcRouter {
|
||||
workers,
|
||||
grpc_clients: Arc::new(RwLock::new(grpc_clients)),
|
||||
policy,
|
||||
tokenizer,
|
||||
reasoning_parser_factory,
|
||||
tool_parser_registry,
|
||||
_health_checker: Some(health_checker),
|
||||
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
|
||||
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
|
||||
dp_aware: ctx.router_config.dp_aware,
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
circuit_breaker_config: core_cb_config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GrpcRouter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GrpcRouter")
|
||||
.field("workers_count", &self.workers.read().unwrap().len())
|
||||
.field("timeout_secs", &self.timeout_secs)
|
||||
.field("interval_secs", &self.interval_secs)
|
||||
.field("dp_aware", &self.dp_aware)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouterTrait for GrpcRouter {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::GenerateRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ChatCompletionRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::CompletionRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn flush_cache(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"grpc"
|
||||
}
|
||||
|
||||
fn readiness(&self) -> Response {
|
||||
(StatusCode::SERVICE_UNAVAILABLE).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for GrpcRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
Err("Not implemented".to_string())
|
||||
}
|
||||
|
||||
fn remove_worker(&self, _worker_url: &str) {}
|
||||
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
self.workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
53
sgl-router/src/routers/header_utils.rs
Normal file
53
sgl-router/src/routers/header_utils.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use axum::body::Body;
|
||||
use axum::extract::Request;
|
||||
use axum::http::HeaderMap;
|
||||
|
||||
/// Copy request headers to a Vec of name-value string pairs
|
||||
/// Used for forwarding headers to backend workers
|
||||
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
|
||||
req.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
// Convert header value to string, skipping non-UTF8 headers
|
||||
value
|
||||
.to_str()
|
||||
.ok()
|
||||
.map(|v| (name.to_string(), v.to_string()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convert headers from reqwest Response to axum HeaderMap
|
||||
/// Filters out hop-by-hop headers that shouldn't be forwarded
|
||||
pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
for (name, value) in reqwest_headers.iter() {
|
||||
// Skip hop-by-hop headers that shouldn't be forwarded
|
||||
let name_str = name.as_str().to_lowercase();
|
||||
if should_forward_header(&name_str) {
|
||||
// The original name and value are already valid, so we can just clone them
|
||||
headers.insert(name.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
headers
|
||||
}
|
||||
|
||||
/// Determine if a header should be forwarded from backend to client
|
||||
fn should_forward_header(name: &str) -> bool {
|
||||
// List of headers that should NOT be forwarded (hop-by-hop headers)
|
||||
!matches!(
|
||||
name,
|
||||
"connection" |
|
||||
"keep-alive" |
|
||||
"proxy-authenticate" |
|
||||
"proxy-authorization" |
|
||||
"te" |
|
||||
"trailers" |
|
||||
"transfer-encoding" |
|
||||
"upgrade" |
|
||||
"content-encoding" | // Let axum/hyper handle encoding
|
||||
"host" // Should not forward the backend's host header
|
||||
)
|
||||
}
|
||||
6
sgl-router/src/routers/http/mod.rs
Normal file
6
sgl-router/src/routers/http/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
//! HTTP router implementations
|
||||
|
||||
pub mod openai_router;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
pub mod router;
|
||||
379
sgl-router/src/routers/http/openai_router.rs
Normal file
379
sgl-router/src/routers/http/openai_router.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
//! OpenAI router implementation (reqwest-based)
|
||||
|
||||
use crate::config::CircuitBreakerConfig;
|
||||
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
|
||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use std::{
|
||||
any::Any,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
|
||||
/// Router for OpenAI backend
|
||||
#[derive(Debug)]
|
||||
pub struct OpenAIRouter {
|
||||
/// HTTP client for upstream OpenAI-compatible API
|
||||
client: reqwest::Client,
|
||||
/// Base URL for identification (no trailing slash)
|
||||
base_url: String,
|
||||
/// Circuit breaker
|
||||
circuit_breaker: CircuitBreaker,
|
||||
/// Health status
|
||||
healthy: AtomicBool,
|
||||
}
|
||||
|
||||
impl OpenAIRouter {
|
||||
/// Create a new OpenAI router
|
||||
pub async fn new(
|
||||
base_url: String,
|
||||
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
||||
) -> Result<Self, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let base_url = base_url.trim_end_matches('/').to_string();
|
||||
|
||||
// Convert circuit breaker config
|
||||
let core_cb_config = circuit_breaker_config
|
||||
.map(|cb| CoreCircuitBreakerConfig {
|
||||
failure_threshold: cb.failure_threshold,
|
||||
success_threshold: cb.success_threshold,
|
||||
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
|
||||
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
circuit_breaker,
|
||||
healthy: AtomicBool::new(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::super::WorkerManagement for OpenAIRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
Err("Cannot add workers to OpenAI router".to_string())
|
||||
}
|
||||
|
||||
fn remove_worker(&self, _worker_url: &str) {
|
||||
// No-op for OpenAI router
|
||||
}
|
||||
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
vec![self.base_url.clone()]
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::super::RouterTrait for OpenAIRouter {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
// Simple upstream probe: GET {base}/v1/models without auth
|
||||
let url = format!("{}/v1/models", self.base_url);
|
||||
match self
|
||||
.client
|
||||
.get(&url)
|
||||
.timeout(std::time::Duration::from_secs(2))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let code = resp.status();
|
||||
// Treat success and auth-required as healthy (endpoint reachable)
|
||||
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
|
||||
(StatusCode::OK, "OK").into_response()
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Upstream status: {}", code),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Upstream error: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
// For OpenAI, health_generate is the same as health
|
||||
self.health(_req).await
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
let info = serde_json::json!({
|
||||
"router_type": "openai",
|
||||
"workers": 1,
|
||||
"base_url": &self.base_url
|
||||
});
|
||||
(StatusCode::OK, info.to_string()).into_response()
|
||||
}
|
||||
|
||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||
// Proxy to upstream /v1/models; forward Authorization header if provided
|
||||
let headers = req.headers();
|
||||
|
||||
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
|
||||
|
||||
if let Some(auth) = headers
|
||||
.get("authorization")
|
||||
.or_else(|| headers.get("Authorization"))
|
||||
{
|
||||
upstream = upstream.header("Authorization", auth);
|
||||
}
|
||||
|
||||
match upstream.send().await {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let content_type = res.headers().get(CONTENT_TYPE).cloned();
|
||||
match res.bytes().await {
|
||||
Ok(body) => {
|
||||
let mut response = Response::new(axum::body::Body::from(body));
|
||||
*response.status_mut() = status;
|
||||
if let Some(ct) = content_type {
|
||||
response.headers_mut().insert(CONTENT_TYPE, ct);
|
||||
}
|
||||
response
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read upstream response: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to contact upstream: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
// Not directly supported without model param; return 501
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"get_model_info not implemented for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &GenerateRequest,
|
||||
) -> Response {
|
||||
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Generate endpoint not supported for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
) -> Response {
|
||||
if !self.circuit_breaker.can_execute() {
|
||||
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
|
||||
}
|
||||
|
||||
// Serialize request body, removing SGLang-only fields
|
||||
let mut payload = match serde_json::to_value(body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to serialize request: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
for key in [
|
||||
"top_k",
|
||||
"min_p",
|
||||
"min_tokens",
|
||||
"regex",
|
||||
"ebnf",
|
||||
"stop_token_ids",
|
||||
"no_stop_trim",
|
||||
"ignore_eos",
|
||||
"continue_final_message",
|
||||
"skip_special_tokens",
|
||||
"lora_path",
|
||||
"session_params",
|
||||
"separate_reasoning",
|
||||
"stream_reasoning",
|
||||
"chat_template_kwargs",
|
||||
"return_hidden_states",
|
||||
"repetition_penalty",
|
||||
] {
|
||||
obj.remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{}/v1/chat/completions", self.base_url);
|
||||
let mut req = self.client.post(&url).json(&payload);
|
||||
|
||||
// Forward Authorization header if provided
|
||||
if let Some(h) = headers {
|
||||
if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
}
|
||||
|
||||
// Accept SSE when stream=true
|
||||
if body.stream {
|
||||
req = req.header("Accept", "text/event-stream");
|
||||
}
|
||||
|
||||
let resp = match req.send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Failed to contact upstream: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let status = StatusCode::from_u16(resp.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !body.stream {
|
||||
// Capture Content-Type before consuming response body
|
||||
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
|
||||
match resp.bytes().await {
|
||||
Ok(body) => {
|
||||
self.circuit_breaker.record_success();
|
||||
let mut response = Response::new(axum::body::Body::from(body));
|
||||
*response.status_mut() = status;
|
||||
if let Some(ct) = content_type {
|
||||
response.headers_mut().insert(CONTENT_TYPE, ct);
|
||||
}
|
||||
response
|
||||
}
|
||||
Err(e) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Stream SSE bytes to client
|
||||
let stream = resp.bytes_stream();
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
tokio::spawn(async move {
|
||||
let mut s = stream;
|
||||
while let Some(chunk) = s.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
if tx.send(Ok(bytes)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let mut response = Response::new(Body::from_stream(
|
||||
tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
|
||||
));
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &CompletionRequest,
|
||||
) -> Response {
|
||||
// Completion endpoint not implemented for OpenAI backend
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Completion endpoint not implemented for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn flush_cache(&self) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"flush_cache not supported for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"get_worker_loads not supported for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn readiness(&self) -> Response {
|
||||
if self.healthy.load(Ordering::Acquire) && self.circuit_breaker.can_execute() {
|
||||
(StatusCode::OK, "Ready").into_response()
|
||||
} else {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "Not ready").into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Embeddings endpoint not implemented for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Rerank endpoint not implemented for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
2500
sgl-router/src/routers/http/pd_router.rs
Normal file
2500
sgl-router/src/routers/http/pd_router.rs
Normal file
File diff suppressed because it is too large
Load Diff
81
sgl-router/src/routers/http/pd_types.rs
Normal file
81
sgl-router/src/routers/http/pd_types.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
// Custom error type for PD router operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PDRouterError {
|
||||
#[error("Worker already exists: {url}")]
|
||||
WorkerAlreadyExists { url: String },
|
||||
|
||||
#[error("Worker not found: {url}")]
|
||||
WorkerNotFound { url: String },
|
||||
|
||||
#[error("Lock acquisition failed: {operation}")]
|
||||
LockError { operation: String },
|
||||
|
||||
#[error("Health check failed for worker: {url}")]
|
||||
HealthCheckFailed { url: String },
|
||||
|
||||
#[error("Invalid worker configuration: {reason}")]
|
||||
InvalidConfiguration { reason: String },
|
||||
|
||||
#[error("Network error: {message}")]
|
||||
NetworkError { message: String },
|
||||
|
||||
#[error("Timeout waiting for worker: {url}")]
|
||||
Timeout { url: String },
|
||||
}
|
||||
|
||||
// Helper functions for workers
|
||||
pub fn api_path(url: &str, api_path: &str) -> String {
|
||||
if api_path.starts_with("/") {
|
||||
format!("{}{}", url, api_path)
|
||||
} else {
|
||||
format!("{}/{}", url, api_path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_hostname(url: &str) -> String {
|
||||
// Simple hostname extraction without external dependencies
|
||||
let url = url
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://");
|
||||
url.split(':').next().unwrap_or("localhost").to_string()
|
||||
}
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
// Optimized bootstrap wrapper for single requests
|
||||
#[derive(Serialize)]
|
||||
pub struct RequestWithBootstrap<'a, T: Serialize> {
|
||||
#[serde(flatten)]
|
||||
pub original: &'a T,
|
||||
pub bootstrap_host: String,
|
||||
pub bootstrap_port: Option<u16>,
|
||||
pub bootstrap_room: u64,
|
||||
}
|
||||
|
||||
// Optimized bootstrap wrapper for batch requests
|
||||
#[derive(Serialize)]
|
||||
pub struct BatchRequestWithBootstrap<'a, T: Serialize> {
|
||||
#[serde(flatten)]
|
||||
pub original: &'a T,
|
||||
pub bootstrap_host: Vec<String>,
|
||||
pub bootstrap_port: Vec<Option<u16>>,
|
||||
pub bootstrap_room: Vec<u64>,
|
||||
}
|
||||
|
||||
// Helper to generate bootstrap room ID
|
||||
pub fn generate_room_id() -> u64 {
|
||||
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
|
||||
rand::random::<u64>() & (i64::MAX as u64)
|
||||
}
|
||||
|
||||
// PD-specific routing policies
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum PDSelectionPolicy {
|
||||
Random,
|
||||
PowerOfTwo,
|
||||
CacheAware {
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
},
|
||||
}
|
||||
1387
sgl-router/src/routers/http/router.rs
Normal file
1387
sgl-router/src/routers/http/router.rs
Normal file
File diff suppressed because it is too large
Load Diff
107
sgl-router/src/routers/mod.rs
Normal file
107
sgl-router/src/routers/mod.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
//! Router implementations
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
|
||||
pub mod factory;
|
||||
pub mod grpc;
|
||||
pub mod header_utils;
|
||||
pub mod http;
|
||||
|
||||
pub use factory::RouterFactory;
|
||||
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
|
||||
pub use http::{openai_router, pd_router, pd_types, router};
|
||||
|
||||
/// Worker management trait for administrative operations
|
||||
///
|
||||
/// This trait is separate from RouterTrait to allow Send futures
|
||||
/// for use in service discovery and other background tasks
|
||||
#[async_trait]
|
||||
pub trait WorkerManagement: Send + Sync {
|
||||
/// Add a worker to the router
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String>;
|
||||
|
||||
/// Remove a worker from the router
|
||||
fn remove_worker(&self, worker_url: &str);
|
||||
|
||||
/// Get all worker URLs
|
||||
fn get_worker_urls(&self) -> Vec<String>;
|
||||
}
|
||||
|
||||
/// Core trait for all router implementations
|
||||
///
|
||||
/// This trait provides a unified interface for routing requests,
|
||||
/// regardless of whether it's a regular router or PD router.
|
||||
#[async_trait]
|
||||
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
/// Get a reference to self as Any for downcasting
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
|
||||
/// Route a health check request
|
||||
async fn health(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Route a health generate request
|
||||
async fn health_generate(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Get server information
|
||||
async fn get_server_info(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Get available models
|
||||
async fn get_models(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Get model information
|
||||
async fn get_model_info(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Route a generate request
|
||||
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest)
|
||||
-> Response;
|
||||
|
||||
/// Route a chat completion request
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
) -> Response;
|
||||
|
||||
/// Route a completion request
|
||||
async fn route_completion(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
) -> Response;
|
||||
|
||||
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||
|
||||
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||
|
||||
/// Flush cache on all workers
|
||||
async fn flush_cache(&self) -> Response;
|
||||
|
||||
/// Get worker loads (for monitoring)
|
||||
async fn get_worker_loads(&self) -> Response;
|
||||
|
||||
/// Get router type name
|
||||
fn router_type(&self) -> &'static str;
|
||||
|
||||
/// Check if this is a PD router
|
||||
fn is_pd_mode(&self) -> bool {
|
||||
self.router_type() == "pd"
|
||||
}
|
||||
|
||||
/// Server liveness check - is the server process running
|
||||
fn liveness(&self) -> Response {
|
||||
// Simple liveness check - if we can respond, we're alive
|
||||
(StatusCode::OK, "OK").into_response()
|
||||
}
|
||||
|
||||
/// Server readiness check - is the server ready to handle requests
|
||||
fn readiness(&self) -> Response;
|
||||
}
|
||||
474
sgl-router/src/server.rs
Normal file
474
sgl-router/src/server.rs
Normal file
@@ -0,0 +1,474 @@
|
||||
use crate::config::RouterConfig;
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::metrics::{self, PrometheusConfig};
|
||||
use crate::middleware::TokenBucket;
|
||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::{RouterFactory, RouterTrait};
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use axum::{
|
||||
extract::{Query, Request, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tokio::spawn;
|
||||
use tracing::{error, info, warn, Level};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppContext {
|
||||
pub client: Client,
|
||||
pub router_config: RouterConfig,
|
||||
pub rate_limiter: Arc<TokenBucket>,
|
||||
pub tokenizer: Option<Arc<dyn Tokenizer>>,
|
||||
pub reasoning_parser_factory: Option<ParserFactory>,
|
||||
pub tool_parser_registry: Option<&'static ParserRegistry>,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
pub fn new(
|
||||
router_config: RouterConfig,
|
||||
client: Client,
|
||||
max_concurrent_requests: usize,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
) -> Result<Self, String> {
|
||||
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
|
||||
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
|
||||
|
||||
// Initialize gRPC-specific components only when in gRPC mode
|
||||
let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
|
||||
if router_config.connection_mode == crate::config::ConnectionMode::Grpc {
|
||||
// Get tokenizer path (required for gRPC mode)
|
||||
let tokenizer_path = router_config
|
||||
.tokenizer_path
|
||||
.clone()
|
||||
.or_else(|| router_config.model_path.clone())
|
||||
.ok_or_else(|| {
|
||||
"gRPC mode requires either --tokenizer-path or --model-path to be specified"
|
||||
.to_string()
|
||||
})?;
|
||||
|
||||
// Initialize all gRPC components
|
||||
let tokenizer = Some(
|
||||
tokenizer_factory::create_tokenizer(&tokenizer_path)
|
||||
.map_err(|e| format!("Failed to create tokenizer: {}", e))?,
|
||||
);
|
||||
let reasoning_parser_factory = Some(ParserFactory::new());
|
||||
let tool_parser_registry = Some(ParserRegistry::new());
|
||||
|
||||
(tokenizer, reasoning_parser_factory, tool_parser_registry)
|
||||
} else {
|
||||
// HTTP mode doesn't need these components
|
||||
(None, None, None)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
router_config,
|
||||
rate_limiter,
|
||||
tokenizer,
|
||||
reasoning_parser_factory,
|
||||
tool_parser_registry,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub router: Arc<dyn RouterTrait>,
|
||||
pub context: Arc<AppContext>,
|
||||
pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
|
||||
}
|
||||
|
||||
// Fallback handler for unmatched routes
|
||||
async fn sink_handler() -> Response {
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
}
|
||||
|
||||
// Health check endpoints
|
||||
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
|
||||
state.router.liveness()
|
||||
}
|
||||
|
||||
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
|
||||
state.router.readiness()
|
||||
}
|
||||
|
||||
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.health(req).await
|
||||
}
|
||||
|
||||
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.health_generate(req).await
|
||||
}
|
||||
|
||||
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.get_server_info(req).await
|
||||
}
|
||||
|
||||
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.get_models(req).await
|
||||
}
|
||||
|
||||
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.get_model_info(req).await
|
||||
}
|
||||
|
||||
// Generation endpoints
|
||||
// The RouterTrait now accepts optional headers and typed body directly
|
||||
async fn generate(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<GenerateRequest>,
|
||||
) -> Response {
|
||||
state.router.route_generate(Some(&headers), &body).await
|
||||
}
|
||||
|
||||
async fn v1_chat_completions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<ChatCompletionRequest>,
|
||||
) -> Response {
|
||||
state.router.route_chat(Some(&headers), &body).await
|
||||
}
|
||||
|
||||
async fn v1_completions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<CompletionRequest>,
|
||||
) -> Response {
|
||||
state.router.route_completion(Some(&headers), &body).await
|
||||
}
|
||||
|
||||
// Worker management endpoints
|
||||
async fn add_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Response {
|
||||
let worker_url = match params.get("url") {
|
||||
Some(url) => url.to_string(),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Worker URL required. Provide 'url' query parameter",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
match state.router.add_worker(&worker_url).await {
|
||||
Ok(message) => (StatusCode::OK, message).into_response(),
|
||||
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
|
||||
let worker_list = state.router.get_worker_urls();
|
||||
Json(serde_json::json!({ "urls": worker_list })).into_response()
|
||||
}
|
||||
|
||||
async fn remove_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Response {
|
||||
let worker_url = match params.get("url") {
|
||||
Some(url) => url.to_string(),
|
||||
None => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
state.router.remove_worker(&worker_url);
|
||||
(
|
||||
StatusCode::OK,
|
||||
format!("Successfully removed worker: {}", worker_url),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||
state.router.flush_cache().await
|
||||
}
|
||||
|
||||
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||
state.router.get_worker_loads().await
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub router_config: RouterConfig,
|
||||
pub max_payload_size: usize,
|
||||
pub log_dir: Option<String>,
|
||||
pub log_level: Option<String>,
|
||||
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
|
||||
pub prometheus_config: Option<PrometheusConfig>,
|
||||
pub request_timeout_secs: u64,
|
||||
pub request_id_headers: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Build the Axum application with all routes and middleware
|
||||
pub fn build_app(
|
||||
app_state: Arc<AppState>,
|
||||
max_payload_size: usize,
|
||||
request_id_headers: Vec<String>,
|
||||
cors_allowed_origins: Vec<String>,
|
||||
) -> Router {
|
||||
// Create routes
|
||||
let protected_routes = Router::new()
|
||||
.route("/generate", post(generate))
|
||||
.route("/v1/chat/completions", post(v1_chat_completions))
|
||||
.route("/v1/completions", post(v1_completions))
|
||||
.route_layer(axum::middleware::from_fn_with_state(
|
||||
app_state.clone(),
|
||||
crate::middleware::concurrency_limit_middleware,
|
||||
));
|
||||
|
||||
let public_routes = Router::new()
|
||||
.route("/liveness", get(liveness))
|
||||
.route("/readiness", get(readiness))
|
||||
.route("/health", get(health))
|
||||
.route("/health_generate", get(health_generate))
|
||||
.route("/v1/models", get(v1_models))
|
||||
.route("/get_model_info", get(get_model_info))
|
||||
.route("/get_server_info", get(get_server_info));
|
||||
|
||||
let admin_routes = Router::new()
|
||||
.route("/add_worker", post(add_worker))
|
||||
.route("/remove_worker", post(remove_worker))
|
||||
.route("/list_workers", get(list_workers))
|
||||
.route("/flush_cache", post(flush_cache))
|
||||
.route("/get_loads", get(get_loads));
|
||||
|
||||
// Build app with all routes and middleware
|
||||
Router::new()
|
||||
.merge(protected_routes)
|
||||
.merge(public_routes)
|
||||
.merge(admin_routes)
|
||||
// Request body size limiting
|
||||
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
||||
max_payload_size,
|
||||
))
|
||||
// Request ID layer - must be added AFTER logging layer in the code
|
||||
// so it executes BEFORE logging layer at runtime (layers execute bottom-up)
|
||||
.layer(crate::middleware::RequestIdLayer::new(request_id_headers))
|
||||
// Custom logging layer that can now see request IDs from extensions
|
||||
.layer(crate::middleware::create_logging_layer())
|
||||
// CORS (should be outermost)
|
||||
.layer(create_cors_layer(cors_allowed_origins))
|
||||
// Fallback
|
||||
.fallback(sink_handler)
|
||||
// State - apply last to get Router<Arc<AppState>>
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Only initialize logging if not already done (for Python bindings support)
|
||||
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
|
||||
Some(logging::init_logging(LoggingConfig {
|
||||
level: config
|
||||
.log_level
|
||||
.as_deref()
|
||||
.and_then(|s| match s.to_uppercase().parse::<Level>() {
|
||||
Ok(l) => Some(l),
|
||||
Err(_) => {
|
||||
warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(Level::INFO),
|
||||
json_format: false,
|
||||
log_dir: config.log_dir.clone(),
|
||||
colorize: true,
|
||||
log_file_name: "sgl-router".to_string(),
|
||||
log_targets: None,
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Initialize prometheus metrics exporter
|
||||
if let Some(prometheus_config) = config.prometheus_config {
|
||||
metrics::start_prometheus(prometheus_config);
|
||||
}
|
||||
|
||||
info!(
|
||||
"Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
|
||||
config.host,
|
||||
config.port,
|
||||
config.router_config.mode,
|
||||
config.router_config.policy,
|
||||
config.max_payload_size / (1024 * 1024)
|
||||
);
|
||||
|
||||
let client = Client::builder()
|
||||
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
||||
.pool_max_idle_per_host(500) // Increase to 500 connections per host
|
||||
.timeout(Duration::from_secs(config.request_timeout_secs))
|
||||
.connect_timeout(Duration::from_secs(10)) // Separate connection timeout
|
||||
.tcp_nodelay(true)
|
||||
.tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
// Create the application context with all dependencies
|
||||
let app_context = Arc::new(AppContext::new(
|
||||
config.router_config.clone(),
|
||||
client.clone(),
|
||||
config.router_config.max_concurrent_requests,
|
||||
config.router_config.rate_limit_tokens_per_second,
|
||||
)?);
|
||||
|
||||
// Create router with the context
|
||||
let router = RouterFactory::create_router(&app_context).await?;
|
||||
|
||||
// Set up concurrency limiter with queue if configured
|
||||
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
|
||||
app_context.rate_limiter.clone(),
|
||||
config.router_config.queue_size,
|
||||
Duration::from_secs(config.router_config.queue_timeout_secs),
|
||||
);
|
||||
|
||||
// Start queue processor if enabled
|
||||
if let Some(processor) = processor {
|
||||
tokio::spawn(processor.run());
|
||||
info!(
|
||||
"Started request queue with size: {}, timeout: {}s",
|
||||
config.router_config.queue_size, config.router_config.queue_timeout_secs
|
||||
);
|
||||
}
|
||||
|
||||
// Create app state with router and context
|
||||
let app_state = Arc::new(AppState {
|
||||
router: Arc::from(router),
|
||||
context: app_context.clone(),
|
||||
concurrency_queue_tx: limiter.queue_tx.clone(),
|
||||
});
|
||||
let router_arc = Arc::clone(&app_state.router);
|
||||
|
||||
// Start the service discovery if enabled
|
||||
if let Some(service_discovery_config) = config.service_discovery_config {
|
||||
if service_discovery_config.enabled {
|
||||
match start_service_discovery(service_discovery_config, router_arc).await {
|
||||
Ok(handle) => {
|
||||
info!("Service discovery started");
|
||||
// Spawn a task to handle the service discovery thread
|
||||
spawn(async move {
|
||||
if let Err(e) = handle.await {
|
||||
error!("Service discovery task failed: {:?}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to start service discovery: {}", e);
|
||||
warn!("Continuing without service discovery");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Router ready | workers: {:?}",
|
||||
app_state.router.get_worker_urls()
|
||||
);
|
||||
|
||||
// Configure request ID headers
|
||||
let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
|
||||
vec![
|
||||
"x-request-id".to_string(),
|
||||
"x-correlation-id".to_string(),
|
||||
"x-trace-id".to_string(),
|
||||
"request-id".to_string(),
|
||||
]
|
||||
});
|
||||
|
||||
// Build the application
|
||||
let app = build_app(
|
||||
app_state,
|
||||
config.max_payload_size,
|
||||
request_id_headers,
|
||||
config.router_config.cors_allowed_origins.clone(),
|
||||
);
|
||||
|
||||
// Create TCP listener - use the configured host
|
||||
let addr = format!("{}:{}", config.host, config.port);
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
|
||||
// Start server with graceful shutdown
|
||||
info!("Starting server on {}", addr);
|
||||
|
||||
// Serve the application with graceful shutdown
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Graceful shutdown handler
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {
|
||||
info!("Received Ctrl+C, starting graceful shutdown");
|
||||
},
|
||||
_ = terminate => {
|
||||
info!("Received terminate signal, starting graceful shutdown");
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CORS Layer Creation
|
||||
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
|
||||
use tower_http::cors::Any;
|
||||
|
||||
let cors = if allowed_origins.is_empty() {
|
||||
// Allow all origins if none specified
|
||||
tower_http::cors::CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
.expose_headers(Any)
|
||||
} else {
|
||||
// Restrict to specific origins
|
||||
let origins: Vec<http::HeaderValue> = allowed_origins
|
||||
.into_iter()
|
||||
.filter_map(|origin| origin.parse().ok())
|
||||
.collect();
|
||||
|
||||
tower_http::cors::CorsLayer::new()
|
||||
.allow_origin(origins)
|
||||
.allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
|
||||
.allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
|
||||
.expose_headers([http::header::HeaderName::from_static("x-request-id")])
|
||||
};
|
||||
|
||||
cors.max_age(Duration::from_secs(3600))
|
||||
}
|
||||
1168
sgl-router/src/service_discovery.rs
Normal file
1168
sgl-router/src/service_discovery.rs
Normal file
File diff suppressed because it is too large
Load Diff
1021
sgl-router/src/tokenizer/README.md
Normal file
1021
sgl-router/src/tokenizer/README.md
Normal file
File diff suppressed because it is too large
Load Diff
182
sgl-router/src/tokenizer/chat_template.rs
Normal file
182
sgl-router/src/tokenizer/chat_template.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
//! Chat template support for tokenizers using Jinja2 templates
|
||||
//!
|
||||
//! This module provides functionality to apply chat templates to messages,
|
||||
//! similar to HuggingFace transformers' apply_chat_template method.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use minijinja::{context, Environment, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
|
||||
/// Represents a chat message with role and content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
ChatMessage {
|
||||
role: role.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::new("system", content)
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self::new("user", content)
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self::new("assistant", content)
|
||||
}
|
||||
}
|
||||
|
||||
/// Chat template processor using Jinja2
|
||||
pub struct ChatTemplateProcessor {
|
||||
template: String,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
}
|
||||
|
||||
impl ChatTemplateProcessor {
|
||||
/// Create a new chat template processor
|
||||
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||
ChatTemplateProcessor {
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the chat template to a list of messages
|
||||
///
|
||||
/// This mimics the behavior of HuggingFace's apply_chat_template method
|
||||
/// but returns the formatted string instead of token IDs.
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
let mut env = Environment::new();
|
||||
|
||||
// Register the template
|
||||
env.add_template("chat", &self.template)
|
||||
.map_err(|e| anyhow!("Failed to add template: {}", e))?;
|
||||
|
||||
// Get the template
|
||||
let tmpl = env
|
||||
.get_template("chat")
|
||||
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
|
||||
|
||||
// Convert messages to a format Jinja can work with
|
||||
let messages_value: Vec<Value> = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
context! {
|
||||
role => msg.role.clone(),
|
||||
content => msg.content.clone()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Render the template
|
||||
let rendered = tmpl
|
||||
.render(context! {
|
||||
messages => messages_value,
|
||||
add_generation_prompt => add_generation_prompt,
|
||||
bos_token => self.bos_token.clone().unwrap_or_default(),
|
||||
eos_token => self.eos_token.clone().unwrap_or_default()
|
||||
})
|
||||
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
|
||||
|
||||
Ok(rendered)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load chat template from tokenizer config JSON
|
||||
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
|
||||
use std::fs;
|
||||
|
||||
let content = fs::read_to_string(config_path)?;
|
||||
let config: serde_json::Value = serde_json::from_str(&content)?;
|
||||
|
||||
// Look for chat_template in the config
|
||||
if let Some(template) = config.get("chat_template") {
|
||||
if let Some(template_str) = template.as_str() {
|
||||
return Ok(Some(template_str.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(msg.role, "system");
|
||||
assert_eq!(msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_chat_template() {
|
||||
// Simple template that formats messages
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
assert!(result.contains("system: You are helpful"));
|
||||
assert!(result.contains("user: Hello"));
|
||||
assert!(result.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens() {
|
||||
// Template that uses special tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
}
|
||||
}
|
||||
318
sgl-router/src/tokenizer/factory.rs
Normal file
318
sgl-router/src/tokenizer/factory.rs
Normal file
@@ -0,0 +1,318 @@
|
||||
use super::traits;
|
||||
use anyhow::{Error, Result};
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::huggingface::HuggingFaceTokenizer;
|
||||
use super::tiktoken::TiktokenTokenizer;
|
||||
use crate::tokenizer::hub::download_tokenizer_from_hf;
|
||||
|
||||
/// Represents the type of tokenizer being used
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TokenizerType {
|
||||
HuggingFace(String),
|
||||
Mock,
|
||||
Tiktoken(String),
|
||||
// Future: SentencePiece, GGUF
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path to a tokenizer file.
|
||||
/// The file extension is used to determine the tokenizer type.
|
||||
/// Supported file types are:
|
||||
/// - json: HuggingFace tokenizer
|
||||
/// - For testing: can return mock tokenizer
|
||||
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
create_tokenizer_with_chat_template(file_path, None)
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path with an optional chat template
|
||||
pub fn create_tokenizer_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
// Special case for testing
|
||||
if file_path == "mock" || file_path == "test" {
|
||||
return Ok(Arc::new(super::mock::MockTokenizer::new()));
|
||||
}
|
||||
|
||||
let path = Path::new(file_path);
|
||||
|
||||
// Check if file exists
|
||||
if !path.exists() {
|
||||
return Err(Error::msg(format!("File not found: {}", file_path)));
|
||||
}
|
||||
|
||||
// Try to determine tokenizer type from extension
|
||||
let extension = path
|
||||
.extension()
|
||||
.and_then(std::ffi::OsStr::to_str)
|
||||
.map(|s| s.to_lowercase());
|
||||
|
||||
let result = match extension.as_deref() {
|
||||
Some("json") => {
|
||||
let tokenizer =
|
||||
HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
|
||||
|
||||
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
|
||||
}
|
||||
Some("model") => {
|
||||
// SentencePiece model file
|
||||
Err(Error::msg("SentencePiece models not yet supported"))
|
||||
}
|
||||
Some("gguf") => {
|
||||
// GGUF format
|
||||
Err(Error::msg("GGUF format not yet supported"))
|
||||
}
|
||||
_ => {
|
||||
// Try to auto-detect by reading file content
|
||||
auto_detect_tokenizer(file_path)
|
||||
}
|
||||
};
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Auto-detect tokenizer type by examining file content
|
||||
fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
let mut file = File::open(file_path)?;
|
||||
let mut buffer = vec![0u8; 512]; // Read first 512 bytes for detection
|
||||
let bytes_read = file.read(&mut buffer)?;
|
||||
buffer.truncate(bytes_read);
|
||||
|
||||
// Check for JSON (HuggingFace format)
|
||||
if is_likely_json(&buffer) {
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
|
||||
// Check for GGUF magic number
|
||||
if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
|
||||
return Err(Error::msg("GGUF format detected but not yet supported"));
|
||||
}
|
||||
|
||||
// Check for SentencePiece model
|
||||
if is_likely_sentencepiece(&buffer) {
|
||||
return Err(Error::msg(
|
||||
"SentencePiece model detected but not yet supported",
|
||||
));
|
||||
}
|
||||
|
||||
Err(Error::msg(format!(
|
||||
"Unable to determine tokenizer type for file: {}",
|
||||
file_path
|
||||
)))
|
||||
}
|
||||
|
||||
/// Check if the buffer likely contains JSON data
|
||||
fn is_likely_json(buffer: &[u8]) -> bool {
|
||||
// Skip UTF-8 BOM if present
|
||||
let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
|
||||
&buffer[3..]
|
||||
} else {
|
||||
buffer
|
||||
};
|
||||
|
||||
// Find first non-whitespace character without allocation
|
||||
if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
|
||||
*first_byte == b'{' || *first_byte == b'['
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the buffer likely contains a SentencePiece model
|
||||
fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
|
||||
// SentencePiece models often start with specific patterns
|
||||
// This is a simplified check
|
||||
buffer.len() >= 12
|
||||
&& (buffer.starts_with(b"\x0a\x09")
|
||||
|| buffer.starts_with(b"\x08\x00")
|
||||
|| buffer.windows(4).any(|w| w == b"<unk")
|
||||
|| buffer.windows(4).any(|w| w == b"<s>")
|
||||
|| buffer.windows(4).any(|w| w == b"</s>"))
|
||||
}
|
||||
|
||||
/// Factory function to create tokenizer from a model name or path (async version)
|
||||
pub async fn create_tokenizer_async(
|
||||
model_name_or_path: &str,
|
||||
) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
// Check if it's a file path
|
||||
let path = Path::new(model_name_or_path);
|
||||
if path.exists() {
|
||||
return create_tokenizer_from_file(model_name_or_path);
|
||||
}
|
||||
|
||||
// Check if it's a GPT model name that should use Tiktoken
|
||||
if model_name_or_path.contains("gpt-")
|
||||
|| model_name_or_path.contains("davinci")
|
||||
|| model_name_or_path.contains("curie")
|
||||
|| model_name_or_path.contains("babbage")
|
||||
|| model_name_or_path.contains("ada")
|
||||
{
|
||||
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
|
||||
// Try to download tokenizer files from HuggingFace
|
||||
match download_tokenizer_from_hf(model_name_or_path).await {
|
||||
Ok(cache_dir) => {
|
||||
// Look for tokenizer.json in the cache directory
|
||||
let tokenizer_path = cache_dir.join("tokenizer.json");
|
||||
if tokenizer_path.exists() {
|
||||
create_tokenizer_from_file(tokenizer_path.to_str().unwrap())
|
||||
} else {
|
||||
// Try other common tokenizer file names
|
||||
let possible_files = ["tokenizer_config.json", "vocab.json"];
|
||||
for file_name in &possible_files {
|
||||
let file_path = cache_dir.join(file_name);
|
||||
if file_path.exists() {
|
||||
return create_tokenizer_from_file(file_path.to_str().unwrap());
|
||||
}
|
||||
}
|
||||
Err(Error::msg(format!(
|
||||
"Downloaded model '{}' but couldn't find a suitable tokenizer file",
|
||||
model_name_or_path
|
||||
)))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(Error::msg(format!(
|
||||
"Failed to download tokenizer from HuggingFace: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Factory function to create tokenizer from a model name or path (blocking version)
|
||||
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
// Check if it's a file path
|
||||
let path = Path::new(model_name_or_path);
|
||||
if path.exists() {
|
||||
return create_tokenizer_from_file(model_name_or_path);
|
||||
}
|
||||
|
||||
// Check if it's a GPT model name that should use Tiktoken
|
||||
if model_name_or_path.contains("gpt-")
|
||||
|| model_name_or_path.contains("davinci")
|
||||
|| model_name_or_path.contains("curie")
|
||||
|| model_name_or_path.contains("babbage")
|
||||
|| model_name_or_path.contains("ada")
|
||||
{
|
||||
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
|
||||
// Only use tokio for HuggingFace downloads
|
||||
// Check if we're already in a tokio runtime
|
||||
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||
// We're in a runtime, use block_in_place
|
||||
tokio::task::block_in_place(|| handle.block_on(create_tokenizer_async(model_name_or_path)))
|
||||
} else {
|
||||
// No runtime, create a temporary one
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(create_tokenizer_async(model_name_or_path))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get information about a tokenizer file
|
||||
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
|
||||
let path = Path::new(file_path);
|
||||
|
||||
if !path.exists() {
|
||||
return Err(Error::msg(format!("File not found: {}", file_path)));
|
||||
}
|
||||
|
||||
let extension = path
|
||||
.extension()
|
||||
.and_then(std::ffi::OsStr::to_str)
|
||||
.map(|s| s.to_lowercase());
|
||||
|
||||
match extension.as_deref() {
|
||||
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
|
||||
_ => {
|
||||
// Try auto-detection
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
|
||||
let mut file = File::open(file_path)?;
|
||||
let mut buffer = vec![0u8; 512];
|
||||
let bytes_read = file.read(&mut buffer)?;
|
||||
buffer.truncate(bytes_read);
|
||||
|
||||
if is_likely_json(&buffer) {
|
||||
Ok(TokenizerType::HuggingFace(file_path.to_string()))
|
||||
} else {
|
||||
Err(Error::msg("Unknown tokenizer type"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_json_detection() {
|
||||
assert!(is_likely_json(b"{\"test\": \"value\"}"));
|
||||
assert!(is_likely_json(b" \n\t{\"test\": \"value\"}"));
|
||||
assert!(is_likely_json(b"[1, 2, 3]"));
|
||||
assert!(!is_likely_json(b"not json"));
|
||||
assert!(!is_likely_json(b""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mock_tokenizer_creation() {
|
||||
let tokenizer = create_tokenizer_from_file("mock").unwrap();
|
||||
assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_not_found() {
|
||||
let result = create_tokenizer_from_file("/nonexistent/file.json");
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("File not found"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_tiktoken_tokenizer() {
|
||||
// Test creating tokenizer for GPT models
|
||||
let tokenizer = create_tokenizer("gpt-4").unwrap();
|
||||
assert!(tokenizer.vocab_size() > 0);
|
||||
|
||||
// Test encoding and decoding
|
||||
let text = "Hello, world!";
|
||||
let encoding = tokenizer.encode(text).unwrap();
|
||||
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
|
||||
assert_eq!(decoded, text);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_download_tokenizer_from_hf() {
|
||||
// Test with a small model that should have tokenizer files
|
||||
// Skip this test if HF_TOKEN is not set and we're in CI
|
||||
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
|
||||
println!("Skipping HF download test in CI without HF_TOKEN");
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to create tokenizer for a known small model
|
||||
let result = create_tokenizer_async("bert-base-uncased").await;
|
||||
|
||||
// The test might fail due to network issues or rate limiting
|
||||
// so we just check that the function executes without panic
|
||||
match result {
|
||||
Ok(tokenizer) => {
|
||||
assert!(tokenizer.vocab_size() > 0);
|
||||
println!("Successfully downloaded and created tokenizer");
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Download failed (this might be expected): {}", e);
|
||||
// Don't fail the test - network issues shouldn't break CI
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
238
sgl-router/src/tokenizer/hub.rs
Normal file
238
sgl-router/src/tokenizer/hub.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use hf_hub::api::tokio::ApiBuilder;
|
||||
use std::env;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const IGNORED: [&str; 5] = [
|
||||
".gitattributes",
|
||||
"LICENSE",
|
||||
"LICENSE.txt",
|
||||
"README.md",
|
||||
"USE_POLICY.md",
|
||||
];
|
||||
|
||||
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
|
||||
|
||||
/// Checks if a file is a model weight file
|
||||
fn is_weight_file(filename: &str) -> bool {
|
||||
filename.ends_with(".bin")
|
||||
|| filename.ends_with(".safetensors")
|
||||
|| filename.ends_with(".h5")
|
||||
|| filename.ends_with(".msgpack")
|
||||
|| filename.ends_with(".ckpt.index")
|
||||
}
|
||||
|
||||
/// Checks if a file is an image file
|
||||
fn is_image(filename: &str) -> bool {
|
||||
filename.ends_with(".png")
|
||||
|| filename.ends_with("PNG")
|
||||
|| filename.ends_with(".jpg")
|
||||
|| filename.ends_with("JPG")
|
||||
|| filename.ends_with(".jpeg")
|
||||
|| filename.ends_with("JPEG")
|
||||
}
|
||||
|
||||
/// Checks if a file is a tokenizer file
|
||||
fn is_tokenizer_file(filename: &str) -> bool {
|
||||
filename.ends_with("tokenizer.json")
|
||||
|| filename.ends_with("tokenizer_config.json")
|
||||
|| filename.ends_with("special_tokens_map.json")
|
||||
|| filename.ends_with("vocab.json")
|
||||
|| filename.ends_with("merges.txt")
|
||||
|| filename.ends_with(".model") // SentencePiece models
|
||||
|| filename.ends_with(".tiktoken")
|
||||
}
|
||||
|
||||
/// Attempt to download tokenizer files from Hugging Face
|
||||
/// Returns the directory containing the downloaded tokenizer files
|
||||
pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
|
||||
let model_id = model_id.as_ref();
|
||||
let token = env::var(HF_TOKEN_ENV_VAR).ok();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(true)
|
||||
.with_token(token)
|
||||
.build()?;
|
||||
let model_name = model_id.display().to_string();
|
||||
|
||||
let repo = api.model(model_name.clone());
|
||||
|
||||
let info = match repo.info().await {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
|
||||
model_name,
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if info.siblings.is_empty() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Model '{}' exists but contains no downloadable files.",
|
||||
model_name
|
||||
));
|
||||
}
|
||||
|
||||
let mut cache_dir = None;
|
||||
let mut tokenizer_files_found = false;
|
||||
|
||||
// First, identify all tokenizer files to download
|
||||
let tokenizer_files: Vec<_> = info
|
||||
.siblings
|
||||
.iter()
|
||||
.filter(|sib| {
|
||||
!IGNORED.contains(&sib.rfilename.as_str())
|
||||
&& !is_image(&sib.rfilename)
|
||||
&& !is_weight_file(&sib.rfilename)
|
||||
&& is_tokenizer_file(&sib.rfilename)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if tokenizer_files.is_empty() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"No tokenizer files found for model '{}'.",
|
||||
model_name
|
||||
));
|
||||
}
|
||||
|
||||
// Download all tokenizer files
|
||||
for sib in tokenizer_files {
|
||||
match repo.get(&sib.rfilename).await {
|
||||
Ok(path) => {
|
||||
if cache_dir.is_none() {
|
||||
cache_dir = path.parent().map(|p| p.to_path_buf());
|
||||
}
|
||||
tokenizer_files_found = true;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to download tokenizer file '{}' from model '{}': {}",
|
||||
sib.rfilename,
|
||||
model_name,
|
||||
e
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !tokenizer_files_found {
|
||||
return Err(anyhow::anyhow!(
|
||||
"No tokenizer files could be downloaded for model '{}'.",
|
||||
model_name
|
||||
));
|
||||
}
|
||||
|
||||
match cache_dir {
|
||||
Some(dir) => Ok(dir),
|
||||
None => Err(anyhow::anyhow!(
|
||||
"Invalid HF cache path for model '{}'",
|
||||
model_name
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to download a model from Hugging Face (including weights)
|
||||
/// Returns the directory it is in
|
||||
/// If ignore_weights is true, model weight files will be skipped
|
||||
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
|
||||
let name = name.as_ref();
|
||||
let token = env::var(HF_TOKEN_ENV_VAR).ok();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(true)
|
||||
.with_token(token)
|
||||
.build()?;
|
||||
let model_name = name.display().to_string();
|
||||
|
||||
let repo = api.model(model_name.clone());
|
||||
|
||||
let info = match repo.info().await {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
|
||||
model_name,
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if info.siblings.is_empty() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Model '{}' exists but contains no downloadable files.",
|
||||
model_name
|
||||
));
|
||||
}
|
||||
|
||||
let mut p = PathBuf::new();
|
||||
let mut files_downloaded = false;
|
||||
|
||||
for sib in info.siblings {
|
||||
if IGNORED.contains(&sib.rfilename.as_str()) || is_image(&sib.rfilename) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If ignore_weights is true, skip weight files
|
||||
if ignore_weights && is_weight_file(&sib.rfilename) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match repo.get(&sib.rfilename).await {
|
||||
Ok(path) => {
|
||||
p = path;
|
||||
files_downloaded = true;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to download file '{}' from model '{}': {}",
|
||||
sib.rfilename,
|
||||
model_name,
|
||||
e
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !files_downloaded {
|
||||
let file_type = if ignore_weights {
|
||||
"non-weight"
|
||||
} else {
|
||||
"valid"
|
||||
};
|
||||
return Err(anyhow::anyhow!(
|
||||
"No {} files found for model '{}'.",
|
||||
file_type,
|
||||
model_name
|
||||
));
|
||||
}
|
||||
|
||||
match p.parent() {
|
||||
Some(p) => Ok(p.to_path_buf()),
|
||||
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_tokenizer_file() {
|
||||
assert!(is_tokenizer_file("tokenizer.json"));
|
||||
assert!(is_tokenizer_file("tokenizer_config.json"));
|
||||
assert!(is_tokenizer_file("special_tokens_map.json"));
|
||||
assert!(is_tokenizer_file("vocab.json"));
|
||||
assert!(is_tokenizer_file("merges.txt"));
|
||||
assert!(is_tokenizer_file("spiece.model"));
|
||||
assert!(!is_tokenizer_file("model.bin"));
|
||||
assert!(!is_tokenizer_file("README.md"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_weight_file() {
|
||||
assert!(is_weight_file("model.bin"));
|
||||
assert!(is_weight_file("model.safetensors"));
|
||||
assert!(is_weight_file("pytorch_model.bin"));
|
||||
assert!(!is_weight_file("tokenizer.json"));
|
||||
assert!(!is_weight_file("config.json"));
|
||||
}
|
||||
}
|
||||
234
sgl-router/src/tokenizer/huggingface.rs
Normal file
234
sgl-router/src/tokenizer/huggingface.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use anyhow::{Error, Result};
|
||||
use std::collections::HashMap;
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
/// HuggingFace tokenizer wrapper
|
||||
pub struct HuggingFaceTokenizer {
|
||||
tokenizer: HfTokenizer,
|
||||
special_tokens: SpecialTokens,
|
||||
vocab: HashMap<String, TokenIdType>,
|
||||
reverse_vocab: HashMap<TokenIdType, String>,
|
||||
chat_template: Option<String>,
|
||||
}
|
||||
|
||||
impl HuggingFaceTokenizer {
|
||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file
|
||||
pub fn from_file(file_path: &str) -> Result<Self> {
|
||||
Self::from_file_with_chat_template(file_path, None)
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
|
||||
pub fn from_file_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Self> {
|
||||
let tokenizer = HfTokenizer::from_file(file_path)
|
||||
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
|
||||
|
||||
// Extract special tokens
|
||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||
|
||||
// Build vocab mappings
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
|
||||
// Load chat template
|
||||
let chat_template = if let Some(template_path) = chat_template_path {
|
||||
// Load from specified .jinja file
|
||||
Self::load_chat_template_from_file(template_path)?
|
||||
} else {
|
||||
// Try to load from tokenizer_config.json
|
||||
Self::load_chat_template(file_path)
|
||||
};
|
||||
|
||||
Ok(HuggingFaceTokenizer {
|
||||
tokenizer,
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
chat_template,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from an existing HuggingFace tokenizer
|
||||
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
|
||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
|
||||
HuggingFaceTokenizer {
|
||||
tokenizer,
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
chat_template: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract special tokens from the tokenizer
|
||||
fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
|
||||
// Try to get special tokens from the tokenizer
|
||||
// This is a simplified version - actual implementation would need to handle various formats
|
||||
let vocab = tokenizer.get_vocab(true);
|
||||
|
||||
let find_token = |patterns: &[&str]| -> Option<String> {
|
||||
for pattern in patterns {
|
||||
if vocab.contains_key(*pattern) {
|
||||
return Some(pattern.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
SpecialTokens {
|
||||
bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
|
||||
eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
|
||||
unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
|
||||
sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
|
||||
pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
|
||||
cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
|
||||
mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
|
||||
additional_special_tokens: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to load chat template from tokenizer_config.json
|
||||
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
|
||||
// Try to find tokenizer_config.json in the same directory
|
||||
let path = std::path::Path::new(tokenizer_path);
|
||||
let dir = path.parent()?;
|
||||
let config_path = dir.join("tokenizer_config.json");
|
||||
|
||||
if config_path.exists() {
|
||||
if let Ok(template) =
|
||||
super::chat_template::load_chat_template_from_config(config_path.to_str()?)
|
||||
{
|
||||
return template;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Load chat template from a .jinja file
|
||||
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
|
||||
use std::fs;
|
||||
|
||||
let content = fs::read_to_string(template_path)
|
||||
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
|
||||
|
||||
// Clean up the template (similar to Python implementation)
|
||||
let template = content.trim().replace("\\n", "\n");
|
||||
|
||||
Ok(Some(template))
|
||||
}
|
||||
|
||||
/// Set or override the chat template
|
||||
pub fn set_chat_template(&mut self, template: String) {
|
||||
self.chat_template = Some(template);
|
||||
}
|
||||
|
||||
/// Apply chat template if available
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
if let Some(ref template) = self.chat_template {
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.clone(),
|
||||
self.special_tokens.bos_token.clone(),
|
||||
self.special_tokens.eos_token.clone(),
|
||||
);
|
||||
processor.apply_chat_template(messages, add_generation_prompt)
|
||||
} else {
|
||||
// Fallback to simple formatting if no template is available
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
if add_generation_prompt {
|
||||
result.push_str("assistant: ");
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for HuggingFaceTokenizer {
|
||||
fn encode(&self, input: &str) -> Result<Encoding> {
|
||||
self.tokenizer
|
||||
.encode(input, false)
|
||||
.map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
|
||||
.map(|encoding| Encoding::Hf(Box::new(encoding)))
|
||||
}
|
||||
|
||||
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||
let encodings = self
|
||||
.tokenizer
|
||||
.encode_batch(inputs.to_vec(), false)
|
||||
.map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
|
||||
|
||||
Ok(encodings
|
||||
.into_iter()
|
||||
.map(|e| Encoding::Hf(Box::new(e)))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for HuggingFaceTokenizer {
|
||||
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
|
||||
self.tokenizer
|
||||
.decode(token_ids, skip_special_tokens)
|
||||
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenizerTrait for HuggingFaceTokenizer {
|
||||
fn vocab_size(&self) -> usize {
|
||||
self.tokenizer.get_vocab_size(false)
|
||||
}
|
||||
|
||||
fn get_special_tokens(&self) -> &SpecialTokens {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
|
||||
self.reverse_vocab.get(&id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ChatMessage;
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(msg.role, "system");
|
||||
assert_eq!(msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
}
|
||||
|
||||
// Note: Actual tokenizer tests would require a real tokenizer file
|
||||
// These would be integration tests rather than unit tests
|
||||
}
|
||||
112
sgl-router/src/tokenizer/mock.rs
Normal file
112
sgl-router/src/tokenizer/mock.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
//! Mock tokenizer implementation for testing
|
||||
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Mock tokenizer for testing purposes
|
||||
pub struct MockTokenizer {
|
||||
vocab: HashMap<String, u32>,
|
||||
reverse_vocab: HashMap<u32, String>,
|
||||
special_tokens: SpecialTokens,
|
||||
}
|
||||
|
||||
impl Default for MockTokenizer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MockTokenizer {
|
||||
pub fn new() -> Self {
|
||||
let mut vocab = HashMap::new();
|
||||
let mut reverse_vocab = HashMap::new();
|
||||
|
||||
// Add some basic tokens
|
||||
let tokens = vec![
|
||||
("Hello", 1),
|
||||
("world", 2),
|
||||
("test", 3),
|
||||
("token", 4),
|
||||
(" ", 5),
|
||||
(".", 6),
|
||||
("<eos>", 999),
|
||||
("<bos>", 1000),
|
||||
];
|
||||
|
||||
for (token, id) in tokens {
|
||||
vocab.insert(token.to_string(), id);
|
||||
reverse_vocab.insert(id, token.to_string());
|
||||
}
|
||||
|
||||
let special_tokens = SpecialTokens {
|
||||
bos_token: Some("<bos>".to_string()),
|
||||
eos_token: Some("<eos>".to_string()),
|
||||
unk_token: Some("<unk>".to_string()),
|
||||
sep_token: None,
|
||||
pad_token: None,
|
||||
cls_token: None,
|
||||
mask_token: None,
|
||||
additional_special_tokens: vec![],
|
||||
};
|
||||
|
||||
Self {
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
special_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for MockTokenizer {
|
||||
fn encode(&self, input: &str) -> Result<Encoding> {
|
||||
// Simple word-based tokenization for testing
|
||||
let tokens: Vec<u32> = input
|
||||
.split_whitespace()
|
||||
.filter_map(|word| self.vocab.get(word).copied())
|
||||
.collect();
|
||||
|
||||
Ok(Encoding::Sp(tokens))
|
||||
}
|
||||
|
||||
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||
inputs.iter().map(|input| self.encode(input)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for MockTokenizer {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
let tokens: Vec<String> = token_ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
self.reverse_vocab.get(id).and_then(|token| {
|
||||
if skip_special_tokens && (token == "<eos>" || token == "<bos>") {
|
||||
None
|
||||
} else {
|
||||
Some(token.clone())
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(tokens.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenizerTrait for MockTokenizer {
|
||||
fn vocab_size(&self) -> usize {
|
||||
self.vocab.len()
|
||||
}
|
||||
|
||||
fn get_special_tokens(&self) -> &SpecialTokens {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
self.reverse_vocab.get(&id).cloned()
|
||||
}
|
||||
}
|
||||
123
sgl-router/src/tokenizer/mod.rs
Normal file
123
sgl-router/src/tokenizer/mod.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use anyhow::Result;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod factory;
|
||||
pub mod hub;
|
||||
pub mod mock;
|
||||
pub mod sequence;
|
||||
pub mod stop;
|
||||
pub mod stream;
|
||||
pub mod traits;
|
||||
|
||||
// Feature-gated modules
|
||||
|
||||
pub mod chat_template;
|
||||
|
||||
pub mod huggingface;
|
||||
|
||||
pub mod tiktoken;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// Re-exports
|
||||
pub use factory::{
|
||||
create_tokenizer, create_tokenizer_async, create_tokenizer_from_file,
|
||||
create_tokenizer_with_chat_template, TokenizerType,
|
||||
};
|
||||
pub use sequence::Sequence;
|
||||
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
||||
pub use stream::DecodeStream;
|
||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
|
||||
pub use huggingface::HuggingFaceTokenizer;
|
||||
|
||||
pub use chat_template::ChatMessage;
|
||||
|
||||
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
||||
|
||||
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
||||
#[derive(Clone)]
|
||||
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
|
||||
|
||||
impl Tokenizer {
|
||||
/// Create a tokenizer from a file path
|
||||
pub fn from_file(file_path: &str) -> Result<Tokenizer> {
|
||||
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path with an optional chat template
|
||||
pub fn from_file_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Tokenizer> {
|
||||
Ok(Tokenizer(factory::create_tokenizer_with_chat_template(
|
||||
file_path,
|
||||
chat_template_path,
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Create a tokenizer from an Arc<dyn Tokenizer>
|
||||
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||
Tokenizer(tokenizer)
|
||||
}
|
||||
|
||||
/// Create a stateful sequence object for decoding token_ids into text
|
||||
pub fn decode_stream(
|
||||
&self,
|
||||
prompt_token_ids: &[u32],
|
||||
skip_special_tokens: bool,
|
||||
) -> DecodeStream {
|
||||
DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
|
||||
}
|
||||
|
||||
/// Direct encode method
|
||||
pub fn encode(&self, input: &str) -> Result<Encoding> {
|
||||
self.0.encode(input)
|
||||
}
|
||||
|
||||
/// Direct batch encode method
|
||||
pub fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||
self.0.encode_batch(inputs)
|
||||
}
|
||||
|
||||
/// Direct decode method
|
||||
pub fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
self.0.decode(token_ids, skip_special_tokens)
|
||||
}
|
||||
|
||||
/// Get vocabulary size
|
||||
pub fn vocab_size(&self) -> usize {
|
||||
self.0.vocab_size()
|
||||
}
|
||||
|
||||
/// Get special tokens
|
||||
pub fn get_special_tokens(&self) -> &SpecialTokens {
|
||||
self.0.get_special_tokens()
|
||||
}
|
||||
|
||||
/// Convert token string to ID
|
||||
pub fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
self.0.token_to_id(token)
|
||||
}
|
||||
|
||||
/// Convert ID to token string
|
||||
pub fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
self.0.id_to_token(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Tokenizer {
|
||||
type Target = Arc<dyn traits::Tokenizer>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
|
||||
fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||
Tokenizer(tokenizer)
|
||||
}
|
||||
}
|
||||
238
sgl-router/src/tokenizer/sequence.rs
Normal file
238
sgl-router/src/tokenizer/sequence.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Maintains state for an ongoing sequence of tokens and their decoded text
|
||||
/// This provides a cleaner abstraction for managing token sequences
|
||||
pub struct Sequence {
|
||||
/// The tokenizer used for encoding/decoding
|
||||
tokenizer: Arc<dyn TokenizerTrait>,
|
||||
|
||||
/// The current sequence of token ids
|
||||
token_ids: Vec<TokenIdType>,
|
||||
|
||||
/// The position in the current sequence the last decoded token completed
|
||||
prefix_offset: usize,
|
||||
|
||||
/// Current position in the sequence
|
||||
read_offset: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Sequence {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Sequence")
|
||||
.field("tokenizer", &"Arc<dyn Tokenizer>")
|
||||
.field(
|
||||
"token_ids",
|
||||
&format_args!("{}", {
|
||||
let token_ids = self.token_ids();
|
||||
if token_ids.len() <= 20 {
|
||||
format!("{:?}", token_ids)
|
||||
} else {
|
||||
let first_ten = &token_ids[..10];
|
||||
let last_ten = &token_ids[token_ids.len() - 10..];
|
||||
format!("{:?} ... {:?}", first_ten, last_ten)
|
||||
}
|
||||
}),
|
||||
)
|
||||
.field("prefix_offset", &self.prefix_offset)
|
||||
.field("read_offset", &self.read_offset)
|
||||
.field("token count", &self.token_ids.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Sequence {
|
||||
/// Create a new empty sequence
|
||||
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
token_ids: Vec::new(),
|
||||
prefix_offset: 0,
|
||||
read_offset: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a sequence with initial tokens
|
||||
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
|
||||
let len = token_ids.len();
|
||||
Self {
|
||||
tokenizer,
|
||||
token_ids,
|
||||
prefix_offset: 0,
|
||||
read_offset: len,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the sequence is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.token_ids.is_empty()
|
||||
}
|
||||
|
||||
/// Get the length of the sequence
|
||||
pub fn len(&self) -> usize {
|
||||
self.token_ids.len()
|
||||
}
|
||||
|
||||
/// Clear the sequence
|
||||
pub fn clear(&mut self) {
|
||||
self.token_ids.clear();
|
||||
self.prefix_offset = 0;
|
||||
self.read_offset = 0;
|
||||
}
|
||||
|
||||
/// Append text to the sequence by encoding it
|
||||
pub fn append_text(&mut self, input: &str) -> Result<()> {
|
||||
let encoding = self.tokenizer.encode(input)?;
|
||||
self.token_ids.extend(encoding.token_ids());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Append a single token to the sequence and return newly decoded text
|
||||
/// Based on HuggingFace TGI incremental decoding
|
||||
pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
|
||||
// Store the old read offset before adding the new token
|
||||
let old_read_offset = self.read_offset;
|
||||
|
||||
self.token_ids.push(token_id);
|
||||
self.read_offset = self.token_ids.len();
|
||||
|
||||
// If this is the first token or we're at the beginning, decode everything
|
||||
if self.prefix_offset == 0 && old_read_offset == 0 {
|
||||
let text = self.tokenizer.decode(&self.token_ids, false)?;
|
||||
if text.ends_with("<EFBFBD>") {
|
||||
// Incomplete UTF-8 sequence, wait for more tokens
|
||||
return Ok(String::new());
|
||||
}
|
||||
self.prefix_offset = 0;
|
||||
return Ok(text);
|
||||
}
|
||||
|
||||
// Decode the text up to the previous position
|
||||
let prefix_text = self
|
||||
.tokenizer
|
||||
.decode(&self.token_ids[self.prefix_offset..old_read_offset], false)?;
|
||||
|
||||
// Decode the text including the new token
|
||||
let new_text = self
|
||||
.tokenizer
|
||||
.decode(&self.token_ids[self.prefix_offset..], false)?;
|
||||
|
||||
// Handle multi-byte character boundaries
|
||||
let mut prefix_text_len = prefix_text.len();
|
||||
while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
|
||||
prefix_text_len -= 1;
|
||||
}
|
||||
|
||||
if new_text.len() > prefix_text.len() {
|
||||
if new_text.ends_with("<EFBFBD>") {
|
||||
// Incomplete UTF-8 sequence, wait for more tokens
|
||||
return Ok(String::new());
|
||||
} else {
|
||||
// Return the new text portion
|
||||
let incremental_text = new_text[prefix_text_len..].to_string().replace("<EFBFBD>", "");
|
||||
self.prefix_offset = old_read_offset;
|
||||
return Ok(incremental_text);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
/// Get a reference to the tokenizer
|
||||
pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
/// Get the current token ids
|
||||
pub fn token_ids(&self) -> &[TokenIdType] {
|
||||
&self.token_ids
|
||||
}
|
||||
|
||||
/// Decode the entire sequence to text
|
||||
pub fn text(&self) -> Result<String> {
|
||||
self.tokenizer.decode(&self.token_ids, false)
|
||||
}
|
||||
|
||||
/// Get the prefix offset
|
||||
pub fn prefix_offset(&self) -> usize {
|
||||
self.prefix_offset
|
||||
}
|
||||
|
||||
/// Get the read offset
|
||||
pub fn read_offset(&self) -> usize {
|
||||
self.read_offset
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tokenizer::mock::MockTokenizer;
|
||||
|
||||
#[test]
|
||||
fn test_sequence_new() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let seq = Sequence::new(tokenizer);
|
||||
assert!(seq.is_empty());
|
||||
assert_eq!(seq.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sequence_append_text() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let mut seq = Sequence::new(tokenizer);
|
||||
|
||||
seq.append_text("Hello").unwrap();
|
||||
assert!(!seq.is_empty());
|
||||
assert!(!seq.is_empty());
|
||||
|
||||
let text = seq.text().unwrap();
|
||||
assert_eq!(text, "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sequence_append_token() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let mut seq = Sequence::new(tokenizer.clone());
|
||||
|
||||
// Start with an empty sequence and append token 1 ("Hello")
|
||||
let text1 = seq.append_token(1).unwrap();
|
||||
assert_eq!(text1, "Hello");
|
||||
|
||||
// Now append token 2 ("world")
|
||||
// The mock tokenizer will decode [1, 2] as "Hello world" (with a space)
|
||||
let text2 = seq.append_token(2).unwrap();
|
||||
// The incremental text should be " world" (with the space that the mock tokenizer adds)
|
||||
assert_eq!(text2, " world");
|
||||
|
||||
// Verify the full text
|
||||
assert_eq!(seq.text().unwrap(), "Hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sequence_clear() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let mut seq = Sequence::new(tokenizer);
|
||||
|
||||
seq.append_text("Hello world").unwrap();
|
||||
assert!(!seq.is_empty());
|
||||
|
||||
seq.clear();
|
||||
assert!(seq.is_empty());
|
||||
assert_eq!(seq.len(), 0);
|
||||
assert_eq!(seq.prefix_offset(), 0);
|
||||
assert_eq!(seq.read_offset(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sequence_debug() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let mut seq = Sequence::new(tokenizer);
|
||||
|
||||
seq.append_text("Test").unwrap();
|
||||
let debug_str = format!("{:?}", seq);
|
||||
assert!(debug_str.contains("Sequence"));
|
||||
assert!(debug_str.contains("token count"));
|
||||
}
|
||||
}
|
||||
506
sgl-router/src/tokenizer/stop.rs
Normal file
506
sgl-router/src/tokenizer/stop.rs
Normal file
@@ -0,0 +1,506 @@
|
||||
use super::traits::{self, TokenIdType};
|
||||
use anyhow::Result;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Output from the sequence decoder
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SequenceDecoderOutput {
|
||||
/// Normal text output
|
||||
Text(String),
|
||||
/// Text is being held due to partial stop sequence match
|
||||
Held,
|
||||
/// Stop sequence matched (hidden - not included in output)
|
||||
Stopped,
|
||||
/// Stop sequence matched with text (visible - included in output)
|
||||
StoppedWithText(String),
|
||||
}
|
||||
|
||||
/// Configuration for stop sequences
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StopSequenceConfig {
|
||||
/// Token IDs that trigger a stop
|
||||
pub stop_tokens: HashSet<TokenIdType>,
|
||||
/// String sequences that trigger a stop
|
||||
pub stop_sequences: Vec<String>,
|
||||
/// Token IDs for visible stops (included in output)
|
||||
pub visible_stop_tokens: HashSet<TokenIdType>,
|
||||
/// String sequences for visible stops (included in output)
|
||||
pub visible_stop_sequences: Vec<String>,
|
||||
}
|
||||
|
||||
impl StopSequenceConfig {
|
||||
/// Builder pattern - add a stop token
|
||||
pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder pattern - add a stop sequence
|
||||
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder pattern - add a visible stop token
|
||||
pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.visible_stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder pattern - add a visible stop sequence
|
||||
pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.visible_stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Decoder that handles stop sequences
|
||||
pub struct StopSequenceDecoder {
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
config: StopSequenceConfig,
|
||||
/// Buffer for partial matches (the "jail")
|
||||
jail_buffer: String,
|
||||
/// Accumulated tokens
|
||||
token_buffer: Vec<TokenIdType>,
|
||||
/// Offset where the prefix text starts (for context)
|
||||
prefix_offset: usize,
|
||||
/// Offset marking the end of previously decoded text
|
||||
read_offset: usize,
|
||||
/// Whether we've stopped
|
||||
stopped: bool,
|
||||
skip_special_tokens: bool,
|
||||
}
|
||||
|
||||
impl StopSequenceDecoder {
|
||||
/// Create a new stop sequence decoder
|
||||
pub fn new(
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
config: StopSequenceConfig,
|
||||
skip_special_tokens: bool,
|
||||
) -> Self {
|
||||
StopSequenceDecoder {
|
||||
tokenizer,
|
||||
config,
|
||||
jail_buffer: String::new(),
|
||||
token_buffer: Vec::new(),
|
||||
prefix_offset: 0,
|
||||
read_offset: 0,
|
||||
stopped: false,
|
||||
skip_special_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single token
|
||||
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
|
||||
if self.stopped {
|
||||
return Ok(SequenceDecoderOutput::Stopped);
|
||||
}
|
||||
|
||||
// Check for token-level stops first
|
||||
if self.config.stop_tokens.contains(&token_id) {
|
||||
self.stopped = true;
|
||||
|
||||
// Flush any jailed text before stopping
|
||||
if !self.jail_buffer.is_empty() {
|
||||
let output = self.jail_buffer.clone();
|
||||
self.jail_buffer.clear();
|
||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||
}
|
||||
return Ok(SequenceDecoderOutput::Stopped);
|
||||
}
|
||||
|
||||
if self.config.visible_stop_tokens.contains(&token_id) {
|
||||
self.stopped = true;
|
||||
|
||||
// Include jailed text plus the stop token
|
||||
let stop_text = self
|
||||
.tokenizer
|
||||
.decode(&[token_id], self.skip_special_tokens)?;
|
||||
let output = format!("{}{}", self.jail_buffer, stop_text);
|
||||
self.jail_buffer.clear();
|
||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||
}
|
||||
|
||||
// Add token to buffer
|
||||
self.token_buffer.push(token_id);
|
||||
|
||||
// Use incremental decoding like DecodeStream
|
||||
// First decode the previous context (what we've already output)
|
||||
let prefix_text = if self.read_offset > self.prefix_offset {
|
||||
self.tokenizer.decode(
|
||||
&self.token_buffer[self.prefix_offset..self.read_offset],
|
||||
self.skip_special_tokens,
|
||||
)?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Now decode from prefix to current position
|
||||
let new_full_text = self.tokenizer.decode(
|
||||
&self.token_buffer[self.prefix_offset..],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
// Check for incomplete UTF-8 sequence
|
||||
if new_full_text.ends_with("<EFBFBD>") {
|
||||
// Wait for more tokens to complete the sequence
|
||||
return Ok(SequenceDecoderOutput::Held);
|
||||
}
|
||||
|
||||
// Calculate only the NEW text since last successful decode
|
||||
let new_text = if new_full_text.len() > prefix_text.len() {
|
||||
&new_full_text[prefix_text.len()..]
|
||||
} else {
|
||||
// No new text produced (can happen with special tokens)
|
||||
return Ok(SequenceDecoderOutput::Held);
|
||||
};
|
||||
|
||||
// Combine jail buffer with new text for checking
|
||||
let check_text = format!("{}{}", self.jail_buffer, new_text);
|
||||
|
||||
// Check for complete stop sequences
|
||||
for stop_seq in &self.config.stop_sequences {
|
||||
if let Some(pos) = check_text.find(stop_seq) {
|
||||
self.stopped = true;
|
||||
|
||||
// Output text before the stop sequence
|
||||
let output = check_text[..pos].to_string();
|
||||
self.jail_buffer.clear();
|
||||
return Ok(if output.is_empty() {
|
||||
SequenceDecoderOutput::Stopped
|
||||
} else {
|
||||
SequenceDecoderOutput::StoppedWithText(output)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check for visible stop sequences
|
||||
for stop_seq in &self.config.visible_stop_sequences {
|
||||
if let Some(pos) = check_text.find(stop_seq) {
|
||||
self.stopped = true;
|
||||
|
||||
// Include the stop sequence in output
|
||||
let end_pos = pos + stop_seq.len();
|
||||
let output = check_text[..end_pos].to_string();
|
||||
self.jail_buffer.clear();
|
||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for partial matches at the end of check_text
|
||||
let mut partial_match_len = 0;
|
||||
for stop_seq in self
|
||||
.config
|
||||
.stop_sequences
|
||||
.iter()
|
||||
.chain(&self.config.visible_stop_sequences)
|
||||
{
|
||||
// Check all possible suffixes that could be a prefix of stop_seq
|
||||
for i in 1..=check_text.len().min(stop_seq.len() - 1) {
|
||||
let suffix = &check_text[check_text.len() - i..];
|
||||
if stop_seq.starts_with(suffix) {
|
||||
partial_match_len = partial_match_len.max(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if partial_match_len > 0 {
|
||||
// Split: output safe text, jail the potential match
|
||||
let safe_end = check_text.len() - partial_match_len;
|
||||
let safe_text = &check_text[..safe_end];
|
||||
self.jail_buffer = check_text[safe_end..].to_string();
|
||||
|
||||
// Update offsets for next iteration
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.token_buffer.len();
|
||||
|
||||
if safe_text.is_empty() {
|
||||
Ok(SequenceDecoderOutput::Held)
|
||||
} else {
|
||||
Ok(SequenceDecoderOutput::Text(safe_text.to_string()))
|
||||
}
|
||||
} else {
|
||||
// No partial matches - output everything
|
||||
self.jail_buffer.clear();
|
||||
|
||||
// Update offsets for next iteration
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.token_buffer.len();
|
||||
|
||||
Ok(SequenceDecoderOutput::Text(check_text))
|
||||
}
|
||||
}
|
||||
|
||||
/// Process multiple tokens
|
||||
pub fn process_tokens(
|
||||
&mut self,
|
||||
token_ids: &[TokenIdType],
|
||||
) -> Result<Vec<SequenceDecoderOutput>> {
|
||||
let mut outputs = Vec::new();
|
||||
for &token_id in token_ids {
|
||||
outputs.push(self.process_token(token_id)?);
|
||||
}
|
||||
Ok(outputs)
|
||||
}
|
||||
|
||||
/// Flush any held text
|
||||
pub fn flush(&mut self) -> SequenceDecoderOutput {
|
||||
if !self.jail_buffer.is_empty() {
|
||||
let output = self.jail_buffer.clone();
|
||||
self.jail_buffer.clear();
|
||||
SequenceDecoderOutput::Text(output)
|
||||
} else {
|
||||
SequenceDecoderOutput::Text(String::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if decoding has stopped
|
||||
pub fn is_stopped(&self) -> bool {
|
||||
self.stopped
|
||||
}
|
||||
|
||||
/// Reset the decoder state
|
||||
pub fn reset(&mut self) {
|
||||
self.jail_buffer.clear();
|
||||
self.token_buffer.clear();
|
||||
self.prefix_offset = 0;
|
||||
self.read_offset = 0;
|
||||
self.stopped = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for StopSequenceDecoder
|
||||
pub struct StopSequenceDecoderBuilder {
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
config: StopSequenceConfig,
|
||||
skip_special_tokens: bool,
|
||||
}
|
||||
|
||||
impl StopSequenceDecoderBuilder {
|
||||
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||
StopSequenceDecoderBuilder {
|
||||
tokenizer,
|
||||
config: StopSequenceConfig::default(),
|
||||
skip_special_tokens: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.config.stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.config.stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.config.visible_stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.config.visible_stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn skip_special_tokens(mut self, skip: bool) -> Self {
|
||||
self.skip_special_tokens = skip;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> StopSequenceDecoder {
|
||||
StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tokenizer::mock::MockTokenizer;
|
||||
|
||||
#[test]
|
||||
fn test_stop_token_detection() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_stop_token(999); // <eos> token
|
||||
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process tokens before stop
|
||||
let result = decoder.process_token(1).unwrap(); // "Hello"
|
||||
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||||
|
||||
// Process stop token
|
||||
let result = decoder.process_token(999).unwrap(); // <eos>
|
||||
assert_eq!(result, SequenceDecoderOutput::Stopped);
|
||||
|
||||
// Further tokens should also return Stopped
|
||||
let result = decoder.process_token(2).unwrap();
|
||||
assert_eq!(result, SequenceDecoderOutput::Stopped);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visible_stop_token() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_visible_stop_token(999);
|
||||
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
let result = decoder.process_token(999).unwrap();
|
||||
assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_pattern() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
|
||||
let decoder = StopSequenceDecoderBuilder::new(tokenizer)
|
||||
.stop_token(999)
|
||||
.stop_sequence("STOP")
|
||||
.visible_stop_token(1000)
|
||||
.skip_special_tokens(true)
|
||||
.build();
|
||||
|
||||
assert!(!decoder.is_stopped());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_incremental_decoding_no_repetition() {
|
||||
// This test verifies the critical fix: no repeated output
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default();
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process tokens one by one and collect outputs
|
||||
let mut outputs = Vec::new();
|
||||
|
||||
// Token 1: "Hello"
|
||||
let result = decoder.process_token(1).unwrap();
|
||||
if let SequenceDecoderOutput::Text(text) = result {
|
||||
outputs.push(text.clone());
|
||||
}
|
||||
|
||||
// Token 2: "world"
|
||||
let result = decoder.process_token(2).unwrap();
|
||||
if let SequenceDecoderOutput::Text(text) = result {
|
||||
outputs.push(text.clone());
|
||||
}
|
||||
|
||||
// Token 3: "test"
|
||||
let result = decoder.process_token(3).unwrap();
|
||||
if let SequenceDecoderOutput::Text(text) = result {
|
||||
outputs.push(text.clone());
|
||||
}
|
||||
|
||||
// CRITICAL: Each output should be unique (no accumulation)
|
||||
// The fix ensures we only output NEW text, not accumulated text
|
||||
assert_eq!(outputs.len(), 3);
|
||||
|
||||
// Verify no text is repeated
|
||||
for i in 0..outputs.len() {
|
||||
for j in i + 1..outputs.len() {
|
||||
// No output should contain another (no accumulation)
|
||||
assert!(!outputs[j].contains(&outputs[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stop_sequence_detection() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_stop_sequence("test");
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process "Hello world"
|
||||
decoder.process_token(1).unwrap(); // "Hello"
|
||||
decoder.process_token(2).unwrap(); // "world"
|
||||
|
||||
// Process "test" which should trigger stop
|
||||
let result = decoder.process_token(3).unwrap(); // "test"
|
||||
|
||||
// Should stop when we hit "test"
|
||||
assert!(matches!(
|
||||
result,
|
||||
SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flush_after_partial() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process a token
|
||||
decoder.process_token(1).unwrap(); // "Hello"
|
||||
|
||||
// Flush should return any remaining text in jail
|
||||
let result = decoder.flush();
|
||||
|
||||
// After processing, flush should work
|
||||
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_functionality() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_stop_token(999);
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process and stop
|
||||
decoder.process_token(1).unwrap();
|
||||
decoder.process_token(999).unwrap();
|
||||
assert!(decoder.is_stopped());
|
||||
|
||||
// Reset should clear everything
|
||||
decoder.reset();
|
||||
assert!(!decoder.is_stopped());
|
||||
|
||||
// Should be able to process again
|
||||
let result = decoder.process_token(2).unwrap();
|
||||
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visible_stop_sequence() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process "Hello"
|
||||
decoder.process_token(1).unwrap();
|
||||
|
||||
// Process "world" - should include it in output
|
||||
let result = decoder.process_token(2).unwrap();
|
||||
|
||||
if let SequenceDecoderOutput::StoppedWithText(text) = result {
|
||||
// Should include "world" in the output
|
||||
assert!(text.contains("world"));
|
||||
} else {
|
||||
panic!("Expected StoppedWithText with visible stop sequence");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_tokens_processing() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default();
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process multiple tokens at once
|
||||
let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
|
||||
|
||||
// Should get results for each token
|
||||
assert_eq!(results.len(), 3);
|
||||
|
||||
// Each result should be Text (no stops configured)
|
||||
for result in results {
|
||||
assert!(matches!(
|
||||
result,
|
||||
SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
105
sgl-router/src/tokenizer/stream.rs
Normal file
105
sgl-router/src/tokenizer/stream.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
// src/tokenizer/stream.rs
|
||||
|
||||
use super::traits::{self, TokenIdType};
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
|
||||
|
||||
/// DecodeStream will keep the state necessary to produce individual chunks of
|
||||
/// strings given an input stream of token_ids
|
||||
pub struct DecodeStream {
|
||||
/// The tokenizer used to decode token_ids
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
|
||||
skip_special_tokens: bool,
|
||||
|
||||
/// A temporary buffer of the necessary token_ids needed
|
||||
/// to produce valid string chunks
|
||||
all_token_ids: Vec<TokenIdType>,
|
||||
|
||||
prefix_offset: usize,
|
||||
read_offset: usize,
|
||||
}
|
||||
|
||||
impl DecodeStream {
|
||||
pub fn new(
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
prompt_token_ids: &[TokenIdType],
|
||||
skip_special_tokens: bool,
|
||||
) -> Self {
|
||||
let num_input_tokens = prompt_token_ids.len();
|
||||
let prompt_token_ids = prompt_token_ids.to_vec();
|
||||
Self {
|
||||
tokenizer,
|
||||
skip_special_tokens,
|
||||
all_token_ids: prompt_token_ids,
|
||||
prefix_offset: num_input_tokens
|
||||
.saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
|
||||
read_offset: num_input_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
/// Step appends a token_id to the internal state and tries to produce a text chunk.
|
||||
/// Returning `None` means the given id is not enough to produce a chunk.
|
||||
pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
|
||||
self.all_token_ids.push(id);
|
||||
|
||||
let prefix_text = self.tokenizer.decode(
|
||||
&self.all_token_ids[self.prefix_offset..self.read_offset],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
let new_text = self.tokenizer.decode(
|
||||
&self.all_token_ids[self.prefix_offset..],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
if new_text.len() > prefix_text.len() && !new_text.ends_with("<EFBFBD>") {
|
||||
let new_text = new_text[prefix_text.len()..].to_string();
|
||||
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.all_token_ids.len();
|
||||
|
||||
Ok(Some(new_text))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Process multiple tokens at once
|
||||
pub fn step_batch(&mut self, token_ids: &[u32]) -> Result<Vec<String>> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
for &token_id in token_ids {
|
||||
if let Some(text) = self.step(token_id)? {
|
||||
chunks.push(text);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
/// Force flush any remaining text
|
||||
pub fn flush(&mut self) -> Result<Option<String>> {
|
||||
if self.read_offset < self.all_token_ids.len() {
|
||||
let remaining = self.tokenizer.decode(
|
||||
&self.all_token_ids[self.read_offset..],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
self.read_offset = self.all_token_ids.len();
|
||||
|
||||
if !remaining.is_empty() {
|
||||
return Ok(Some(remaining));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Get all tokens processed so far
|
||||
pub fn tokens(&self) -> &[u32] {
|
||||
&self.all_token_ids
|
||||
}
|
||||
}
|
||||
143
sgl-router/src/tokenizer/tests.rs
Normal file
143
sgl-router/src/tokenizer/tests.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
#[cfg(test)]
|
||||
use super::*;
|
||||
#[cfg(test)]
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_mock_tokenizer_encode() {
|
||||
let tokenizer = mock::MockTokenizer::new();
|
||||
let encoding = tokenizer.encode("Hello world").unwrap();
|
||||
let token_ids = encoding.token_ids();
|
||||
assert_eq!(token_ids, &[1, 2]); // "Hello" -> 1, "world" -> 2
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mock_tokenizer_decode() {
|
||||
let tokenizer = mock::MockTokenizer::new();
|
||||
let text = tokenizer.decode(&[1, 2], false).unwrap();
|
||||
assert_eq!(text, "Hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mock_tokenizer_decode_skip_special() {
|
||||
let tokenizer = mock::MockTokenizer::new();
|
||||
|
||||
// With special tokens
|
||||
let text = tokenizer.decode(&[1000, 1, 2, 999], false).unwrap();
|
||||
assert_eq!(text, "<bos> Hello world <eos>");
|
||||
|
||||
// Without special tokens
|
||||
let text = tokenizer.decode(&[1000, 1, 2, 999], true).unwrap();
|
||||
assert_eq!(text, "Hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_wrapper() {
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
// Test encoding
|
||||
let encoding = tokenizer.encode("Hello world").unwrap();
|
||||
assert_eq!(encoding.token_ids(), &[1, 2]);
|
||||
|
||||
// Test decoding
|
||||
let text = tokenizer.decode(&[1, 2], false).unwrap();
|
||||
assert_eq!(text, "Hello world");
|
||||
|
||||
// Test vocab size
|
||||
assert_eq!(tokenizer.vocab_size(), 8);
|
||||
|
||||
// Test token to ID
|
||||
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
|
||||
assert_eq!(tokenizer.token_to_id("unknown"), None);
|
||||
|
||||
// Test ID to token
|
||||
assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string()));
|
||||
assert_eq!(tokenizer.id_to_token(9999), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_stream_basic() {
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
// Create a decode stream with initial tokens
|
||||
let initial_tokens = vec![1, 2]; // "Hello world"
|
||||
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
|
||||
|
||||
// Add a new token
|
||||
let result = stream.step(3).unwrap(); // "test"
|
||||
// Since we're using a mock, the actual incremental behavior depends on implementation
|
||||
// For now, we just verify it doesn't crash
|
||||
assert!(result.is_some() || result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_stream_flush() {
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
let initial_tokens = vec![1];
|
||||
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
|
||||
|
||||
// Add tokens
|
||||
stream.step(2).unwrap();
|
||||
stream.step(3).unwrap();
|
||||
|
||||
// Flush remaining
|
||||
let flushed = stream.flush().unwrap();
|
||||
// The flush behavior depends on the implementation
|
||||
assert!(flushed.is_some() || flushed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_special_tokens() {
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
let special_tokens = tokenizer.get_special_tokens();
|
||||
assert_eq!(special_tokens.bos_token, Some("<bos>".to_string()));
|
||||
assert_eq!(special_tokens.eos_token, Some("<eos>".to_string()));
|
||||
assert_eq!(special_tokens.unk_token, Some("<unk>".to_string()));
|
||||
assert!(special_tokens.sep_token.is_none());
|
||||
assert!(special_tokens.pad_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_encode() {
|
||||
let tokenizer = mock::MockTokenizer::new();
|
||||
let inputs = vec!["Hello", "world", "test"];
|
||||
let encodings = tokenizer.encode_batch(&inputs).unwrap();
|
||||
|
||||
assert_eq!(encodings.len(), 3);
|
||||
assert_eq!(encodings[0].token_ids(), &[1]); // "Hello" -> 1
|
||||
assert_eq!(encodings[1].token_ids(), &[2]); // "world" -> 2
|
||||
assert_eq!(encodings[2].token_ids(), &[3]); // "test" -> 3
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thread_safety() {
|
||||
use std::thread;
|
||||
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
// Spawn multiple threads that use the same tokenizer
|
||||
let handles: Vec<_> = (0..10)
|
||||
.map(|i| {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
thread::spawn(move || {
|
||||
let text = "Hello test".to_string();
|
||||
let encoding = tokenizer_clone.encode(&text).unwrap();
|
||||
let decoded = tokenizer_clone.decode(encoding.token_ids(), false).unwrap();
|
||||
assert!(decoded.contains("Hello") || decoded.contains("test"));
|
||||
i
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Wait for all threads to complete
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
276
sgl-router/src/tokenizer/tiktoken.rs
Normal file
276
sgl-router/src/tokenizer/tiktoken.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use anyhow::{Error, Result};
|
||||
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
|
||||
|
||||
/// Tiktoken tokenizer wrapper for OpenAI GPT models
|
||||
pub struct TiktokenTokenizer {
|
||||
tokenizer: CoreBPE,
|
||||
#[allow(dead_code)]
|
||||
model: TiktokenModel,
|
||||
special_tokens: SpecialTokens,
|
||||
vocab_size: usize,
|
||||
}
|
||||
|
||||
/// Supported Tiktoken models
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TiktokenModel {
|
||||
/// GPT-4, GPT-3.5-turbo, text-embedding-ada-002
|
||||
Cl100kBase,
|
||||
/// Codex models, text-davinci-002, text-davinci-003
|
||||
P50kBase,
|
||||
/// Use for edit models like text-davinci-edit-001, code-davinci-edit-001
|
||||
P50kEdit,
|
||||
/// GPT-3 models like davinci
|
||||
R50kBase,
|
||||
}
|
||||
|
||||
impl TiktokenTokenizer {
|
||||
/// Create a new Tiktoken tokenizer for the specified model
|
||||
pub fn new(model: TiktokenModel) -> Result<Self> {
|
||||
let tokenizer =
|
||||
match model {
|
||||
TiktokenModel::Cl100kBase => cl100k_base()
|
||||
.map_err(|e| Error::msg(format!("Failed to load cl100k_base: {}", e)))?,
|
||||
TiktokenModel::P50kBase => p50k_base()
|
||||
.map_err(|e| Error::msg(format!("Failed to load p50k_base: {}", e)))?,
|
||||
TiktokenModel::P50kEdit => p50k_edit()
|
||||
.map_err(|e| Error::msg(format!("Failed to load p50k_edit: {}", e)))?,
|
||||
TiktokenModel::R50kBase => r50k_base()
|
||||
.map_err(|e| Error::msg(format!("Failed to load r50k_base: {}", e)))?,
|
||||
};
|
||||
|
||||
// Extract special tokens (tiktoken-rs doesn't expose them directly)
|
||||
// We'll use common ones for GPT models
|
||||
let special_tokens = Self::get_special_tokens_for_model(model);
|
||||
|
||||
// Get vocabulary size (this is an approximation)
|
||||
let vocab_size = match model {
|
||||
TiktokenModel::Cl100kBase => 100256, // cl100k has ~100k tokens
|
||||
TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281, // p50k has ~50k tokens
|
||||
TiktokenModel::R50kBase => 50257, // r50k has ~50k tokens
|
||||
};
|
||||
|
||||
Ok(TiktokenTokenizer {
|
||||
tokenizer,
|
||||
model,
|
||||
special_tokens,
|
||||
vocab_size,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a model string (e.g., "gpt-4", "gpt-3.5-turbo")
|
||||
pub fn from_model_name(model_name: &str) -> Result<Self> {
|
||||
let model = Self::model_from_name(model_name)?;
|
||||
Self::new(model)
|
||||
}
|
||||
|
||||
/// Determine the appropriate model from a model name
|
||||
fn model_from_name(model_name: &str) -> Result<TiktokenModel> {
|
||||
// Based on OpenAI's model-to-encoding mapping
|
||||
if model_name.contains("gpt-4")
|
||||
|| model_name.contains("gpt-3.5")
|
||||
|| model_name.contains("turbo")
|
||||
{
|
||||
Ok(TiktokenModel::Cl100kBase)
|
||||
} else if model_name.contains("davinci-002")
|
||||
|| model_name.contains("davinci-003")
|
||||
|| model_name.contains("codex")
|
||||
{
|
||||
Ok(TiktokenModel::P50kBase)
|
||||
} else if model_name.contains("edit") {
|
||||
Ok(TiktokenModel::P50kEdit)
|
||||
} else if model_name.contains("davinci")
|
||||
|| model_name.contains("curie")
|
||||
|| model_name.contains("babbage")
|
||||
|| model_name.contains("ada")
|
||||
{
|
||||
Ok(TiktokenModel::R50kBase)
|
||||
} else {
|
||||
// Return an error for unrecognized model names to prevent silent failures
|
||||
Err(anyhow::anyhow!(
|
||||
"Unrecognized OpenAI model name: '{}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names",
|
||||
model_name
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get special tokens for a specific model
|
||||
fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
|
||||
// These are common special tokens for GPT models
|
||||
// The actual token IDs might vary by model
|
||||
match model {
|
||||
TiktokenModel::Cl100kBase => SpecialTokens {
|
||||
bos_token: Some("<|endoftext|>".to_string()),
|
||||
eos_token: Some("<|endoftext|>".to_string()),
|
||||
unk_token: None,
|
||||
sep_token: None,
|
||||
pad_token: Some("<|endoftext|>".to_string()),
|
||||
cls_token: None,
|
||||
mask_token: None,
|
||||
additional_special_tokens: vec![
|
||||
"<|fim_prefix|>".to_string(),
|
||||
"<|fim_middle|>".to_string(),
|
||||
"<|fim_suffix|>".to_string(),
|
||||
"<|endofprompt|>".to_string(),
|
||||
],
|
||||
},
|
||||
_ => SpecialTokens {
|
||||
bos_token: Some("<|endoftext|>".to_string()),
|
||||
eos_token: Some("<|endoftext|>".to_string()),
|
||||
unk_token: None,
|
||||
sep_token: None,
|
||||
pad_token: Some("<|endoftext|>".to_string()),
|
||||
cls_token: None,
|
||||
mask_token: None,
|
||||
additional_special_tokens: vec![],
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for TiktokenTokenizer {
|
||||
fn encode(&self, input: &str) -> Result<Encoding> {
|
||||
let tokens = self.tokenizer.encode_ordinary(input);
|
||||
Ok(Encoding::Tiktoken(tokens))
|
||||
}
|
||||
|
||||
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||
inputs.iter().map(|input| self.encode(input)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for TiktokenTokenizer {
|
||||
fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
|
||||
// tiktoken-rs 0.7.0 now uses u32 (Rank type)
|
||||
self.tokenizer
|
||||
.decode(token_ids.to_vec())
|
||||
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenizerTrait for TiktokenTokenizer {
|
||||
fn vocab_size(&self) -> usize {
|
||||
self.vocab_size
|
||||
}
|
||||
|
||||
fn get_special_tokens(&self) -> &SpecialTokens {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, _token: &str) -> Option<TokenIdType> {
|
||||
// Tiktoken doesn't provide direct token-to-id mapping
|
||||
// We'd need to encode the token and check if it produces a single ID
|
||||
None
|
||||
}
|
||||
|
||||
fn id_to_token(&self, _id: TokenIdType) -> Option<String> {
|
||||
// Tiktoken doesn't provide direct id-to-token mapping
|
||||
// We can only decode IDs to text
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tiktoken_creation() {
|
||||
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
|
||||
assert_eq!(tokenizer.vocab_size(), 100256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_from_name() {
|
||||
assert!(matches!(
|
||||
TiktokenTokenizer::model_from_name("gpt-4").unwrap(),
|
||||
TiktokenModel::Cl100kBase
|
||||
));
|
||||
assert!(matches!(
|
||||
TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(),
|
||||
TiktokenModel::Cl100kBase
|
||||
));
|
||||
assert!(matches!(
|
||||
TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(),
|
||||
TiktokenModel::P50kBase
|
||||
));
|
||||
assert!(matches!(
|
||||
TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(),
|
||||
TiktokenModel::P50kEdit
|
||||
));
|
||||
assert!(matches!(
|
||||
TiktokenTokenizer::model_from_name("davinci").unwrap(),
|
||||
TiktokenModel::R50kBase
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode() {
|
||||
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
|
||||
|
||||
let text = "Hello, world!";
|
||||
let encoding = tokenizer.encode(text).unwrap();
|
||||
|
||||
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
|
||||
assert_eq!(decoded, text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_encode() {
|
||||
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
|
||||
|
||||
let texts = vec!["Hello", "World", "Test"];
|
||||
let encodings = tokenizer.encode_batch(&texts).unwrap();
|
||||
|
||||
assert_eq!(encodings.len(), 3);
|
||||
for (i, encoding) in encodings.iter().enumerate() {
|
||||
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
|
||||
assert_eq!(decoded, texts[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_special_tokens() {
|
||||
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
|
||||
let special_tokens = tokenizer.get_special_tokens();
|
||||
|
||||
assert!(special_tokens.eos_token.is_some());
|
||||
assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unrecognized_model_name_returns_error() {
|
||||
// Test that unrecognized model names return an error
|
||||
let result = TiktokenTokenizer::from_model_name("distilgpt-2");
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("Unrecognized OpenAI model name"));
|
||||
}
|
||||
|
||||
let result = TiktokenTokenizer::from_model_name("bert-base-uncased");
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("Unrecognized OpenAI model name"));
|
||||
}
|
||||
|
||||
let result = TiktokenTokenizer::from_model_name("llama-7b");
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("Unrecognized OpenAI model name"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recognized_model_names() {
|
||||
// Test that recognized model names work correctly
|
||||
assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
|
||||
assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
|
||||
assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
|
||||
assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok());
|
||||
assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok());
|
||||
assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok());
|
||||
assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok());
|
||||
}
|
||||
}
|
||||
83
sgl-router/src/tokenizer/traits.rs
Normal file
83
sgl-router/src/tokenizer/traits.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use anyhow::Result;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Type alias for token IDs
|
||||
pub type TokenIdType = u32;
|
||||
|
||||
/// Core encoding trait - separate from decoding for modularity
|
||||
pub trait Encoder: Send + Sync {
|
||||
fn encode(&self, input: &str) -> Result<Encoding>;
|
||||
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
|
||||
}
|
||||
|
||||
/// Core decoding trait - can be implemented independently
|
||||
pub trait Decoder: Send + Sync {
|
||||
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
|
||||
}
|
||||
|
||||
/// Combined tokenizer trait
|
||||
pub trait Tokenizer: Encoder + Decoder {
|
||||
fn vocab_size(&self) -> usize;
|
||||
fn get_special_tokens(&self) -> &SpecialTokens;
|
||||
fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
|
||||
fn id_to_token(&self, id: TokenIdType) -> Option<String>;
|
||||
}
|
||||
|
||||
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Encoding {
|
||||
/// Hugging Face
|
||||
Hf(Box<tokenizers::tokenizer::Encoding>),
|
||||
/// Sentence Piece
|
||||
Sp(Vec<TokenIdType>),
|
||||
/// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
|
||||
Tiktoken(Vec<TokenIdType>),
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
/// Returns a reference to token IDs - zero-copy operation
|
||||
pub fn token_ids(&self) -> &[TokenIdType] {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids(),
|
||||
Encoding::Sp(inner) => inner,
|
||||
Encoding::Tiktoken(inner) => inner,
|
||||
}
|
||||
}
|
||||
|
||||
/// Deprecated: Use token_ids() instead (kept for compatibility)
|
||||
#[deprecated(since = "0.1.0", note = "Use token_ids() instead")]
|
||||
pub fn token_ids_ref(&self) -> &[TokenIdType] {
|
||||
self.token_ids()
|
||||
}
|
||||
|
||||
/// Get a hash of the token IDs for caching purposes
|
||||
pub fn get_hash(&self) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
self.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash implementation for Encoding
|
||||
impl Hash for Encoding {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids().hash(state),
|
||||
Encoding::Sp(inner) => inner.hash(state),
|
||||
Encoding::Tiktoken(inner) => inner.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SpecialTokens {
|
||||
pub bos_token: Option<String>,
|
||||
pub eos_token: Option<String>,
|
||||
pub unk_token: Option<String>,
|
||||
pub sep_token: Option<String>,
|
||||
pub pad_token: Option<String>,
|
||||
pub cls_token: Option<String>,
|
||||
pub mask_token: Option<String>,
|
||||
pub additional_special_tokens: Vec<String>,
|
||||
}
|
||||
32
sgl-router/src/tool_parser/errors.rs
Normal file
32
sgl-router/src/tool_parser/errors.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type for tool parser operations
|
||||
pub type ToolParserResult<T> = Result<T, ToolParserError>;
|
||||
|
||||
/// Errors that can occur during tool parsing
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ToolParserError {
|
||||
#[error("Parsing failed: {0}")]
|
||||
ParsingFailed(String),
|
||||
|
||||
#[error("Model not supported: {0}")]
|
||||
ModelNotSupported(String),
|
||||
|
||||
#[error("Parse depth exceeded: max {0}")]
|
||||
DepthExceeded(usize),
|
||||
|
||||
#[error("Invalid JSON: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
|
||||
#[error("Regex error: {0}")]
|
||||
RegexError(#[from] regex::Error),
|
||||
|
||||
#[error("Incomplete tool call")]
|
||||
Incomplete,
|
||||
|
||||
#[error("Invalid tool name: {0}")]
|
||||
InvalidToolName(String),
|
||||
|
||||
#[error("Token not found: {0}")]
|
||||
TokenNotFound(String),
|
||||
}
|
||||
30
sgl-router/src/tool_parser/mod.rs
Normal file
30
sgl-router/src/tool_parser/mod.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
/// Tool parser module for handling function/tool calls in model outputs
|
||||
///
|
||||
/// This module provides infrastructure for parsing tool calls from various model formats.
|
||||
// Core modules
|
||||
pub mod errors;
|
||||
pub mod partial_json;
|
||||
pub mod python_literal_parser;
|
||||
pub mod registry;
|
||||
pub mod state;
|
||||
pub mod traits;
|
||||
pub mod types;
|
||||
|
||||
// Parser implementations
|
||||
pub mod parsers;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use errors::{ToolParserError, ToolParserResult};
|
||||
pub use registry::ParserRegistry;
|
||||
pub use state::{ParsePhase, ParseState};
|
||||
pub use traits::{PartialJsonParser, ToolParser};
|
||||
pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCall};
|
||||
|
||||
// Re-export parsers for convenience
|
||||
pub use parsers::{
|
||||
DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser,
|
||||
MistralParser, PythonicParser, QwenParser, Step3Parser,
|
||||
};
|
||||
277
sgl-router/src/tool_parser/parsers/deepseek_parser.rs
Normal file
277
sgl-router/src/tool_parser/parsers/deepseek_parser.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// DeepSeek V3 format parser for tool calls
|
||||
///
|
||||
/// Handles the DeepSeek V3 specific format that uses Unicode tokens:
|
||||
/// `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{name}\n```json\n{args}\n```<|tool▁call▁end|><|tool▁calls▁end|>`
|
||||
///
|
||||
/// Features:
|
||||
/// - Unicode token delimiters
|
||||
/// - JSON arguments in code blocks
|
||||
/// - Support for multiple sequential tool calls
|
||||
pub struct DeepSeekParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex for extracting complete tool calls
|
||||
tool_call_extractor: Regex,
|
||||
/// Regex for extracting function details
|
||||
func_detail_extractor: Regex,
|
||||
}
|
||||
|
||||
impl DeepSeekParser {
|
||||
/// Create a new DeepSeek parser
|
||||
pub fn new() -> Self {
|
||||
// Use (?s) flag for DOTALL mode to handle newlines
|
||||
let tool_call_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
|
||||
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
|
||||
|
||||
let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>";
|
||||
let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
tool_call_extractor,
|
||||
func_detail_extractor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains DeepSeek tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("<|tool▁calls▁begin|>")
|
||||
}
|
||||
|
||||
/// Extract all tool call blocks from text
|
||||
fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> {
|
||||
self.tool_call_extractor
|
||||
.find_iter(text)
|
||||
.map(|m| m.as_str())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a single tool call block
|
||||
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
|
||||
if let Some(captures) = self.func_detail_extractor.captures(block) {
|
||||
// Get function type (should be "function")
|
||||
let func_type = captures.get(1).map_or("", |m| m.as_str());
|
||||
if func_type != "function" {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Get function name
|
||||
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
// Get JSON arguments
|
||||
let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim();
|
||||
|
||||
// Parse JSON arguments
|
||||
match serde_json::from_str::<Value>(json_args) {
|
||||
Ok(value) => {
|
||||
// Create arguments object
|
||||
let args = if value.is_object() {
|
||||
value
|
||||
} else {
|
||||
// If not an object, wrap it
|
||||
serde_json::json!({ "value": value })
|
||||
};
|
||||
|
||||
let arguments = serde_json::to_string(&args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID
|
||||
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name.to_string(),
|
||||
arguments,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(_) => Ok(None),
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DeepSeekParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for DeepSeekParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if text contains DeepSeek format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Extract all tool call blocks
|
||||
let tool_blocks = self.extract_tool_calls(text);
|
||||
let mut tools = Vec::new();
|
||||
|
||||
for block in tool_blocks {
|
||||
if let Some(tool) = self.parse_tool_call(block)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check for tool markers
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No markers found, return as incomplete
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Look for start of tool calls
|
||||
if let Some(start_pos) = state.buffer.find("<|tool▁calls▁begin|>") {
|
||||
// Look for individual tool call start
|
||||
let search_from = start_pos + "<|tool▁calls▁begin|>".len();
|
||||
if let Some(call_start) = state.buffer[search_from..].find("<|tool▁call▁begin|>")
|
||||
{
|
||||
let call_start_abs = search_from + call_start;
|
||||
|
||||
// Look for the end of this tool call
|
||||
let search_end_from = call_start_abs + "<|tool▁call▁begin|>".len();
|
||||
if let Some(call_end) = state.buffer[search_end_from..].find("<|tool▁call▁end|>")
|
||||
{
|
||||
let call_end_abs = search_end_from + call_end + "<|tool▁call▁end|>".len();
|
||||
|
||||
// Extract and parse the complete tool call
|
||||
let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
|
||||
|
||||
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
|
||||
// Remove the processed part from buffer
|
||||
state.buffer.drain(..call_end_abs);
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
} else {
|
||||
// Tool call not complete yet, try to extract partial info
|
||||
let partial = &state.buffer[search_end_from..];
|
||||
|
||||
// Try to extract function name
|
||||
if let Some(sep_pos) = partial.find("<|tool▁sep|>") {
|
||||
if let Some(_func_start) = partial[..sep_pos].rfind("function") {
|
||||
// We have the function type marker
|
||||
let after_sep = &partial[sep_pos + "<|tool▁sep|>".len()..];
|
||||
|
||||
// Look for function name (ends at newline before ```json)
|
||||
if let Some(name_end) = after_sep.find("\n```json\n") {
|
||||
let func_name = after_sep[..name_end].trim();
|
||||
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: func_name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Try to extract partial arguments
|
||||
let args_start = name_end + "\n```json\n".len();
|
||||
let partial_args = &after_sep[args_start..];
|
||||
|
||||
// Check if we can parse partial JSON
|
||||
if !partial_args.is_empty() {
|
||||
match self.partial_json.parse_value(partial_args) {
|
||||
Ok((value, _consumed)) => {
|
||||
let args_str = serde_json::to_string(&value)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
// Can't parse yet, keep buffering
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
self.has_tool_markers(text)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_deepseek_single_tool() {
|
||||
let parser = DeepSeekParser::new();
|
||||
let input = r#"Some text
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Tokyo", "units": "celsius"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Tokyo"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_deepseek_multiple_tools() {
|
||||
let parser = DeepSeekParser::new();
|
||||
let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Tokyo"}
|
||||
```<|tool▁call▁end|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Paris"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Tokyo"));
|
||||
assert!(result[1].function.arguments.contains("Paris"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = DeepSeekParser::new();
|
||||
assert!(parser.detect_format("<|tool▁calls▁begin|>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
292
sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs
Normal file
292
sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs
Normal file
@@ -0,0 +1,292 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// GLM-4 MoE format parser for tool calls
|
||||
///
|
||||
/// Handles the GLM-4 MoE specific format:
|
||||
/// `<tool_call>{name}\n<arg_key>{key}</arg_key>\n<arg_value>{value}</arg_value>\n</tool_call>`
|
||||
///
|
||||
/// Features:
|
||||
/// - XML-style tags for tool calls
|
||||
/// - Key-value pairs for arguments
|
||||
/// - Support for multiple sequential tool calls
|
||||
pub struct Glm4MoeParser {
|
||||
/// Regex for extracting complete tool calls
|
||||
tool_call_extractor: Regex,
|
||||
/// Regex for extracting function details
|
||||
func_detail_extractor: Regex,
|
||||
/// Regex for extracting argument key-value pairs
|
||||
arg_extractor: Regex,
|
||||
}
|
||||
|
||||
impl Glm4MoeParser {
|
||||
/// Create a new GLM-4 MoE parser
|
||||
pub fn new() -> Self {
|
||||
// Use (?s) flag for DOTALL mode to handle newlines
|
||||
let tool_call_pattern = r"(?s)<tool_call>.*?</tool_call>";
|
||||
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
|
||||
|
||||
let func_detail_pattern = r"(?s)<tool_call>([^\n]*)\n(.*)</tool_call>";
|
||||
let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
|
||||
|
||||
let arg_pattern = r"(?s)<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>";
|
||||
let arg_extractor = Regex::new(arg_pattern).expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
tool_call_extractor,
|
||||
func_detail_extractor,
|
||||
arg_extractor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains GLM-4 MoE tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("<tool_call>")
|
||||
}
|
||||
|
||||
/// Parse arguments from key-value pairs
|
||||
fn parse_arguments(&self, args_text: &str) -> ToolParserResult<serde_json::Map<String, Value>> {
|
||||
let mut arguments = serde_json::Map::new();
|
||||
|
||||
for capture in self.arg_extractor.captures_iter(args_text) {
|
||||
let key = capture.get(1).map_or("", |m| m.as_str()).trim();
|
||||
let value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
// Try to parse the value as JSON first, fallback to string
|
||||
let value = if let Ok(json_val) = serde_json::from_str::<Value>(value_str) {
|
||||
json_val
|
||||
} else {
|
||||
// Try parsing as Python literal (similar to Python's ast.literal_eval)
|
||||
if value_str == "true" || value_str == "True" {
|
||||
Value::Bool(true)
|
||||
} else if value_str == "false" || value_str == "False" {
|
||||
Value::Bool(false)
|
||||
} else if value_str == "null" || value_str == "None" {
|
||||
Value::Null
|
||||
} else if let Ok(num) = value_str.parse::<i64>() {
|
||||
Value::Number(num.into())
|
||||
} else if let Ok(num) = value_str.parse::<f64>() {
|
||||
if let Some(n) = serde_json::Number::from_f64(num) {
|
||||
Value::Number(n)
|
||||
} else {
|
||||
Value::String(value_str.to_string())
|
||||
}
|
||||
} else {
|
||||
Value::String(value_str.to_string())
|
||||
}
|
||||
};
|
||||
|
||||
arguments.insert(key.to_string(), value);
|
||||
}
|
||||
|
||||
Ok(arguments)
|
||||
}
|
||||
|
||||
/// Parse a single tool call block
|
||||
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
|
||||
if let Some(captures) = self.func_detail_extractor.captures(block) {
|
||||
// Get function name
|
||||
let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
// Get arguments text
|
||||
let args_text = captures.get(2).map_or("", |m| m.as_str());
|
||||
|
||||
// Parse arguments
|
||||
let arguments = self.parse_arguments(args_text)?;
|
||||
|
||||
let arguments_str = serde_json::to_string(&arguments)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID
|
||||
let id = format!("glm4_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name.to_string(),
|
||||
arguments: arguments_str,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Glm4MoeParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for Glm4MoeParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if text contains GLM-4 MoE format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Extract all tool call blocks
|
||||
let mut tools = Vec::new();
|
||||
for mat in self.tool_call_extractor.find_iter(text) {
|
||||
if let Some(tool) = self.parse_tool_call(mat.as_str())? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check for tool markers
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No markers found, return as incomplete
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Look for start of tool call
|
||||
if let Some(start_pos) = state.buffer.find("<tool_call>") {
|
||||
// Look for the end of this tool call
|
||||
let search_from = start_pos + "<tool_call>".len();
|
||||
if let Some(end_pos) = state.buffer[search_from..].find("</tool_call>") {
|
||||
let end_abs = search_from + end_pos + "</tool_call>".len();
|
||||
|
||||
// Extract and parse the complete tool call
|
||||
let tool_call_text = &state.buffer[start_pos..end_abs];
|
||||
|
||||
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
|
||||
// Remove the processed part from buffer
|
||||
state.buffer.drain(..end_abs);
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
} else {
|
||||
// Tool call not complete yet, try to extract partial info
|
||||
let partial = &state.buffer[search_from..];
|
||||
|
||||
// Try to extract function name (first line after <tool_call>)
|
||||
if let Some(name_end) = partial.find('\n') {
|
||||
let func_name = partial[..name_end].trim();
|
||||
|
||||
if !func_name.is_empty() && !state.in_string {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: func_name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Try to extract partial arguments
|
||||
let args_text = &partial[name_end + 1..];
|
||||
let partial_args = self.parse_arguments(args_text)?;
|
||||
|
||||
if !partial_args.is_empty() {
|
||||
let args_str = serde_json::to_string(&partial_args)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
self.has_tool_markers(text)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_glm4_single_tool() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
let input = r#"Some text
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2024-06-27</arg_value>
|
||||
</tool_call>More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Beijing"));
|
||||
assert!(result[0].function.arguments.contains("2024-06-27"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_glm4_multiple_tools() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
let input = r#"<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Shanghai</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Beijing"));
|
||||
assert!(result[1].function.arguments.contains("Shanghai"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_glm4_mixed_types() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
let input = r#"<tool_call>process_data
|
||||
<arg_key>count</arg_key>
|
||||
<arg_value>42</arg_value>
|
||||
<arg_key>active</arg_key>
|
||||
<arg_value>true</arg_value>
|
||||
<arg_key>name</arg_key>
|
||||
<arg_value>test</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process_data");
|
||||
|
||||
// Parse arguments to check types
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["count"], 42);
|
||||
assert_eq!(args["active"], true);
|
||||
assert_eq!(args["name"], "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
assert!(parser.detect_format("<tool_call>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
292
sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs
Normal file
292
sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs
Normal file
@@ -0,0 +1,292 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// GPT-OSS format parser for tool calls
|
||||
///
|
||||
/// Handles the GPT-OSS specific channel format:
|
||||
/// `<|channel|>commentary to={namespace.function}<|constrain|>json<|message|>{json_args}<|call|>`
|
||||
///
|
||||
/// Features:
|
||||
/// - Channel-based format with commentary
|
||||
/// - Namespaced function calls
|
||||
/// - JSON arguments
|
||||
pub struct GptOssParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex for extracting complete function calls
|
||||
function_call_extractor: Regex,
|
||||
/// Regex for extracting streaming function calls
|
||||
streaming_extractor: Regex,
|
||||
}
|
||||
|
||||
impl GptOssParser {
|
||||
/// Create a new GPT-OSS parser
|
||||
pub fn new() -> Self {
|
||||
// Pattern for complete function calls with to= parameter
|
||||
// Handles optional <|start|>assistant prefix and whitespace after function name
|
||||
let function_call_pattern = r"(?s)(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*<\|constrain\|>json<\|message\|>(.*?)<\|call\|>(?:commentary)?";
|
||||
let function_call_extractor =
|
||||
Regex::new(function_call_pattern).expect("Valid regex pattern");
|
||||
|
||||
// Pattern for streaming function calls (incomplete)
|
||||
let streaming_pattern = r"(?s)(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*<\|constrain\|>json<\|message\|>(.*)";
|
||||
let streaming_extractor = Regex::new(streaming_pattern).expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
function_call_extractor,
|
||||
streaming_extractor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains GPT-OSS tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("<|channel|>commentary to=")
|
||||
}
|
||||
|
||||
/// Extract function name from full namespace (e.g., "functions.get_weather" -> "get_weather")
|
||||
fn extract_function_name(&self, full_name: &str) -> String {
|
||||
if let Some(dot_pos) = full_name.rfind('.') {
|
||||
full_name[dot_pos + 1..].to_string()
|
||||
} else {
|
||||
full_name.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GptOssParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for GptOssParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if text contains GPT-OSS format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let mut tools = Vec::new();
|
||||
let mut _tool_index = 0;
|
||||
|
||||
// Extract all function calls
|
||||
for captures in self.function_call_extractor.captures_iter(text) {
|
||||
if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) {
|
||||
let full_function_name = name_match.as_str();
|
||||
let args_content = args_match.as_str().trim();
|
||||
|
||||
// Extract actual function name
|
||||
let function_name = self.extract_function_name(full_function_name);
|
||||
|
||||
// Parse JSON arguments
|
||||
let arguments = if args_content.is_empty() {
|
||||
"{}".to_string()
|
||||
} else {
|
||||
match serde_json::from_str::<Value>(args_content) {
|
||||
Ok(value) => serde_json::to_string(&value)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?,
|
||||
Err(_) => {
|
||||
// Skip malformed JSON
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Generate unique ID
|
||||
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
tools.push(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: function_name,
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
|
||||
_tool_index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check for tool markers
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No markers found, clear buffer and return
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Try to match streaming pattern
|
||||
if let Some(captures) = self.streaming_extractor.captures(&state.buffer) {
|
||||
if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) {
|
||||
let full_function_name = name_match.as_str();
|
||||
let partial_args = args_match.as_str();
|
||||
|
||||
// Extract actual function name
|
||||
let function_name = self.extract_function_name(full_function_name);
|
||||
|
||||
// Send function name if not sent yet
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: function_name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check if we have a complete function call
|
||||
if let Some(complete_match) = self.function_call_extractor.captures(&state.buffer) {
|
||||
if let Some(args_match) = complete_match.get(2) {
|
||||
let args_content = args_match.as_str().trim();
|
||||
|
||||
// Parse JSON arguments
|
||||
let arguments = if args_content.is_empty() {
|
||||
"{}".to_string()
|
||||
} else {
|
||||
match serde_json::from_str::<Value>(args_content) {
|
||||
Ok(value) => serde_json::to_string(&value)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
Err(_) => "{}".to_string(),
|
||||
}
|
||||
};
|
||||
|
||||
// Generate unique ID
|
||||
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
let tool = ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: function_name,
|
||||
arguments,
|
||||
},
|
||||
};
|
||||
|
||||
// Remove the processed part from buffer
|
||||
let complete_end = complete_match.get(0).unwrap().end();
|
||||
state.buffer.drain(..complete_end);
|
||||
|
||||
// Reset state for next tool
|
||||
state.in_string = false;
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
} else {
|
||||
// Try to parse partial JSON for streaming arguments
|
||||
if !partial_args.is_empty() {
|
||||
// Look for the end of JSON (before <|call|>)
|
||||
let json_part = if let Some(call_pos) = partial_args.find("<|call|>") {
|
||||
&partial_args[..call_pos]
|
||||
} else {
|
||||
partial_args
|
||||
};
|
||||
|
||||
match self.partial_json.parse_value(json_part) {
|
||||
Ok((value, _consumed)) => {
|
||||
let args_str = serde_json::to_string(&value)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
// Can't parse yet, keep buffering
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
self.has_tool_markers(text) || text.contains("<|channel|>commentary")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_gpt_oss_single_tool() {
|
||||
let parser = GptOssParser::new();
|
||||
let input = r#"Some text
|
||||
<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "San Francisco"}<|call|>
|
||||
More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("San Francisco"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_gpt_oss_multiple_tools() {
|
||||
let parser = GptOssParser::new();
|
||||
let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary
|
||||
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
assert!(result[0].function.arguments.contains("Paris"));
|
||||
assert!(result[1].function.arguments.contains("Paris tourism"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_gpt_oss_with_prefix() {
|
||||
let parser = GptOssParser::new();
|
||||
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_gpt_oss_empty_args() {
|
||||
let parser = GptOssParser::new();
|
||||
let input =
|
||||
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
assert_eq!(result[0].function.arguments, "{}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = GptOssParser::new();
|
||||
assert!(parser.detect_format("<|channel|>commentary to="));
|
||||
assert!(parser.detect_format("<|channel|>commentary"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
619
sgl-router/src/tool_parser/parsers/json_parser.rs
Normal file
619
sgl-router/src/tool_parser/parsers/json_parser.rs
Normal file
@@ -0,0 +1,619 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, TokenConfig, ToolCall},
|
||||
};
|
||||
|
||||
/// JSON format parser for tool calls
|
||||
///
|
||||
/// Handles various JSON formats for function calling:
|
||||
/// - Single tool call: {"name": "fn", "arguments": {...}}
|
||||
/// - Multiple tool calls: [{"name": "fn1", "arguments": {...}}, ...]
|
||||
/// - With parameters instead of arguments: {"name": "fn", "parameters": {...}}
|
||||
///
|
||||
/// Supports configurable token markers for different models
|
||||
pub struct JsonParser {
|
||||
/// Token configuration for parsing
|
||||
token_config: TokenConfig,
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex patterns for extracting content between tokens
|
||||
extractors: Vec<Regex>,
|
||||
}
|
||||
|
||||
impl JsonParser {
|
||||
/// Create a new JSON parser with default configuration
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(TokenConfig {
|
||||
start_tokens: vec![],
|
||||
end_tokens: vec![],
|
||||
separator: ", ".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a parser with custom token configuration
|
||||
pub fn with_config(config: TokenConfig) -> Self {
|
||||
// Build extraction patterns for each token pair
|
||||
let extractors: Vec<Regex> = config
|
||||
.iter_pairs()
|
||||
.filter_map(|(start, end)| {
|
||||
if !start.is_empty() && !end.is_empty() {
|
||||
// Use (?s) flag to enable DOTALL mode so . matches newlines
|
||||
let pattern =
|
||||
format!(r"(?s){}(.*?){}", regex::escape(start), regex::escape(end));
|
||||
Regex::new(&pattern).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
token_config: config,
|
||||
partial_json: PartialJson::default(),
|
||||
extractors,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract JSON content from text, handling wrapper tokens if configured
|
||||
fn extract_json_content<'a>(&self, text: &'a str) -> &'a str {
|
||||
let mut content = text;
|
||||
|
||||
// Try each extractor pattern (for tokens with both start and end)
|
||||
for extractor in &self.extractors {
|
||||
if let Some(captures) = extractor.captures(content) {
|
||||
if let Some(matched) = captures.get(1) {
|
||||
return matched.as_str().trim();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle special case where there's a start token but no end token
|
||||
for (start, end) in self.token_config.iter_pairs() {
|
||||
if !start.is_empty() && end.is_empty() {
|
||||
// Find the start token and extract everything after it
|
||||
if let Some(pos) = content.find(start) {
|
||||
content = &content[pos + start.len()..];
|
||||
return content.trim();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content.trim()
|
||||
}
|
||||
|
||||
/// Try to extract a JSON object or array from text that may contain other content
|
||||
fn extract_json_from_text(&self, text: &str) -> Option<String> {
|
||||
// Look for JSON object starting with {
|
||||
if let Some(start) = text.find('{') {
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
for (i, ch) in text[start..].char_indices() {
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'\\' if in_string => escape_next = true,
|
||||
'"' if !in_string => in_string = true,
|
||||
'"' if in_string => in_string = false,
|
||||
'{' if !in_string => depth += 1,
|
||||
'}' if !in_string => {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
return Some(text[start..start + i + 1].to_string());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Look for JSON array starting with [
|
||||
if let Some(start) = text.find('[') {
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
for (i, ch) in text[start..].char_indices() {
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'\\' if in_string => escape_next = true,
|
||||
'"' if !in_string => in_string = true,
|
||||
'"' if in_string => in_string = false,
|
||||
'[' if !in_string => depth += 1,
|
||||
']' if !in_string => {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
return Some(text[start..start + i + 1].to_string());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse a single JSON object into a ToolCall
|
||||
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
|
||||
// Check if this looks like a tool call
|
||||
let name = obj
|
||||
.get("name")
|
||||
.or_else(|| obj.get("function"))
|
||||
.and_then(|v| v.as_str());
|
||||
|
||||
if let Some(name) = name {
|
||||
// Get arguments - support both "arguments" and "parameters" keys
|
||||
let empty_obj = Value::Object(serde_json::Map::new());
|
||||
let args = obj
|
||||
.get("arguments")
|
||||
.or_else(|| obj.get("parameters"))
|
||||
.unwrap_or(&empty_obj);
|
||||
|
||||
// Convert arguments to JSON string
|
||||
let arguments = serde_json::to_string(args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate a unique ID if not provided
|
||||
let id = obj
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4()));
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.to_string(),
|
||||
arguments,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse JSON value(s) into tool calls
|
||||
fn parse_json_value(&self, value: &Value) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let mut tools = Vec::new();
|
||||
|
||||
match value {
|
||||
Value::Array(arr) => {
|
||||
// Parse each element in the array
|
||||
for item in arr {
|
||||
if let Some(tool) = self.parse_single_object(item)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
Value::Object(_) => {
|
||||
// Single tool call
|
||||
if let Some(tool) = self.parse_single_object(value)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Not a valid tool call format
|
||||
return Ok(vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
/// Check if text contains potential tool call markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
// If no start tokens configured, check for JSON structure
|
||||
if self.token_config.start_tokens.is_empty() {
|
||||
// For JSON, we just need to see the start of an object or array
|
||||
return text.contains('{') || text.contains('[');
|
||||
}
|
||||
|
||||
// Check for any start token
|
||||
self.token_config
|
||||
.start_tokens
|
||||
.iter()
|
||||
.any(|token| text.contains(token))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for JsonParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for JsonParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if we have multiple start tokens (e.g., multiple <|python_tag|> markers)
|
||||
if !self.token_config.start_tokens.is_empty() {
|
||||
let start_token = &self.token_config.start_tokens[0];
|
||||
if !start_token.is_empty() && text.matches(start_token).count() > 1 {
|
||||
// We have multiple occurrences of the start token
|
||||
let mut all_tools = Vec::new();
|
||||
let mut remaining = text;
|
||||
|
||||
while let Some(start_pos) = remaining.find(start_token.as_str()) {
|
||||
// Extract content after this start token
|
||||
let after_token = &remaining[start_pos + start_token.len()..];
|
||||
|
||||
// Find where this JSON ends (look for the next start token or end of string)
|
||||
let end_pos = if let Some(next_start) = after_token.find(start_token.as_str()) {
|
||||
next_start
|
||||
} else {
|
||||
after_token.len()
|
||||
};
|
||||
|
||||
let json_content = &after_token[..end_pos];
|
||||
|
||||
// Try to extract and parse JSON from this segment
|
||||
if let Some(extracted) = self.extract_json_from_text(json_content) {
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
|
||||
if let Ok(tools) = self.parse_json_value(&value) {
|
||||
all_tools.extend(tools);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Move to the next segment
|
||||
remaining = &remaining[start_pos + start_token.len() + end_pos..];
|
||||
if remaining.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !all_tools.is_empty() {
|
||||
return Ok(all_tools);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract JSON content from wrapper tokens if present
|
||||
let json_content = self.extract_json_content(text);
|
||||
|
||||
// Try to parse as JSON first
|
||||
match serde_json::from_str::<Value>(json_content) {
|
||||
Ok(value) => self.parse_json_value(&value),
|
||||
Err(_) => {
|
||||
// If parse failed, check if we have multiple JSON objects separated by the configured separator
|
||||
// This handles cases like: {"name": "func1", ...};{"name": "func2", ...}
|
||||
if !self.token_config.separator.is_empty()
|
||||
&& json_content.contains(&self.token_config.separator)
|
||||
{
|
||||
let mut all_tools = Vec::new();
|
||||
|
||||
// Split by separator and try to parse each part
|
||||
let parts: Vec<&str> =
|
||||
json_content.split(&self.token_config.separator).collect();
|
||||
for part in parts {
|
||||
let trimmed = part.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to parse this part as JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
|
||||
if let Ok(tools) = self.parse_json_value(&value) {
|
||||
all_tools.extend(tools);
|
||||
}
|
||||
} else if let Some(extracted) = self.extract_json_from_text(trimmed) {
|
||||
// Try extracting JSON from this part
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
|
||||
if let Ok(tools) = self.parse_json_value(&value) {
|
||||
all_tools.extend(tools);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !all_tools.is_empty() {
|
||||
return Ok(all_tools);
|
||||
}
|
||||
}
|
||||
|
||||
// If no wrapper tokens configured and parse failed,
|
||||
// try to extract JSON from mixed text
|
||||
if self.token_config.start_tokens.is_empty() {
|
||||
if let Some(extracted) = self.extract_json_from_text(text) {
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
|
||||
return self.parse_json_value(&value);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Not valid JSON, return empty
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check if we have potential tool calls
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No tool markers, return as incomplete
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Extract JSON content first to check for separators
|
||||
let extracted_json = self.extract_json_content(&state.buffer);
|
||||
|
||||
// Handle multiple JSON objects with separators
|
||||
// Check if we have a separator and potentially multiple JSON objects
|
||||
let separator = &self.token_config.separator;
|
||||
if !separator.is_empty() && extracted_json.contains(separator.as_str()) {
|
||||
// Try to find a complete JSON object before the separator
|
||||
if let Some(separator_pos) = extracted_json.find(separator.as_str()) {
|
||||
// Get JSON before separator
|
||||
let before_separator = &extracted_json[..separator_pos];
|
||||
|
||||
// Try to parse the JSON before the separator
|
||||
match serde_json::from_str::<Value>(before_separator) {
|
||||
Ok(value) => {
|
||||
// Parse tool calls from this JSON
|
||||
let tools = self.parse_json_value(&value)?;
|
||||
if !tools.is_empty() {
|
||||
// We need to figure out how much to remove from the original buffer
|
||||
// Find where the separator is in the original buffer and remove up to and including it
|
||||
if let Some(sep_in_original) = state.buffer.find(separator.as_str()) {
|
||||
let remaining =
|
||||
state.buffer[sep_in_original + separator.len()..].to_string();
|
||||
state.buffer = remaining;
|
||||
}
|
||||
|
||||
// Return the first tool as complete
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse, continue to try other methods
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple start tokens (e.g., multiple <|python_tag|> markers)
|
||||
if !self.token_config.start_tokens.is_empty() {
|
||||
let start_token = &self.token_config.start_tokens[0];
|
||||
if !start_token.is_empty() {
|
||||
// Find all occurrences of start token
|
||||
let occurrences: Vec<_> =
|
||||
state.buffer.match_indices(start_token.as_str()).collect();
|
||||
if occurrences.len() > 1 {
|
||||
// We have multiple start tokens, try to process the first complete one
|
||||
let first_pos = occurrences[0].0;
|
||||
let second_pos = occurrences[1].0;
|
||||
|
||||
// Extract content between first and second start token
|
||||
let first_json_section = &state.buffer[first_pos..second_pos];
|
||||
let json_content = self.extract_json_content(first_json_section);
|
||||
|
||||
// Try to parse this as complete JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json_content) {
|
||||
// Parse tool calls from this JSON
|
||||
let tools = self.parse_json_value(&value)?;
|
||||
if !tools.is_empty() {
|
||||
// Remove the processed section from buffer
|
||||
let remaining = state.buffer[second_pos..].to_string();
|
||||
state.buffer = remaining;
|
||||
|
||||
// Return the first tool as complete
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Regular single JSON parsing
|
||||
// Extract JSON content
|
||||
let json_content = self.extract_json_content(&state.buffer);
|
||||
|
||||
// Try to parse with partial JSON parser
|
||||
match self.partial_json.parse_value(json_content) {
|
||||
Ok((value, consumed)) => {
|
||||
// Check if we have a complete JSON structure
|
||||
if consumed == json_content.len() {
|
||||
// Check if this is truly complete or just has null from incomplete parsing
|
||||
// We need to ensure the JSON actually ends properly (not cut off mid-key)
|
||||
let trimmed = json_content.trim();
|
||||
let looks_complete = trimmed.ends_with('}') || trimmed.ends_with(']');
|
||||
|
||||
if looks_complete {
|
||||
// Complete JSON, parse tool calls
|
||||
let tools = self.parse_json_value(&value)?;
|
||||
if !tools.is_empty() {
|
||||
// Clear buffer since we consumed everything
|
||||
state.buffer.clear();
|
||||
|
||||
// Return the first tool as complete
|
||||
// TODO simplified version, address more complex version
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Partial JSON, try to extract tool name
|
||||
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
|
||||
// TODO simplified version, address more complex version
|
||||
// Just return the tool name once we see it
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as a flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for complete arguments
|
||||
if let Some(args) =
|
||||
value.get("arguments").or_else(|| value.get("parameters"))
|
||||
{
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
// Return arguments as a single update
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Keep buffering
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
// Check if text contains JSON-like structure
|
||||
if self.has_tool_markers(text) {
|
||||
// Try to extract and parse
|
||||
let json_content = self.extract_json_content(text);
|
||||
|
||||
// Check if it looks like valid JSON for tool calls
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json_content) {
|
||||
match value {
|
||||
Value::Object(ref obj) => {
|
||||
// Check for tool call structure
|
||||
obj.contains_key("name") || obj.contains_key("function")
|
||||
}
|
||||
Value::Array(ref arr) => {
|
||||
// Check if array contains tool-like objects
|
||||
arr.iter().any(|v| {
|
||||
if let Some(obj) = v.as_object() {
|
||||
obj.contains_key("name") || obj.contains_key("function")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_single_tool_call() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_multiple_tool_calls() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"[
|
||||
{"name": "get_weather", "arguments": {"location": "SF"}},
|
||||
{"name": "search", "arguments": {"query": "news"}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_parameters_key() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
assert!(result[0].function.arguments.contains("10"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_wrapper_tokens() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tool>".to_string()],
|
||||
end_tokens: vec!["</tool>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(parser.detect_format(r#"[{"name": "test"}]"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_parse() {
|
||||
// Just verify that streaming eventually produces a complete tool call
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send complete JSON in one go
|
||||
// TODO simplified version, address more complex version
|
||||
let full_json = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_json, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get a complete tool immediately with complete JSON
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
assert!(tool.function.arguments.contains("SF"));
|
||||
}
|
||||
_ => panic!("Expected ToolComplete for complete JSON input"),
|
||||
}
|
||||
}
|
||||
}
|
||||
270
sgl-router/src/tool_parser/parsers/kimik2_parser.rs
Normal file
270
sgl-router/src/tool_parser/parsers/kimik2_parser.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Kimi K2 format parser for tool calls
|
||||
///
|
||||
/// Handles the Kimi K2 specific format:
|
||||
/// `<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|><|tool_calls_section_end|>`
|
||||
///
|
||||
/// Features:
|
||||
/// - Token-based delimiters
|
||||
/// - Function calls with explicit indexing
|
||||
/// - JSON arguments
|
||||
pub struct KimiK2Parser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex for extracting complete tool calls
|
||||
tool_call_extractor: Regex,
|
||||
/// Regex for extracting partial tool calls (streaming)
|
||||
stream_tool_call_extractor: Regex,
|
||||
}
|
||||
|
||||
impl KimiK2Parser {
|
||||
/// Create a new Kimi K2 parser
|
||||
pub fn new() -> Self {
|
||||
// Pattern for complete tool calls
|
||||
let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>";
|
||||
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
|
||||
|
||||
// Pattern for streaming (partial) tool calls
|
||||
let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
|
||||
let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
tool_call_extractor,
|
||||
stream_tool_call_extractor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains Kimi K2 tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("<|tool_calls_section_begin|>")
|
||||
}
|
||||
|
||||
/// Parse function ID to extract name and index
|
||||
fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
|
||||
// Format: functions.{name}:{index} or namespace.functions.{name}:{index}
|
||||
// Extract everything after the last dot before the colon as the function name
|
||||
if let Some(colon_pos) = id.rfind(':') {
|
||||
let before_colon = &id[..colon_pos];
|
||||
let index_str = &id[colon_pos + 1..];
|
||||
|
||||
// Find the last dot to extract the function name
|
||||
if let Some(dot_pos) = before_colon.rfind('.') {
|
||||
let func_name = &before_colon[dot_pos + 1..];
|
||||
|
||||
if let Ok(index) = index_str.parse::<usize>() {
|
||||
return Some((func_name.to_string(), index));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KimiK2Parser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for KimiK2Parser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if text contains Kimi K2 format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let mut tools = Vec::new();
|
||||
|
||||
// Extract all tool calls
|
||||
for captures in self.tool_call_extractor.captures_iter(text) {
|
||||
if let (Some(id_match), Some(args_match)) = (
|
||||
captures.name("tool_call_id"),
|
||||
captures.name("function_arguments"),
|
||||
) {
|
||||
let function_id = id_match.as_str();
|
||||
let function_args = args_match.as_str();
|
||||
|
||||
// Parse function ID
|
||||
if let Some((func_name, _index)) = self.parse_function_id(function_id) {
|
||||
// Validate JSON arguments
|
||||
if serde_json::from_str::<serde_json::Value>(function_args).is_ok() {
|
||||
// Generate unique ID
|
||||
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
tools.push(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name,
|
||||
arguments: function_args.to_string(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check for tool markers
|
||||
let has_tool_call =
|
||||
self.has_tool_markers(&state.buffer) || state.buffer.contains("<|tool_call_begin|>");
|
||||
|
||||
if !has_tool_call {
|
||||
// No markers found, clear buffer and return
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Try to match streaming pattern
|
||||
if let Some(captures) = self.stream_tool_call_extractor.captures(&state.buffer) {
|
||||
if let (Some(id_match), Some(args_match)) = (
|
||||
captures.name("tool_call_id"),
|
||||
captures.name("function_arguments"),
|
||||
) {
|
||||
let function_id = id_match.as_str();
|
||||
let partial_args = args_match.as_str();
|
||||
|
||||
// Parse function ID
|
||||
if let Some((func_name, _index)) = self.parse_function_id(function_id) {
|
||||
// Send function name if not sent yet
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: func_name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check if we have a complete tool call
|
||||
if let Some(end_pos) = partial_args.find("<|tool_call_end|>") {
|
||||
// Extract just the JSON part
|
||||
let json_args = &partial_args[..end_pos];
|
||||
|
||||
// Validate and parse JSON
|
||||
if serde_json::from_str::<serde_json::Value>(json_args).is_ok() {
|
||||
// Generate unique ID
|
||||
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
let tool = ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name,
|
||||
arguments: json_args.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
// Find where this tool call ends in the buffer
|
||||
if let Some(tool_end) = state.buffer.find("<|tool_call_end|>") {
|
||||
let end_pos = tool_end + "<|tool_call_end|>".len();
|
||||
state.buffer.drain(..end_pos);
|
||||
}
|
||||
|
||||
// Reset state for next tool
|
||||
state.in_string = false;
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
} else {
|
||||
// Try to parse partial JSON for streaming arguments
|
||||
match self.partial_json.parse_value(partial_args) {
|
||||
Ok((value, _consumed)) => {
|
||||
let args_str = serde_json::to_string(&value)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
// Can't parse yet, keep buffering
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_kimi_single_tool() {
|
||||
let parser = KimiK2Parser::new();
|
||||
let input = r#"Some text
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Tokyo"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_kimi_multiple_tools() {
|
||||
let parser = KimiK2Parser::new();
|
||||
let input = r#"<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust"}<|tool_call_end|>
|
||||
<|tool_call_begin|>functions.calculate:1<|tool_call_argument_begin|>{"expression": "2+2"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_kimi_with_whitespace() {
|
||||
let parser = KimiK2Parser::new();
|
||||
let input = r#"<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value"} <|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = KimiK2Parser::new();
|
||||
assert!(parser.detect_format("<|tool_calls_section_begin|>"));
|
||||
assert!(parser.detect_format("<|tool_call_begin|>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
156
sgl-router/src/tool_parser/parsers/llama_parser.rs
Normal file
156
sgl-router/src/tool_parser/parsers/llama_parser.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::json_parser::JsonParser;
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{StreamResult, TokenConfig, ToolCall},
|
||||
};
|
||||
|
||||
/// Llama 3.2 format parser for tool calls
|
||||
///
|
||||
/// Handles the Llama 3.2 specific format:
|
||||
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
|
||||
///
|
||||
/// Also supports plain JSON without the python_tag prefix
|
||||
pub struct LlamaParser {
|
||||
/// Underlying JSON parser with Llama-specific configuration
|
||||
json_parser: JsonParser,
|
||||
}
|
||||
|
||||
impl LlamaParser {
|
||||
/// Create a new Llama parser
|
||||
pub fn new() -> Self {
|
||||
// Configure JSON parser with Llama's python_tag token
|
||||
// Note: No end token for python_tag format
|
||||
let json_parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<|python_tag|>".to_string()],
|
||||
end_tokens: vec!["".to_string()], // Empty end token
|
||||
separator: ";".to_string(), // Llama uses semicolon for multiple calls (though not well supported)
|
||||
});
|
||||
|
||||
Self { json_parser }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LlamaParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for LlamaParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// First try with the configured python_tag parser
|
||||
let result = self.json_parser.parse_complete(text).await?;
|
||||
|
||||
if !result.is_empty() {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
// If no results and text starts with '{', try plain JSON
|
||||
if text.trim_start().starts_with('{') {
|
||||
// Create a temporary plain JSON parser
|
||||
let plain_parser = JsonParser::new();
|
||||
return plain_parser.parse_complete(text).await;
|
||||
}
|
||||
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
// Try with the python_tag parser first
|
||||
let result = self.json_parser.parse_incremental(chunk, state).await?;
|
||||
|
||||
// If we get Incomplete and buffer starts with '{', might be plain JSON
|
||||
if matches!(result, StreamResult::Incomplete) && state.buffer.trim_start().starts_with('{')
|
||||
{
|
||||
// Check if we have python_tag in the buffer
|
||||
if !state.buffer.contains("<|python_tag|>") {
|
||||
// Likely plain JSON, create temporary parser
|
||||
let plain_parser = JsonParser::new();
|
||||
return plain_parser.parse_incremental("", state).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
// Llama format if contains python_tag or starts with JSON object
|
||||
text.contains("<|python_tag|>")
|
||||
|| (text.trim_start().starts_with('{')
|
||||
&& (text.contains(r#""name""#) || text.contains(r#""function""#)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_python_tag() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert!(result[0].function.arguments.contains("weather"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_plain_json() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_text_before() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_call_with_semicolon() {
|
||||
let parser = LlamaParser::new();
|
||||
// Note: Llama 3.2 doesn't handle multiple calls well
|
||||
// Test that we can at least parse a single call followed by semicolon
|
||||
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
// We expect this to either parse the first JSON object or fail gracefully
|
||||
// Since the semicolon makes it invalid JSON, it will likely return empty
|
||||
// This is acceptable as Llama 3.2 doesn't reliably support parallel calls
|
||||
|
||||
// If it parses anything, it should be func1
|
||||
if !result.is_empty() {
|
||||
assert_eq!(result[0].function.name, "func1");
|
||||
}
|
||||
}
|
||||
}
|
||||
347
sgl-router/src/tool_parser/parsers/mistral_parser.rs
Normal file
347
sgl-router/src/tool_parser/parsers/mistral_parser.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Mistral format parser for tool calls
|
||||
///
|
||||
/// Handles the Mistral-specific format:
|
||||
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
|
||||
///
|
||||
/// Features:
|
||||
/// - Bracket counting for proper JSON array extraction
|
||||
/// - Support for multiple tool calls in a single array
|
||||
/// - String-aware parsing to handle nested brackets in JSON
|
||||
pub struct MistralParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
}
|
||||
|
||||
impl MistralParser {
|
||||
/// Create a new Mistral parser
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract JSON array using bracket counting
|
||||
///
|
||||
/// Handles nested brackets in JSON content by tracking:
|
||||
/// - String boundaries (quotes)
|
||||
/// - Escape sequences
|
||||
/// - Bracket depth
|
||||
fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
|
||||
const BOT_TOKEN: &str = "[TOOL_CALLS] [";
|
||||
|
||||
// Find the start of the token
|
||||
let start_idx = text.find(BOT_TOKEN)?;
|
||||
|
||||
// Start from the opening bracket after [TOOL_CALLS]
|
||||
// The -1 is to include the opening bracket that's part of the token
|
||||
let json_start = start_idx + BOT_TOKEN.len() - 1;
|
||||
|
||||
let mut bracket_count = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
let bytes = text.as_bytes();
|
||||
|
||||
for i in json_start..text.len() {
|
||||
let char = bytes[i];
|
||||
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if char == b'\\' {
|
||||
escape_next = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if char == b'"' && !escape_next {
|
||||
in_string = !in_string;
|
||||
continue;
|
||||
}
|
||||
|
||||
if !in_string {
|
||||
if char == b'[' {
|
||||
bracket_count += 1;
|
||||
} else if char == b']' {
|
||||
bracket_count -= 1;
|
||||
if bracket_count == 0 {
|
||||
// Found the matching closing bracket
|
||||
return Some(&text[json_start..=i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Incomplete array (no matching closing bracket found)
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse tool calls from a JSON array
|
||||
fn parse_json_array(&self, json_str: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let value: Value = serde_json::from_str(json_str)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
let mut tools = Vec::new();
|
||||
|
||||
if let Value::Array(arr) = value {
|
||||
for (index, item) in arr.iter().enumerate() {
|
||||
if let Some(tool) = self.parse_single_object(item, index)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single object case (shouldn't happen with Mistral format, but handle it)
|
||||
if let Some(tool) = self.parse_single_object(&value, 0)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
/// Parse a single JSON object into a ToolCall
|
||||
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
|
||||
let name = obj.get("name").and_then(|v| v.as_str());
|
||||
|
||||
if let Some(name) = name {
|
||||
// Get arguments - Mistral uses "arguments" key
|
||||
let empty_obj = Value::Object(serde_json::Map::new());
|
||||
let args = obj.get("arguments").unwrap_or(&empty_obj);
|
||||
|
||||
// Convert arguments to JSON string
|
||||
let arguments = serde_json::to_string(args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID with index for multiple tools
|
||||
let id = format!("mistral_call_{}", index);
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.to_string(),
|
||||
arguments,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains Mistral tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("[TOOL_CALLS]")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MistralParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for MistralParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if text contains Mistral format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Extract JSON array from Mistral format
|
||||
if let Some(json_array) = self.extract_json_array(text) {
|
||||
self.parse_json_array(json_array)
|
||||
} else {
|
||||
// Markers present but no complete array found
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check if we have the start marker
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Try to extract complete JSON array
|
||||
if let Some(json_array) = self.extract_json_array(&state.buffer) {
|
||||
// Parse with partial JSON to handle incomplete content
|
||||
match self.partial_json.parse_value(json_array) {
|
||||
Ok((value, consumed)) => {
|
||||
// Check if we have a complete JSON structure
|
||||
if consumed == json_array.len() {
|
||||
// Complete JSON, parse tool calls
|
||||
let tools = if let Value::Array(arr) = value {
|
||||
let mut result = Vec::new();
|
||||
for (index, item) in arr.iter().enumerate() {
|
||||
if let Some(tool) = self.parse_single_object(item, index)? {
|
||||
result.push(tool);
|
||||
}
|
||||
}
|
||||
result
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
if !tools.is_empty() {
|
||||
// Clear buffer since we consumed everything
|
||||
state.buffer.clear();
|
||||
|
||||
// Return the first tool (simplified for Phase 3)
|
||||
// Full multi-tool streaming will be implemented later
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Partial JSON - try to extract tool name for streaming
|
||||
if let Value::Array(arr) = value {
|
||||
if let Some(first_tool) = arr.first() {
|
||||
if let Some(name) = first_tool.get("name").and_then(|v| v.as_str())
|
||||
{
|
||||
// Check if we've already sent the name
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for arguments
|
||||
if let Some(args) = first_tool.get("arguments") {
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Keep buffering
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
// Check if text contains Mistral-specific markers
|
||||
if self.has_tool_markers(text) {
|
||||
// Try to extract and validate the array
|
||||
if let Some(json_array) = self.extract_json_array(text) {
|
||||
// Check if it's valid JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json_array) {
|
||||
// Check if it contains tool-like structures
|
||||
match value {
|
||||
Value::Array(ref arr) => arr.iter().any(|v| {
|
||||
v.as_object().is_some_and(|o| {
|
||||
o.contains_key("name") && o.contains_key("arguments")
|
||||
})
|
||||
}),
|
||||
Value::Object(ref obj) => {
|
||||
obj.contains_key("name") && obj.contains_key("arguments")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
// Has markers but no complete array - might be streaming
|
||||
true
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_mistral_format() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Paris"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_multiple_tools() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [
|
||||
{"name": "search", "arguments": {"query": "rust programming"}},
|
||||
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_brackets_in_json() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process");
|
||||
// JSON serialization removes spaces, so check for [3,4] without spaces
|
||||
assert!(result[0].function.arguments.contains("[3,4]"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_escaped_quotes_in_strings() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "echo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = MistralParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#));
|
||||
assert!(
|
||||
parser.detect_format(r#"Some text [TOOL_CALLS] [{"name": "test", "arguments": {}}]"#)
|
||||
);
|
||||
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
}
|
||||
}
|
||||
27
sgl-router/src/tool_parser/parsers/mod.rs
Normal file
27
sgl-router/src/tool_parser/parsers/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
/// Parser implementations for different model formats
|
||||
///
|
||||
/// This module contains concrete parser implementations for various model-specific
|
||||
/// tool/function call formats.
|
||||
// Individual parser modules
|
||||
pub mod deepseek_parser;
|
||||
pub mod glm4_moe_parser;
|
||||
pub mod gpt_oss_parser;
|
||||
pub mod json_parser;
|
||||
pub mod kimik2_parser;
|
||||
pub mod llama_parser;
|
||||
pub mod mistral_parser;
|
||||
pub mod pythonic_parser;
|
||||
pub mod qwen_parser;
|
||||
pub mod step3_parser;
|
||||
|
||||
// Re-export parser types for convenience
|
||||
pub use deepseek_parser::DeepSeekParser;
|
||||
pub use glm4_moe_parser::Glm4MoeParser;
|
||||
pub use gpt_oss_parser::GptOssParser;
|
||||
pub use json_parser::JsonParser;
|
||||
pub use kimik2_parser::KimiK2Parser;
|
||||
pub use llama_parser::LlamaParser;
|
||||
pub use mistral_parser::MistralParser;
|
||||
pub use pythonic_parser::PythonicParser;
|
||||
pub use qwen_parser::QwenParser;
|
||||
pub use step3_parser::Step3Parser;
|
||||
434
sgl-router/src/tool_parser/parsers/pythonic_parser.rs
Normal file
434
sgl-router/src/tool_parser/parsers/pythonic_parser.rs
Normal file
@@ -0,0 +1,434 @@
|
||||
/// Pythonic format parser for tool calls
|
||||
///
|
||||
/// Handles Python function call syntax within square brackets:
|
||||
/// ```text
|
||||
/// [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
||||
/// ```
|
||||
///
|
||||
/// This format is used by Llama-4 models and uses Python literals
|
||||
/// rather than JSON for arguments.
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
python_literal_parser::parse_python_literal,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Parser for Pythonic tool call format
|
||||
pub struct PythonicParser {
|
||||
/// Regex to detect tool calls in Pythonic format
|
||||
tool_call_regex: Regex,
|
||||
/// Regex to parse function calls - cached for reuse
|
||||
call_regex: Regex,
|
||||
}
|
||||
|
||||
impl PythonicParser {
|
||||
/// Create a new Pythonic parser
|
||||
pub fn new() -> Self {
|
||||
// Simple regex to detect start of Pythonic tool calls
|
||||
// We'll use manual parsing for the actual extraction
|
||||
let pattern = r"\[[a-zA-Z_]\w*\(";
|
||||
let tool_call_regex = Regex::new(pattern).expect("Valid regex pattern");
|
||||
|
||||
// Compile the function call regex once
|
||||
let call_regex = Regex::new(r"(?s)^([a-zA-Z_]\w*)\((.*)\)$").expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
tool_call_regex,
|
||||
call_regex,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract tool calls using bracket counting (similar to MistralParser)
|
||||
fn extract_tool_calls(&self, text: &str) -> Option<String> {
|
||||
// Find the start of a tool call list - look for [ followed by a function name
|
||||
let chars: Vec<char> = text.chars().collect();
|
||||
|
||||
for start_idx in 0..chars.len() {
|
||||
if chars[start_idx] != '[' {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this looks like a tool call
|
||||
// Skip whitespace after [
|
||||
let mut check_idx = start_idx + 1;
|
||||
while check_idx < chars.len() && chars[check_idx].is_whitespace() {
|
||||
check_idx += 1;
|
||||
}
|
||||
|
||||
// Check if we have a function name (starts with letter or underscore)
|
||||
if check_idx >= chars.len()
|
||||
|| (!chars[check_idx].is_alphabetic() && chars[check_idx] != '_')
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Now count brackets to find the matching ]
|
||||
let mut bracket_count = 0;
|
||||
let mut _paren_count = 0;
|
||||
let mut _brace_count = 0;
|
||||
let mut in_string = false;
|
||||
let mut string_char = ' ';
|
||||
let mut escape_next = false;
|
||||
|
||||
for i in start_idx..chars.len() {
|
||||
let ch = chars[i];
|
||||
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if ch == '\\' && in_string {
|
||||
escape_next = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if !in_string && (ch == '"' || ch == '\'') {
|
||||
in_string = true;
|
||||
string_char = ch;
|
||||
} else if in_string && ch == string_char && !escape_next {
|
||||
in_string = false;
|
||||
} else if !in_string {
|
||||
match ch {
|
||||
'[' => bracket_count += 1,
|
||||
']' => {
|
||||
bracket_count -= 1;
|
||||
if bracket_count == 0 {
|
||||
// Found the matching bracket
|
||||
let extracted: String = chars[start_idx..=i].iter().collect();
|
||||
// Verify this actually contains a function call
|
||||
if extracted.contains('(') && extracted.contains(')') {
|
||||
return Some(extracted);
|
||||
}
|
||||
}
|
||||
}
|
||||
'(' => _paren_count += 1,
|
||||
')' => _paren_count -= 1,
|
||||
'{' => _brace_count += 1,
|
||||
'}' => _brace_count -= 1,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Strip special tokens that Llama 4 might output
|
||||
fn strip_special_tokens(text: &str) -> String {
|
||||
text.replace("<|python_start|>", "")
|
||||
.replace("<|python_end|>", "")
|
||||
}
|
||||
|
||||
/// Parse a single function call from Python syntax
|
||||
fn parse_function_call(&self, call_str: &str) -> ToolParserResult<Option<ToolCall>> {
|
||||
// Use cached regex instead of creating new one
|
||||
if let Some(captures) = self.call_regex.captures(call_str.trim()) {
|
||||
let function_name = captures.get(1).unwrap().as_str();
|
||||
let args_str = captures.get(2).unwrap().as_str();
|
||||
|
||||
// Parse arguments
|
||||
let arguments = self.parse_arguments(args_str)?;
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id: format!("call_{}", uuid::Uuid::new_v4()),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: function_name.to_string(),
|
||||
arguments: serde_json::to_string(&arguments)?,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse Python-style arguments into JSON
|
||||
fn parse_arguments(&self, args_str: &str) -> ToolParserResult<Value> {
|
||||
if args_str.trim().is_empty() {
|
||||
return Ok(json!({}));
|
||||
}
|
||||
|
||||
let mut result = serde_json::Map::new();
|
||||
let mut current_key = String::new();
|
||||
let mut current_value = String::new();
|
||||
let mut in_key = true;
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut string_char = ' ';
|
||||
let mut escape_next = false;
|
||||
|
||||
let chars: Vec<char> = args_str.chars().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < chars.len() {
|
||||
let ch = chars[i];
|
||||
|
||||
if escape_next {
|
||||
if in_key {
|
||||
current_key.push(ch);
|
||||
} else {
|
||||
current_value.push(ch);
|
||||
}
|
||||
escape_next = false;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if ch == '\\' && in_string {
|
||||
escape_next = true;
|
||||
current_value.push(ch);
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle string literals
|
||||
if !in_string && (ch == '"' || ch == '\'') {
|
||||
in_string = true;
|
||||
string_char = ch;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
} else if in_string && ch == string_char && !escape_next {
|
||||
in_string = false;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
} else if in_string {
|
||||
if in_key {
|
||||
current_key.push(ch);
|
||||
} else {
|
||||
current_value.push(ch);
|
||||
}
|
||||
} else {
|
||||
// Not in string
|
||||
match ch {
|
||||
'=' if in_key && depth == 0 => {
|
||||
in_key = false;
|
||||
}
|
||||
',' if depth == 0 => {
|
||||
// End of current argument
|
||||
if !current_key.is_empty() {
|
||||
let value = parse_python_literal(current_value.trim())?;
|
||||
result.insert(current_key.trim().to_string(), value);
|
||||
}
|
||||
current_key.clear();
|
||||
current_value.clear();
|
||||
in_key = true;
|
||||
}
|
||||
'[' | '{' | '(' => {
|
||||
depth += 1;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
}
|
||||
']' | '}' | ')' => {
|
||||
depth -= 1;
|
||||
if !in_key {
|
||||
current_value.push(ch);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if in_key {
|
||||
if !ch.is_whitespace() || !current_key.is_empty() {
|
||||
current_key.push(ch);
|
||||
}
|
||||
} else {
|
||||
current_value.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Handle the last argument
|
||||
if !current_key.is_empty() {
|
||||
let value = parse_python_literal(current_value.trim())?;
|
||||
result.insert(current_key.trim().to_string(), value);
|
||||
}
|
||||
|
||||
Ok(Value::Object(result))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for PythonicParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let cleaned = Self::strip_special_tokens(text);
|
||||
|
||||
// Extract tool calls using bracket counting
|
||||
if let Some(tool_calls_text) = self.extract_tool_calls(&cleaned) {
|
||||
// Remove the outer brackets
|
||||
let tool_calls_str = &tool_calls_text[1..tool_calls_text.len() - 1];
|
||||
|
||||
// Split into individual function calls
|
||||
let mut calls = Vec::new();
|
||||
let mut current_call = String::new();
|
||||
let mut paren_depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut string_char = ' ';
|
||||
|
||||
for ch in tool_calls_str.chars() {
|
||||
if !in_string && (ch == '"' || ch == '\'') {
|
||||
in_string = true;
|
||||
string_char = ch;
|
||||
current_call.push(ch);
|
||||
} else if in_string && ch == string_char {
|
||||
in_string = false;
|
||||
current_call.push(ch);
|
||||
} else if in_string {
|
||||
current_call.push(ch);
|
||||
} else {
|
||||
match ch {
|
||||
'(' => {
|
||||
paren_depth += 1;
|
||||
current_call.push(ch);
|
||||
}
|
||||
')' => {
|
||||
paren_depth -= 1;
|
||||
current_call.push(ch);
|
||||
}
|
||||
',' if paren_depth == 0 => {
|
||||
// End of current function call
|
||||
if let Some(call) = self.parse_function_call(current_call.trim())? {
|
||||
calls.push(call);
|
||||
}
|
||||
current_call.clear();
|
||||
}
|
||||
_ => {
|
||||
if !ch.is_whitespace() || !current_call.is_empty() {
|
||||
current_call.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the last call (important for single calls or the last call in a list)
|
||||
if !current_call.trim().is_empty() {
|
||||
if let Some(call) = self.parse_function_call(current_call.trim())? {
|
||||
calls.push(call);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(calls)
|
||||
} else {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
// For Pythonic format, we accumulate until we have a complete tool call
|
||||
// This is a simplified implementation
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Try to parse if we have a complete tool call
|
||||
let cleaned = Self::strip_special_tokens(&state.buffer);
|
||||
if self.extract_tool_calls(&cleaned).is_some() {
|
||||
let result = self.parse_complete(&state.buffer).await?;
|
||||
if !result.is_empty() {
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::ToolComplete(
|
||||
result.into_iter().next().unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
let cleaned = Self::strip_special_tokens(text);
|
||||
self.tool_call_regex.is_match(&cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PythonicParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_function_call() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[search_web(query="Rust programming", max_results=5)]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search_web");
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "Rust programming");
|
||||
assert_eq!(args["max_results"], 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_function_calls() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_literals() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[test(flag=True, disabled=False, optional=None)]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["flag"], true);
|
||||
assert_eq!(args["disabled"], false);
|
||||
assert_eq!(args["optional"], Value::Null);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_special_tokens() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["x"], 10);
|
||||
assert_eq!(args["y"], 20);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama4_format() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[get_weather(city="London", units="celsius")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "London");
|
||||
assert_eq!(args["units"], "celsius");
|
||||
}
|
||||
}
|
||||
396
sgl-router/src/tool_parser/parsers/qwen_parser.rs
Normal file
396
sgl-router/src/tool_parser/parsers/qwen_parser.rs
Normal file
@@ -0,0 +1,396 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Qwen format parser for tool calls
|
||||
///
|
||||
/// Handles the Qwen 2.5/3 specific format:
|
||||
/// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>`
|
||||
///
|
||||
/// Features:
|
||||
/// - XML-style tags with JSON content
|
||||
/// - Support for multiple sequential tool calls
|
||||
/// - Newline-aware parsing
|
||||
pub struct QwenParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex for extracting tool calls
|
||||
extractor: Regex,
|
||||
}
|
||||
|
||||
impl QwenParser {
|
||||
/// Create a new Qwen parser
|
||||
pub fn new() -> Self {
|
||||
// Use (?s) flag for DOTALL mode to handle newlines
|
||||
let pattern = r"(?s)<tool_call>\n(.*?)\n</tool_call>";
|
||||
let extractor = Regex::new(pattern).expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
extractor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract all tool call blocks from text
|
||||
fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> {
|
||||
self.extractor
|
||||
.captures_iter(text)
|
||||
.filter_map(|cap| cap.get(1).map(|m| m.as_str()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a single JSON object into a ToolCall
|
||||
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
|
||||
let name = obj.get("name").and_then(|v| v.as_str());
|
||||
|
||||
if let Some(name) = name {
|
||||
// Get arguments - Qwen uses "arguments" key
|
||||
let empty_obj = Value::Object(serde_json::Map::new());
|
||||
let args = obj.get("arguments").unwrap_or(&empty_obj);
|
||||
|
||||
// Convert arguments to JSON string
|
||||
let arguments = serde_json::to_string(args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID with index for multiple tools
|
||||
let id = format!("qwen_call_{}", index);
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.to_string(),
|
||||
arguments,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains Qwen tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("<tool_call>")
|
||||
}
|
||||
|
||||
/// Find the start position of a tool call
|
||||
fn find_tool_start(&self, text: &str) -> Option<usize> {
|
||||
text.find("<tool_call>\n")
|
||||
}
|
||||
|
||||
/// Find the end position of a tool call
|
||||
fn find_tool_end(&self, text: &str, start_pos: usize) -> Option<usize> {
|
||||
let search_from = start_pos + "<tool_call>\n".len();
|
||||
text[search_from..]
|
||||
.find("\n</tool_call>")
|
||||
.map(|pos| search_from + pos + "\n</tool_call>".len())
|
||||
}
|
||||
|
||||
/// Check if buffer ends with a partial token
|
||||
fn ends_with_partial_token(&self, buffer: &str) -> Option<usize> {
|
||||
// Check for partial start token
|
||||
let start_token = "<tool_call>\n";
|
||||
// Use inclusive range to check if entire buffer could be a prefix
|
||||
for i in 1..=start_token.len().min(buffer.len()) {
|
||||
if start_token.starts_with(&buffer[buffer.len() - i..]) {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for partial end token
|
||||
let end_token = "\n</tool_call>";
|
||||
// Only check if buffer ends with a partial match (not the complete token without newline)
|
||||
// If buffer ends with "</tool_call>", that's not a partial token - it's missing the newline
|
||||
if buffer.ends_with("</tool_call>") {
|
||||
// This is a complete end tag, just missing the leading newline
|
||||
// Not a partial token situation
|
||||
return None;
|
||||
}
|
||||
// Use inclusive range to check if entire buffer could be a prefix
|
||||
(1..=end_token.len().min(buffer.len()))
|
||||
.find(|&i| end_token.starts_with(&buffer[buffer.len() - i..]))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QwenParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for QwenParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if text contains Qwen format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Extract all tool call blocks
|
||||
let tool_blocks = self.extract_tool_calls(text);
|
||||
let mut tools = Vec::new();
|
||||
|
||||
for (index, json_str) in tool_blocks.iter().enumerate() {
|
||||
// Parse each JSON block
|
||||
match serde_json::from_str::<Value>(json_str.trim()) {
|
||||
Ok(value) => {
|
||||
if let Some(tool) = self.parse_single_object(&value, index)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Skip malformed JSON blocks
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check for partial token at end of buffer
|
||||
if let Some(_partial_len) = self.ends_with_partial_token(&state.buffer) {
|
||||
// Hold back the partial token
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Check if we have the start marker
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Find start and end positions
|
||||
if let Some(start_pos) = self.find_tool_start(&state.buffer) {
|
||||
// Check if we have the complete tool call
|
||||
if let Some(end_pos) = self.find_tool_end(&state.buffer, start_pos) {
|
||||
// Extract the JSON content
|
||||
let json_start = start_pos + "<tool_call>\n".len();
|
||||
let json_end = end_pos - "\n</tool_call>".len();
|
||||
let json_str = &state.buffer[json_start..json_end];
|
||||
|
||||
// Parse the complete JSON
|
||||
match serde_json::from_str::<Value>(json_str.trim()) {
|
||||
Ok(value) => {
|
||||
if let Some(tool) = self.parse_single_object(&value, 0)? {
|
||||
// Clear the consumed part from buffer using drain for efficiency
|
||||
state.buffer.drain(..end_pos);
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// JSON parsing failed, might be incomplete
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// We have start but no end yet - try partial parsing
|
||||
let json_start = start_pos + "<tool_call>\n".len();
|
||||
let partial_json = &state.buffer[json_start..];
|
||||
|
||||
// Remove trailing newline if present (might be start of end token)
|
||||
let partial_json = partial_json.trim_end();
|
||||
|
||||
// Try to parse with partial JSON parser
|
||||
match self.partial_json.parse_value(partial_json) {
|
||||
Ok((value, _consumed)) => {
|
||||
// Extract tool name if available
|
||||
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
|
||||
// Check if we've already sent the name
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for arguments
|
||||
if let Some(args) = value.get("arguments") {
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Keep buffering
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
// Check if text contains Qwen-specific markers. If not, it's not this format.
|
||||
if !self.has_tool_markers(text) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to extract tool calls to see if we have a complete, valid one.
|
||||
let tool_blocks = self.extract_tool_calls(text);
|
||||
for json_str in &tool_blocks {
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json_str.trim()) {
|
||||
if let Some(obj) = value.as_object() {
|
||||
if obj.contains_key("name") && obj.contains_key("arguments") {
|
||||
// Found a valid, complete tool call.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have the marker but no valid complete tool call,
|
||||
// it could be a partial stream. We should detect this as the format.
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_qwen_format() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "get_weather", "arguments": {"location": "Beijing", "units": "celsius"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Beijing"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_multiple_tools() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "search", "arguments": {"query": "rust programming"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_normal_text() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"Let me help you with that.
|
||||
<tool_call>
|
||||
{"name": "get_info", "arguments": {"topic": "Rust"}}
|
||||
</tool_call>
|
||||
Here are the results."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_info");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_json_structures() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{
|
||||
"name": "process_data",
|
||||
"arguments": {
|
||||
"data": {
|
||||
"nested": {
|
||||
"array": [1, 2, 3],
|
||||
"object": {"key": "value"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process_data");
|
||||
assert!(result[0].function.arguments.contains("nested"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = QwenParser::new();
|
||||
|
||||
assert!(parser.detect_format(
|
||||
r#"<tool_call>
|
||||
{"name": "test", "arguments": {}}
|
||||
</tool_call>"#
|
||||
));
|
||||
|
||||
assert!(parser.detect_format(
|
||||
r#"Text before <tool_call>
|
||||
{"name": "test", "arguments": {}}
|
||||
</tool_call> text after"#
|
||||
));
|
||||
|
||||
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
|
||||
// Partial format should still be detected
|
||||
assert!(parser.detect_format("<tool_call>"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_partial() {
|
||||
let parser = QwenParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Simulate streaming chunks
|
||||
let chunks = vec![
|
||||
"<tool_call>\n",
|
||||
r#"{"name": "search","#,
|
||||
r#" "arguments": {"query":"#,
|
||||
r#" "rust"}}"#,
|
||||
"\n</tool_call>",
|
||||
];
|
||||
|
||||
let mut found_name = false;
|
||||
let mut found_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "search");
|
||||
found_name = true;
|
||||
}
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
found_complete = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(found_name || found_complete); // At least one should be found
|
||||
}
|
||||
}
|
||||
348
sgl-router/src/tool_parser/parsers/step3_parser.rs
Normal file
348
sgl-router/src/tool_parser/parsers/step3_parser.rs
Normal file
@@ -0,0 +1,348 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Step3 format parser for tool calls
|
||||
///
|
||||
/// Handles the Step3 specific format with steptml XML:
|
||||
/// `<|tool_calls_begin|><|tool_call_begin|>function<|tool_sep|><steptml:invoke name="{name}"><steptml:parameter name="{k}">{v}</steptml:parameter></steptml:invoke><|tool_call_end|><|tool_calls_end|>`
|
||||
///
|
||||
/// Features:
|
||||
/// - Unicode token delimiters
|
||||
/// - StepTML XML format for invocations
|
||||
/// - Support for multiple sequential tool calls
|
||||
pub struct Step3Parser {
|
||||
/// Regex for extracting tool call blocks
|
||||
tool_call_extractor: Regex,
|
||||
/// Regex for extracting steptml invocations
|
||||
invoke_extractor: Regex,
|
||||
/// Regex for extracting parameters
|
||||
param_extractor: Regex,
|
||||
}
|
||||
|
||||
impl Step3Parser {
|
||||
/// Create a new Step3 parser
|
||||
pub fn new() -> Self {
|
||||
// Pattern for individual tool calls
|
||||
let tool_call_pattern = r"(?s)<|tool_call_begin|>.*?<|tool_call_end|>";
|
||||
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
|
||||
|
||||
// Pattern for steptml invocations
|
||||
let invoke_pattern = r#"(?s)<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>"#;
|
||||
let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern");
|
||||
|
||||
// Pattern for steptml parameters - using non-greedy match for values to handle < characters
|
||||
let param_pattern = r#"(?s)<steptml:parameter name="([^"]+)">(.+?)</steptml:parameter>"#;
|
||||
let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
tool_call_extractor,
|
||||
invoke_extractor,
|
||||
param_extractor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains Step3 tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("<|tool_calls_begin|>")
|
||||
}
|
||||
|
||||
/// Parse parameters from steptml format
|
||||
fn parse_steptml_parameters(
|
||||
&self,
|
||||
params_text: &str,
|
||||
) -> ToolParserResult<serde_json::Map<String, Value>> {
|
||||
let mut parameters = serde_json::Map::new();
|
||||
|
||||
for capture in self.param_extractor.captures_iter(params_text) {
|
||||
let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
|
||||
let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
// Try to parse the value as JSON first, fallback to string
|
||||
let param_value = if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
|
||||
json_val
|
||||
} else {
|
||||
// Try parsing as Python literal
|
||||
if param_value_str == "true" || param_value_str == "True" {
|
||||
Value::Bool(true)
|
||||
} else if param_value_str == "false" || param_value_str == "False" {
|
||||
Value::Bool(false)
|
||||
} else if param_value_str == "null" || param_value_str == "None" {
|
||||
Value::Null
|
||||
} else if let Ok(num) = param_value_str.parse::<i64>() {
|
||||
Value::Number(num.into())
|
||||
} else if let Ok(num) = param_value_str.parse::<f64>() {
|
||||
if let Some(n) = serde_json::Number::from_f64(num) {
|
||||
Value::Number(n)
|
||||
} else {
|
||||
Value::String(param_value_str.to_string())
|
||||
}
|
||||
} else {
|
||||
Value::String(param_value_str.to_string())
|
||||
}
|
||||
};
|
||||
|
||||
parameters.insert(param_name.to_string(), param_value);
|
||||
}
|
||||
|
||||
Ok(parameters)
|
||||
}
|
||||
|
||||
/// Parse a single tool call block
|
||||
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
|
||||
// Check if it contains function marker and tool separator
|
||||
if !block.contains("function") || !block.contains("<|tool_sep|>") {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Split by tool separator
|
||||
let parts: Vec<&str> = block.split("<|tool_sep|>").collect();
|
||||
if parts.len() != 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Check if it's a function type
|
||||
if !parts[0].contains("function") {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let invoke_part = parts[1];
|
||||
|
||||
// Extract steptml invoke
|
||||
if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
|
||||
let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
// Validate function name is not empty
|
||||
if func_name.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let params_text = captures.get(2).map_or("", |m| m.as_str());
|
||||
|
||||
// Parse parameters
|
||||
let parameters = self.parse_steptml_parameters(params_text)?;
|
||||
|
||||
let arguments_str = serde_json::to_string(¶meters)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID
|
||||
let id = format!("step3_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name.to_string(),
|
||||
arguments: arguments_str,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Step3Parser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for Step3Parser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
// Check if text contains Step3 format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Find the tool calls section
|
||||
if let Some(start_pos) = text.find("<|tool_calls_begin|>") {
|
||||
let search_from = start_pos + "<|tool_calls_begin|>".len();
|
||||
|
||||
// Find the end of tool calls section
|
||||
if let Some(end_pos) = text[search_from..].find("<|tool_calls_end|>") {
|
||||
let tool_section = &text[search_from..search_from + end_pos];
|
||||
|
||||
// Extract all tool call blocks
|
||||
let mut tools = Vec::new();
|
||||
for mat in self.tool_call_extractor.find_iter(tool_section) {
|
||||
if let Some(tool) = self.parse_tool_call(mat.as_str())? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(tools);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
|
||||
// Check for tool markers
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No markers found, return as incomplete
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Look for start of tool calls
|
||||
if let Some(start_pos) = state.buffer.find("<|tool_calls_begin|>") {
|
||||
let search_from = start_pos + "<|tool_calls_begin|>".len();
|
||||
|
||||
// Look for individual tool call start
|
||||
if let Some(call_start) = state.buffer[search_from..].find("<|tool_call_begin|>") {
|
||||
let call_start_abs = search_from + call_start;
|
||||
|
||||
// Look for the end of this tool call
|
||||
let search_end_from = call_start_abs + "<|tool_call_begin|>".len();
|
||||
if let Some(call_end) = state.buffer[search_end_from..].find("<|tool_call_end|>")
|
||||
{
|
||||
let call_end_abs = search_end_from + call_end + "<|tool_call_end|>".len();
|
||||
|
||||
// Extract and parse the complete tool call
|
||||
let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
|
||||
|
||||
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
|
||||
// Remove the processed part from buffer
|
||||
state.buffer.drain(..call_end_abs);
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
} else {
|
||||
// Tool call not complete yet, try to extract partial info
|
||||
let partial = &state.buffer[search_end_from..];
|
||||
|
||||
// Check for tool separator
|
||||
if let Some(sep_pos) = partial.find("<|tool_sep|>") {
|
||||
// Check if it's a function
|
||||
if partial[..sep_pos].contains("function") {
|
||||
let after_sep = &partial[sep_pos + "<|tool_sep|>".len()..];
|
||||
|
||||
// Try to extract function name from steptml:invoke
|
||||
if let Some(name_match) = self.invoke_extractor.captures(after_sep) {
|
||||
let func_name = name_match.get(1).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
if !state.in_string && !func_name.is_empty() {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: func_name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Try to extract partial parameters
|
||||
if let Some(params_text) = name_match.get(2) {
|
||||
let parameters =
|
||||
self.parse_steptml_parameters(params_text.as_str())?;
|
||||
|
||||
if !parameters.is_empty() {
|
||||
let args_str = serde_json::to_string(¶meters)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
self.has_tool_markers(text)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_step3_single_tool() {
|
||||
let parser = Step3Parser::new();
|
||||
let input = r#"Some text
|
||||
<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="get_weather">
|
||||
<steptml:parameter name="location">Tokyo</steptml:parameter>
|
||||
<steptml:parameter name="units">celsius</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Tokyo"));
|
||||
assert!(result[0].function.arguments.contains("celsius"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_step3_multiple_tools() {
|
||||
let parser = Step3Parser::new();
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="search">
|
||||
<steptml:parameter name="query">rust programming</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="calculate">
|
||||
<steptml:parameter name="expression">2 + 2</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_step3_mixed_types() {
|
||||
let parser = Step3Parser::new();
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="process_data">
|
||||
<steptml:parameter name="count">42</steptml:parameter>
|
||||
<steptml:parameter name="active">true</steptml:parameter>
|
||||
<steptml:parameter name="rate">1.5</steptml:parameter>
|
||||
<steptml:parameter name="name">test</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process_data");
|
||||
|
||||
// Parse arguments to check types
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["count"], 42);
|
||||
assert_eq!(args["active"], true);
|
||||
assert_eq!(args["rate"], 1.5);
|
||||
assert_eq!(args["name"], "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = Step3Parser::new();
|
||||
assert!(parser.detect_format("<|tool_calls_begin|>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
527
sgl-router/src/tool_parser/partial_json.rs
Normal file
527
sgl-router/src/tool_parser/partial_json.rs
Normal file
@@ -0,0 +1,527 @@
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
traits::PartialJsonParser,
|
||||
};
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
/// Parser for incomplete JSON
|
||||
pub struct PartialJson {
|
||||
/// Maximum depth for nested structures
|
||||
max_depth: usize,
|
||||
/// Whether to allow incomplete values
|
||||
allow_incomplete: bool,
|
||||
}
|
||||
|
||||
impl PartialJson {
|
||||
/// Create a new partial JSON parser
|
||||
pub fn new(max_depth: usize, allow_incomplete: bool) -> Self {
|
||||
Self {
|
||||
max_depth,
|
||||
allow_incomplete,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse potentially incomplete JSON, returning parsed value and consumed bytes
|
||||
pub fn parse_value(&self, input: &str) -> ToolParserResult<(Value, usize)> {
|
||||
let mut parser = Parser::new(input, self.max_depth, self.allow_incomplete);
|
||||
let value = parser.parse_value(0)?;
|
||||
Ok((value, parser.position))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PartialJson {
|
||||
fn default() -> Self {
|
||||
Self::new(32, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialJsonParser for PartialJson {
|
||||
fn parse(&self, input: &str) -> ToolParserResult<(Value, usize)> {
|
||||
self.parse_value(input)
|
||||
}
|
||||
|
||||
fn is_complete(&self, input: &str) -> bool {
|
||||
// Try to parse as complete JSON
|
||||
serde_json::from_str::<Value>(input).is_ok()
|
||||
}
|
||||
|
||||
fn max_depth(&self) -> usize {
|
||||
self.max_depth
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal parser state
|
||||
struct Parser<'a> {
|
||||
chars: std::iter::Peekable<std::str::Chars<'a>>,
|
||||
position: usize,
|
||||
max_depth: usize,
|
||||
allow_incomplete: bool,
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
fn new(input: &'a str, max_depth: usize, allow_incomplete: bool) -> Self {
|
||||
Self {
|
||||
chars: input.chars().peekable(),
|
||||
position: 0,
|
||||
max_depth,
|
||||
allow_incomplete,
|
||||
}
|
||||
}
|
||||
|
||||
fn peek(&mut self) -> Option<char> {
|
||||
self.chars.peek().copied()
|
||||
}
|
||||
|
||||
fn advance(&mut self) {
|
||||
if self.chars.next().is_some() {
|
||||
self.position += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn skip_whitespace(&mut self) {
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_whitespace() {
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_value(&mut self, depth: usize) -> ToolParserResult<Value> {
|
||||
if depth > self.max_depth {
|
||||
return Err(ToolParserError::DepthExceeded(self.max_depth));
|
||||
}
|
||||
|
||||
self.skip_whitespace();
|
||||
|
||||
match self.peek() {
|
||||
Some('{') => self.parse_object(depth + 1),
|
||||
Some('[') => self.parse_array(depth + 1),
|
||||
Some('"') => self.parse_string(),
|
||||
Some('t') | Some('f') => self.parse_bool(),
|
||||
Some('n') => self.parse_null(),
|
||||
Some(c) if c == '-' || c.is_ascii_digit() => self.parse_number(),
|
||||
_ => {
|
||||
if self.allow_incomplete {
|
||||
Ok(Value::Null)
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed(
|
||||
"Unexpected character".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_object(&mut self, depth: usize) -> ToolParserResult<Value> {
|
||||
if depth > self.max_depth {
|
||||
return Err(ToolParserError::DepthExceeded(self.max_depth));
|
||||
}
|
||||
|
||||
let mut object = Map::new();
|
||||
|
||||
// Consume '{'
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
|
||||
// Check for empty object
|
||||
if self.peek() == Some('}') {
|
||||
self.advance();
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
|
||||
loop {
|
||||
// Parse key
|
||||
let key = match self.parse_string() {
|
||||
Ok(Value::String(s)) => s,
|
||||
Err(_) if self.allow_incomplete => {
|
||||
// Incomplete object
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
_ => return Err(ToolParserError::ParsingFailed("Expected string key".into())),
|
||||
};
|
||||
|
||||
self.skip_whitespace();
|
||||
|
||||
// Expect ':'
|
||||
if self.peek() != Some(':') {
|
||||
if self.allow_incomplete {
|
||||
// Add null value for incomplete pair
|
||||
object.insert(key, Value::Null);
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
return Err(ToolParserError::ParsingFailed("Expected ':'".into()));
|
||||
}
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
|
||||
// Parse value (keep same depth - we already incremented in parse_object)
|
||||
let value = match self.parse_value(depth) {
|
||||
Ok(v) => v,
|
||||
Err(_) if self.allow_incomplete => {
|
||||
// Add null for incomplete value
|
||||
object.insert(key, Value::Null);
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
object.insert(key, value);
|
||||
self.skip_whitespace();
|
||||
|
||||
match self.peek() {
|
||||
Some(',') => {
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
// Check for trailing comma
|
||||
if self.peek() == Some('}') {
|
||||
self.advance();
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
}
|
||||
Some('}') => {
|
||||
self.advance();
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
None if self.allow_incomplete => {
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
_ => {
|
||||
if self.allow_incomplete {
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
return Err(ToolParserError::ParsingFailed("Expected ',' or '}'".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_array(&mut self, depth: usize) -> ToolParserResult<Value> {
|
||||
if depth > self.max_depth {
|
||||
return Err(ToolParserError::DepthExceeded(self.max_depth));
|
||||
}
|
||||
|
||||
let mut array = Vec::new();
|
||||
|
||||
// Consume '['
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
|
||||
// Check for empty array
|
||||
if self.peek() == Some(']') {
|
||||
self.advance();
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
|
||||
loop {
|
||||
// Parse value (keep same depth - we already incremented in parse_object)
|
||||
let value = match self.parse_value(depth) {
|
||||
Ok(v) => v,
|
||||
Err(_) if self.allow_incomplete => {
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
array.push(value);
|
||||
self.skip_whitespace();
|
||||
|
||||
match self.peek() {
|
||||
Some(',') => {
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
// Check for trailing comma
|
||||
if self.peek() == Some(']') {
|
||||
self.advance();
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
}
|
||||
Some(']') => {
|
||||
self.advance();
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
None if self.allow_incomplete => {
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
_ => {
|
||||
if self.allow_incomplete {
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
return Err(ToolParserError::ParsingFailed("Expected ',' or ']'".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_string(&mut self) -> ToolParserResult<Value> {
|
||||
if self.peek() != Some('"') {
|
||||
return Err(ToolParserError::ParsingFailed("Expected '\"'".into()));
|
||||
}
|
||||
|
||||
// Consume opening quote
|
||||
self.advance();
|
||||
|
||||
let mut string = String::new();
|
||||
let mut escaped = false;
|
||||
|
||||
while let Some(ch) = self.peek() {
|
||||
if escaped {
|
||||
// Handle escape sequences
|
||||
let escaped_char = match ch {
|
||||
'"' | '\\' | '/' => ch,
|
||||
'b' => '\u{0008}',
|
||||
'f' => '\u{000C}',
|
||||
'n' => '\n',
|
||||
'r' => '\r',
|
||||
't' => '\t',
|
||||
'u' => {
|
||||
// Unicode escape
|
||||
self.advance();
|
||||
let hex = self.parse_unicode_escape()?;
|
||||
string.push(hex);
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
_ => ch, // Invalid escape, but be lenient
|
||||
};
|
||||
string.push(escaped_char);
|
||||
escaped = false;
|
||||
} else if ch == '\\' {
|
||||
escaped = true;
|
||||
} else if ch == '"' {
|
||||
// End of string
|
||||
self.advance();
|
||||
return Ok(Value::String(string));
|
||||
} else {
|
||||
string.push(ch);
|
||||
}
|
||||
self.advance();
|
||||
}
|
||||
|
||||
// Incomplete string
|
||||
if self.allow_incomplete {
|
||||
Ok(Value::String(string))
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed("Unterminated string".into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_unicode_escape(&mut self) -> ToolParserResult<char> {
|
||||
let mut hex = String::new();
|
||||
for _ in 0..4 {
|
||||
if let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_hexdigit() {
|
||||
hex.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if hex.len() == 4 {
|
||||
u32::from_str_radix(&hex, 16)
|
||||
.ok()
|
||||
.and_then(char::from_u32)
|
||||
.ok_or_else(|| ToolParserError::ParsingFailed("Invalid unicode escape".into()))
|
||||
} else if self.allow_incomplete {
|
||||
Ok('\u{FFFD}') // Replacement character
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed(
|
||||
"Incomplete unicode escape".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_number(&mut self) -> ToolParserResult<Value> {
|
||||
let mut number = String::new();
|
||||
|
||||
// Handle negative sign
|
||||
if self.peek() == Some('-') {
|
||||
number.push('-');
|
||||
self.advance();
|
||||
}
|
||||
|
||||
// Parse integer part
|
||||
if self.peek() == Some('0') {
|
||||
number.push('0');
|
||||
self.advance();
|
||||
} else {
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_digit() {
|
||||
number.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse decimal part
|
||||
if self.peek() == Some('.') {
|
||||
number.push('.');
|
||||
self.advance();
|
||||
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_digit() {
|
||||
number.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse exponent
|
||||
if let Some(ch) = self.peek() {
|
||||
if ch == 'e' || ch == 'E' {
|
||||
number.push(ch);
|
||||
self.advance();
|
||||
|
||||
if let Some(sign) = self.peek() {
|
||||
if sign == '+' || sign == '-' {
|
||||
number.push(sign);
|
||||
self.advance();
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_digit() {
|
||||
number.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try to parse as integer first, then as float
|
||||
if let Ok(n) = number.parse::<i64>() {
|
||||
Ok(Value::Number(serde_json::Number::from(n)))
|
||||
} else if let Ok(n) = number.parse::<f64>() {
|
||||
Ok(Value::Number(
|
||||
serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)),
|
||||
))
|
||||
} else if self.allow_incomplete {
|
||||
Ok(Value::Number(serde_json::Number::from(0)))
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed("Invalid number".into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_bool(&mut self) -> ToolParserResult<Value> {
|
||||
let mut word = String::new();
|
||||
|
||||
// Peek at upcoming characters to validate it looks like a boolean
|
||||
let mut temp_chars = self.chars.clone();
|
||||
while let Some(&ch) = temp_chars.peek() {
|
||||
if ch.is_alphabetic() && word.len() < 5 {
|
||||
// "false" is 5 chars
|
||||
word.push(ch);
|
||||
temp_chars.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it's a valid boolean prefix
|
||||
let is_valid = word == "true"
|
||||
|| word == "false"
|
||||
|| (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word)));
|
||||
|
||||
if !is_valid {
|
||||
return Err(ToolParserError::ParsingFailed("Invalid boolean".into()));
|
||||
}
|
||||
|
||||
// Now actually consume the characters
|
||||
word.clear();
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_alphabetic() {
|
||||
word.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
match word.as_str() {
|
||||
"true" => Ok(Value::Bool(true)),
|
||||
"false" => Ok(Value::Bool(false)),
|
||||
partial if self.allow_incomplete => {
|
||||
if "true".starts_with(partial) {
|
||||
Ok(Value::Bool(true))
|
||||
} else if "false".starts_with(partial) {
|
||||
Ok(Value::Bool(false))
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed("Invalid boolean".into()))
|
||||
}
|
||||
}
|
||||
_ => Err(ToolParserError::ParsingFailed("Invalid boolean".into())),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_null(&mut self) -> ToolParserResult<Value> {
|
||||
let mut word = String::new();
|
||||
|
||||
// Peek at upcoming characters to validate it looks like "null"
|
||||
let mut temp_chars = self.chars.clone();
|
||||
while let Some(&ch) = temp_chars.peek() {
|
||||
if ch.is_alphabetic() && word.len() < 4 {
|
||||
// "null" is 4 chars
|
||||
word.push(ch);
|
||||
temp_chars.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it's a valid null prefix
|
||||
let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word));
|
||||
|
||||
if !is_valid {
|
||||
return Err(ToolParserError::ParsingFailed("Invalid null".into()));
|
||||
}
|
||||
|
||||
// Now actually consume the characters
|
||||
word.clear();
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_alphabetic() {
|
||||
word.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) {
|
||||
Ok(Value::Null)
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed("Invalid null".into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility function to check if a string contains complete JSON
|
||||
pub fn is_complete_json(input: &str) -> bool {
|
||||
serde_json::from_str::<Value>(input).is_ok()
|
||||
}
|
||||
|
||||
/// Utility function to find common prefix between two strings
|
||||
pub fn find_common_prefix(s1: &str, s2: &str) -> usize {
|
||||
s1.chars()
|
||||
.zip(s2.chars())
|
||||
.take_while(|(a, b)| a == b)
|
||||
.count()
|
||||
}
|
||||
|
||||
/// Utility function to compute diff between old and new strings
|
||||
pub fn compute_diff(old: &str, new: &str) -> String {
|
||||
let common_len = find_common_prefix(old, new);
|
||||
// Convert character count to byte offset
|
||||
new.chars().skip(common_len).collect()
|
||||
}
|
||||
442
sgl-router/src/tool_parser/python_literal_parser.rs
Normal file
442
sgl-router/src/tool_parser/python_literal_parser.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
/// Minimal Python literal parser for Pythonic tool call format
|
||||
///
|
||||
/// This module provides a recursive descent parser for Python literals
|
||||
/// (strings, numbers, booleans, None, lists, dicts) without requiring
|
||||
/// a full Python AST parser.
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::tool_parser::errors::{ToolParserError, ToolParserResult};
|
||||
|
||||
/// Token types for Python literals
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum Token {
|
||||
// Literals
|
||||
String(String),
|
||||
Number(String),
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
|
||||
// Delimiters
|
||||
LeftBracket, // [
|
||||
RightBracket, // ]
|
||||
LeftBrace, // {
|
||||
RightBrace, // }
|
||||
LeftParen, // (
|
||||
RightParen, // )
|
||||
Comma, // ,
|
||||
Colon, // :
|
||||
Equals, // =
|
||||
|
||||
// Identifier for function names
|
||||
Identifier(String),
|
||||
|
||||
// End of input
|
||||
Eof,
|
||||
}
|
||||
|
||||
/// Lexer for Python literals
|
||||
struct Lexer {
|
||||
input: Vec<char>,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
impl Lexer {
|
||||
fn new(input: &str) -> Self {
|
||||
Self {
|
||||
input: input.chars().collect(),
|
||||
position: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn current_char(&self) -> Option<char> {
|
||||
self.input.get(self.position).copied()
|
||||
}
|
||||
|
||||
fn advance(&mut self) {
|
||||
if self.position < self.input.len() {
|
||||
self.position += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn skip_whitespace(&mut self) {
|
||||
while let Some(ch) = self.current_char() {
|
||||
if ch.is_whitespace() {
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_string(&mut self, quote_char: char) -> Result<String, ToolParserError> {
|
||||
let mut result = String::new();
|
||||
self.advance(); // Skip opening quote
|
||||
|
||||
while let Some(ch) = self.current_char() {
|
||||
if ch == '\\' {
|
||||
self.advance();
|
||||
if let Some(escaped) = self.current_char() {
|
||||
match escaped {
|
||||
'n' => result.push('\n'),
|
||||
't' => result.push('\t'),
|
||||
'r' => result.push('\r'),
|
||||
'\\' => result.push('\\'),
|
||||
'\'' => result.push('\''),
|
||||
'"' => result.push('"'),
|
||||
_ => {
|
||||
result.push('\\');
|
||||
result.push(escaped);
|
||||
}
|
||||
}
|
||||
self.advance();
|
||||
}
|
||||
} else if ch == quote_char {
|
||||
self.advance(); // Skip closing quote
|
||||
return Ok(result);
|
||||
} else {
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
}
|
||||
}
|
||||
|
||||
Err(ToolParserError::ParsingFailed("Unterminated string".into()))
|
||||
}
|
||||
|
||||
fn read_number(&mut self) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
// Handle negative numbers
|
||||
if self.current_char() == Some('-') {
|
||||
result.push('-');
|
||||
self.advance();
|
||||
}
|
||||
|
||||
// Read digits and decimal point
|
||||
while let Some(ch) = self.current_char() {
|
||||
if ch.is_ascii_digit() || ch == '.' || ch == 'e' || ch == 'E' || ch == '+' || ch == '-'
|
||||
{
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn read_identifier(&mut self) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
while let Some(ch) = self.current_char() {
|
||||
if ch.is_alphanumeric() || ch == '_' {
|
||||
result.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn next_token(&mut self) -> Result<Token, ToolParserError> {
|
||||
self.skip_whitespace();
|
||||
|
||||
match self.current_char() {
|
||||
None => Ok(Token::Eof),
|
||||
Some('[') => {
|
||||
self.advance();
|
||||
Ok(Token::LeftBracket)
|
||||
}
|
||||
Some(']') => {
|
||||
self.advance();
|
||||
Ok(Token::RightBracket)
|
||||
}
|
||||
Some('{') => {
|
||||
self.advance();
|
||||
Ok(Token::LeftBrace)
|
||||
}
|
||||
Some('}') => {
|
||||
self.advance();
|
||||
Ok(Token::RightBrace)
|
||||
}
|
||||
Some('(') => {
|
||||
self.advance();
|
||||
Ok(Token::LeftParen)
|
||||
}
|
||||
Some(')') => {
|
||||
self.advance();
|
||||
Ok(Token::RightParen)
|
||||
}
|
||||
Some(',') => {
|
||||
self.advance();
|
||||
Ok(Token::Comma)
|
||||
}
|
||||
Some(':') => {
|
||||
self.advance();
|
||||
Ok(Token::Colon)
|
||||
}
|
||||
Some('=') => {
|
||||
self.advance();
|
||||
Ok(Token::Equals)
|
||||
}
|
||||
Some('"') => Ok(Token::String(self.read_string('"')?)),
|
||||
Some('\'') => Ok(Token::String(self.read_string('\'')?)),
|
||||
Some(ch) if ch == '-' || ch.is_ascii_digit() => Ok(Token::Number(self.read_number())),
|
||||
Some(ch) if ch.is_alphabetic() || ch == '_' => {
|
||||
let ident = self.read_identifier();
|
||||
match ident.as_str() {
|
||||
"True" => Ok(Token::True),
|
||||
"False" => Ok(Token::False),
|
||||
"None" => Ok(Token::None),
|
||||
_ => Ok(Token::Identifier(ident)),
|
||||
}
|
||||
}
|
||||
Some(ch) => Err(ToolParserError::ParsingFailed(format!(
|
||||
"Unexpected character: {}",
|
||||
ch
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parser for Python literals
|
||||
pub struct PythonLiteralParser {
|
||||
lexer: Lexer,
|
||||
current_token: Token,
|
||||
}
|
||||
|
||||
impl PythonLiteralParser {
|
||||
pub fn new(input: &str) -> Result<Self, ToolParserError> {
|
||||
let mut lexer = Lexer::new(input);
|
||||
let current_token = lexer.next_token()?;
|
||||
Ok(Self {
|
||||
lexer,
|
||||
current_token,
|
||||
})
|
||||
}
|
||||
|
||||
fn advance(&mut self) -> Result<(), ToolParserError> {
|
||||
self.current_token = self.lexer.next_token()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expect(&mut self, expected: Token) -> Result<(), ToolParserError> {
|
||||
if self.current_token == expected {
|
||||
self.advance()?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed(format!(
|
||||
"Expected {:?}, got {:?}",
|
||||
expected, self.current_token
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Python literal value
|
||||
pub fn parse_value(&mut self) -> Result<Value, ToolParserError> {
|
||||
match &self.current_token.clone() {
|
||||
Token::String(s) => {
|
||||
let value = s.clone();
|
||||
self.advance()?;
|
||||
Ok(json!(value))
|
||||
}
|
||||
Token::Number(n) => {
|
||||
let value = if let Ok(int_val) = n.parse::<i64>() {
|
||||
json!(int_val)
|
||||
} else if let Ok(float_val) = n.parse::<f64>() {
|
||||
json!(float_val)
|
||||
} else {
|
||||
return Err(ToolParserError::ParsingFailed(format!(
|
||||
"Invalid number: {}",
|
||||
n
|
||||
)));
|
||||
};
|
||||
self.advance()?;
|
||||
Ok(value)
|
||||
}
|
||||
Token::True => {
|
||||
self.advance()?;
|
||||
Ok(json!(true))
|
||||
}
|
||||
Token::False => {
|
||||
self.advance()?;
|
||||
Ok(json!(false))
|
||||
}
|
||||
Token::None => {
|
||||
self.advance()?;
|
||||
Ok(Value::Null)
|
||||
}
|
||||
Token::LeftBracket => self.parse_list(),
|
||||
Token::LeftBrace => self.parse_dict(),
|
||||
_ => Err(ToolParserError::ParsingFailed(format!(
|
||||
"Unexpected token: {:?}",
|
||||
self.current_token
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Python list: [item1, item2, ...]
|
||||
fn parse_list(&mut self) -> Result<Value, ToolParserError> {
|
||||
self.expect(Token::LeftBracket)?;
|
||||
let mut items = Vec::new();
|
||||
|
||||
// Handle empty list
|
||||
if self.current_token == Token::RightBracket {
|
||||
self.advance()?;
|
||||
return Ok(json!(items));
|
||||
}
|
||||
|
||||
loop {
|
||||
items.push(self.parse_value()?);
|
||||
|
||||
if self.current_token == Token::Comma {
|
||||
self.advance()?;
|
||||
// Handle trailing comma
|
||||
if self.current_token == Token::RightBracket {
|
||||
break;
|
||||
}
|
||||
} else if self.current_token == Token::RightBracket {
|
||||
break;
|
||||
} else {
|
||||
return Err(ToolParserError::ParsingFailed(format!(
|
||||
"Expected ',' or ']', got {:?}",
|
||||
self.current_token
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
self.expect(Token::RightBracket)?;
|
||||
Ok(json!(items))
|
||||
}
|
||||
|
||||
/// Parse a Python dict: {key1: value1, key2: value2, ...}
|
||||
fn parse_dict(&mut self) -> Result<Value, ToolParserError> {
|
||||
self.expect(Token::LeftBrace)?;
|
||||
let mut map = HashMap::new();
|
||||
|
||||
// Handle empty dict
|
||||
if self.current_token == Token::RightBrace {
|
||||
self.advance()?;
|
||||
return Ok(json!(map));
|
||||
}
|
||||
|
||||
loop {
|
||||
// Parse key (must be a string)
|
||||
let key = match &self.current_token {
|
||||
Token::String(s) => {
|
||||
let k = s.clone();
|
||||
self.advance()?;
|
||||
k
|
||||
}
|
||||
_ => {
|
||||
return Err(ToolParserError::ParsingFailed(format!(
|
||||
"Expected string key, got {:?}",
|
||||
self.current_token
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
self.expect(Token::Colon)?;
|
||||
|
||||
// Parse value
|
||||
let value = self.parse_value()?;
|
||||
map.insert(key, value);
|
||||
|
||||
if self.current_token == Token::Comma {
|
||||
self.advance()?;
|
||||
// Handle trailing comma
|
||||
if self.current_token == Token::RightBrace {
|
||||
break;
|
||||
}
|
||||
} else if self.current_token == Token::RightBrace {
|
||||
break;
|
||||
} else {
|
||||
return Err(ToolParserError::ParsingFailed(format!(
|
||||
"Expected ',' or '}}', got {:?}",
|
||||
self.current_token
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
self.expect(Token::RightBrace)?;
|
||||
Ok(json!(map))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Python literal string into a JSON value
|
||||
pub fn parse_python_literal(input: &str) -> ToolParserResult<Value> {
|
||||
let mut parser = PythonLiteralParser::new(input)?;
|
||||
parser.parse_value()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_primitives() {
|
||||
assert_eq!(parse_python_literal("True").unwrap(), json!(true));
|
||||
assert_eq!(parse_python_literal("False").unwrap(), json!(false));
|
||||
assert_eq!(parse_python_literal("None").unwrap(), Value::Null);
|
||||
assert_eq!(parse_python_literal("42").unwrap(), json!(42));
|
||||
assert_eq!(parse_python_literal("12.345").unwrap(), json!(12.345));
|
||||
assert_eq!(parse_python_literal("-42").unwrap(), json!(-42));
|
||||
assert_eq!(parse_python_literal("\"hello\"").unwrap(), json!("hello"));
|
||||
assert_eq!(parse_python_literal("'world'").unwrap(), json!("world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_list() {
|
||||
assert_eq!(parse_python_literal("[]").unwrap(), json!([]));
|
||||
assert_eq!(parse_python_literal("[1, 2, 3]").unwrap(), json!([1, 2, 3]));
|
||||
assert_eq!(
|
||||
parse_python_literal("[\"a\", \"b\", \"c\"]").unwrap(),
|
||||
json!(["a", "b", "c"])
|
||||
);
|
||||
assert_eq!(
|
||||
parse_python_literal("[True, False, None]").unwrap(),
|
||||
json!([true, false, null])
|
||||
);
|
||||
// Nested list
|
||||
assert_eq!(
|
||||
parse_python_literal("[[1, 2], [3, 4]]").unwrap(),
|
||||
json!([[1, 2], [3, 4]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_dict() {
|
||||
assert_eq!(parse_python_literal("{}").unwrap(), json!({}));
|
||||
assert_eq!(
|
||||
parse_python_literal("{\"a\": 1, \"b\": 2}").unwrap(),
|
||||
json!({"a": 1, "b": 2})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_python_literal("{'x': True, 'y': False}").unwrap(),
|
||||
json!({"x": true, "y": false})
|
||||
);
|
||||
// Nested dict
|
||||
assert_eq!(
|
||||
parse_python_literal("{\"nested\": {\"value\": [1, 2, 3]}}").unwrap(),
|
||||
json!({"nested": {"value": [1, 2, 3]}})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complex_nested() {
|
||||
let input = r#"{"config": {"nested": {"value": [1, 2, 3]}, "enabled": True}}"#;
|
||||
let expected = json!({
|
||||
"config": {
|
||||
"nested": {
|
||||
"value": [1, 2, 3]
|
||||
},
|
||||
"enabled": true
|
||||
}
|
||||
});
|
||||
assert_eq!(parse_python_literal(input).unwrap(), expected);
|
||||
}
|
||||
}
|
||||
224
sgl-router/src/tool_parser/registry.rs
Normal file
224
sgl-router/src/tool_parser/registry.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
use crate::tool_parser::parsers::{
|
||||
DeepSeekParser, Glm4MoeParser, GptOssParser, JsonParser, KimiK2Parser, LlamaParser,
|
||||
MistralParser, PythonicParser, QwenParser, Step3Parser,
|
||||
};
|
||||
use crate::tool_parser::traits::ToolParser;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Global singleton registry instance - created once and reused
|
||||
pub static GLOBAL_REGISTRY: Lazy<ParserRegistry> = Lazy::new(ParserRegistry::new_internal);
|
||||
|
||||
/// Registry for tool parsers and model mappings
|
||||
pub struct ParserRegistry {
|
||||
/// Map of parser name to parser instance
|
||||
parsers: HashMap<String, Arc<dyn ToolParser>>,
|
||||
/// Map of model name/pattern to parser name
|
||||
model_mapping: HashMap<String, String>,
|
||||
/// Default parser to use when no match found
|
||||
default_parser: String,
|
||||
}
|
||||
|
||||
impl ParserRegistry {
|
||||
/// Get the global singleton instance
|
||||
pub fn new() -> &'static Self {
|
||||
&GLOBAL_REGISTRY
|
||||
}
|
||||
|
||||
/// Create a new instance for testing (not the singleton)
|
||||
#[cfg(test)]
|
||||
pub fn new_for_testing() -> Self {
|
||||
Self::new_internal()
|
||||
}
|
||||
|
||||
/// Internal constructor for creating the singleton instance
|
||||
fn new_internal() -> Self {
|
||||
let mut registry = Self {
|
||||
parsers: HashMap::new(),
|
||||
model_mapping: HashMap::new(),
|
||||
default_parser: "json".to_string(),
|
||||
};
|
||||
|
||||
// Register default parsers
|
||||
registry.register_default_parsers();
|
||||
|
||||
// Register default model mappings
|
||||
registry.register_default_mappings();
|
||||
|
||||
registry
|
||||
}
|
||||
|
||||
/// Register a parser
|
||||
pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
|
||||
self.parsers.insert(name.into(), parser);
|
||||
}
|
||||
|
||||
/// Map a model name/pattern to a parser
|
||||
pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
|
||||
self.model_mapping.insert(model.into(), parser.into());
|
||||
}
|
||||
|
||||
/// Get parser for a specific model
|
||||
pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
|
||||
// Try exact match first
|
||||
if let Some(parser_name) = self.model_mapping.get(model) {
|
||||
if let Some(parser) = self.parsers.get(parser_name) {
|
||||
return Some(parser.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Try prefix matching with more specific patterns first
|
||||
// Collect all matching patterns and sort by specificity (longer = more specific)
|
||||
let mut matches: Vec<(&String, &String)> = self
|
||||
.model_mapping
|
||||
.iter()
|
||||
.filter(|(pattern, _)| {
|
||||
if pattern.ends_with('*') {
|
||||
let prefix = &pattern[..pattern.len() - 1];
|
||||
model.starts_with(prefix)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by pattern length in descending order (longer patterns are more specific)
|
||||
matches.sort_by_key(|(pattern, _)| std::cmp::Reverse(pattern.len()));
|
||||
|
||||
// Return the first matching parser
|
||||
for (_, parser_name) in matches {
|
||||
if let Some(parser) = self.parsers.get(parser_name) {
|
||||
return Some(parser.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default parser if it exists
|
||||
self.parsers.get(&self.default_parser).cloned()
|
||||
}
|
||||
|
||||
/// List all registered parsers
|
||||
pub fn list_parsers(&self) -> Vec<&str> {
|
||||
self.parsers.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
|
||||
/// List all model mappings
|
||||
pub fn list_mappings(&self) -> Vec<(&str, &str)> {
|
||||
self.model_mapping
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v.as_str()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Register default parsers
|
||||
fn register_default_parsers(&mut self) {
|
||||
// JSON parser - most common format
|
||||
self.register_parser("json", Arc::new(JsonParser::new()));
|
||||
|
||||
// Mistral parser - [TOOL_CALLS] [...] format
|
||||
self.register_parser("mistral", Arc::new(MistralParser::new()));
|
||||
|
||||
// Qwen parser - <tool_call>...</tool_call> format
|
||||
self.register_parser("qwen", Arc::new(QwenParser::new()));
|
||||
|
||||
// Pythonic parser - [func(arg=val)] format
|
||||
self.register_parser("pythonic", Arc::new(PythonicParser::new()));
|
||||
|
||||
// Llama parser - <|python_tag|>{...} or plain JSON format
|
||||
self.register_parser("llama", Arc::new(LlamaParser::new()));
|
||||
|
||||
// DeepSeek V3 parser - Unicode tokens with JSON blocks
|
||||
self.register_parser("deepseek", Arc::new(DeepSeekParser::new()));
|
||||
|
||||
// GLM-4 MoE parser - XML-style key-value format
|
||||
self.register_parser("glm4_moe", Arc::new(Glm4MoeParser::new()));
|
||||
|
||||
// Step3 parser - StepTML XML format
|
||||
self.register_parser("step3", Arc::new(Step3Parser::new()));
|
||||
|
||||
// Kimi K2 parser - Token-based with indexed functions
|
||||
self.register_parser("kimik2", Arc::new(KimiK2Parser::new()));
|
||||
|
||||
// GPT-OSS parser - Channel format
|
||||
self.register_parser("gpt_oss", Arc::new(GptOssParser::new()));
|
||||
}
|
||||
|
||||
/// Register default model mappings
|
||||
fn register_default_mappings(&mut self) {
|
||||
// OpenAI models
|
||||
self.map_model("gpt-4*", "json");
|
||||
self.map_model("gpt-3.5*", "json");
|
||||
self.map_model("gpt-4o*", "json");
|
||||
|
||||
// Anthropic models
|
||||
self.map_model("claude-*", "json");
|
||||
|
||||
// Mistral models - use Mistral parser
|
||||
self.map_model("mistral-*", "mistral");
|
||||
self.map_model("mixtral-*", "mistral");
|
||||
|
||||
// Qwen models - use Qwen parser
|
||||
self.map_model("qwen*", "qwen");
|
||||
self.map_model("Qwen*", "qwen");
|
||||
|
||||
// Llama models
|
||||
// Llama 4 uses pythonic format
|
||||
self.map_model("llama-4*", "pythonic");
|
||||
self.map_model("meta-llama-4*", "pythonic");
|
||||
// Llama 3.2 uses python_tag format
|
||||
self.map_model("llama-3.2*", "llama");
|
||||
self.map_model("meta-llama-3.2*", "llama");
|
||||
// Other Llama models use JSON
|
||||
self.map_model("llama-*", "json");
|
||||
self.map_model("meta-llama-*", "json");
|
||||
|
||||
// DeepSeek models
|
||||
// DeepSeek V3 uses custom Unicode token format
|
||||
self.map_model("deepseek-v3*", "deepseek");
|
||||
self.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
|
||||
// DeepSeek V2 uses pythonic format
|
||||
self.map_model("deepseek-*", "pythonic");
|
||||
|
||||
// GLM models
|
||||
// GLM-4 MoE uses XML-style format
|
||||
self.map_model("glm-4-moe*", "glm4_moe");
|
||||
self.map_model("THUDM/glm-4-moe*", "glm4_moe");
|
||||
self.map_model("glm-4.5*", "glm4_moe");
|
||||
// Other GLM models may use JSON
|
||||
self.map_model("glm-*", "json");
|
||||
|
||||
// Step3 models
|
||||
self.map_model("step3*", "step3");
|
||||
self.map_model("Step-3*", "step3");
|
||||
|
||||
// Kimi models
|
||||
self.map_model("kimi-k2*", "kimik2");
|
||||
self.map_model("Kimi-K2*", "kimik2");
|
||||
self.map_model("moonshot*/Kimi-K2*", "kimik2");
|
||||
|
||||
// GPT-OSS models (T4-style)
|
||||
self.map_model("gpt-oss*", "gpt_oss");
|
||||
self.map_model("t4-*", "gpt_oss");
|
||||
|
||||
// Other models default to JSON
|
||||
self.map_model("gemini-*", "json");
|
||||
self.map_model("palm-*", "json");
|
||||
self.map_model("gemma-*", "json");
|
||||
}
|
||||
|
||||
/// Set the default parser
|
||||
pub fn set_default_parser(&mut self, name: impl Into<String>) {
|
||||
self.default_parser = name.into();
|
||||
}
|
||||
|
||||
/// Check if a parser is registered
|
||||
pub fn has_parser(&self, name: &str) -> bool {
|
||||
self.parsers.contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for &'static ParserRegistry {
|
||||
fn default() -> Self {
|
||||
ParserRegistry::new()
|
||||
}
|
||||
}
|
||||
181
sgl-router/src/tool_parser/state.rs
Normal file
181
sgl-router/src/tool_parser/state.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
use crate::tool_parser::types::{PartialToolCall, ToolCall};
|
||||
|
||||
/// Current phase of parsing
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ParsePhase {
|
||||
/// Looking for start of tool call
|
||||
Searching,
|
||||
/// Parsing function name
|
||||
InName,
|
||||
/// Parsing function arguments
|
||||
InArguments,
|
||||
/// Tool call complete
|
||||
Complete,
|
||||
}
|
||||
|
||||
/// State for streaming parser
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParseState {
|
||||
/// Buffer for accumulating input
|
||||
pub buffer: String,
|
||||
/// Position of last consumed character
|
||||
pub consumed: usize,
|
||||
/// Current partial tool being parsed
|
||||
pub partial_tool: Option<PartialToolCall>,
|
||||
/// Completed tool calls
|
||||
pub completed_tools: Vec<ToolCall>,
|
||||
/// Current parsing phase
|
||||
pub phase: ParsePhase,
|
||||
/// Bracket/brace depth for JSON parsing
|
||||
pub bracket_depth: i32,
|
||||
/// Whether currently inside a string literal
|
||||
pub in_string: bool,
|
||||
/// Whether next character should be escaped
|
||||
pub escape_next: bool,
|
||||
/// Current tool index (for streaming)
|
||||
pub tool_index: usize,
|
||||
}
|
||||
|
||||
impl ParseState {
|
||||
/// Create a new parse state
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
consumed: 0,
|
||||
partial_tool: None,
|
||||
completed_tools: Vec::new(),
|
||||
phase: ParsePhase::Searching,
|
||||
bracket_depth: 0,
|
||||
in_string: false,
|
||||
escape_next: false,
|
||||
tool_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset state for parsing next tool
|
||||
pub fn reset(&mut self) {
|
||||
self.partial_tool = None;
|
||||
self.phase = ParsePhase::Searching;
|
||||
self.bracket_depth = 0;
|
||||
self.in_string = false;
|
||||
self.escape_next = false;
|
||||
}
|
||||
|
||||
/// Process a single character for JSON parsing
|
||||
pub fn process_char(&mut self, ch: char) {
|
||||
// Handle escape sequences
|
||||
if self.escape_next {
|
||||
self.escape_next = false;
|
||||
self.buffer.push(ch);
|
||||
return;
|
||||
}
|
||||
|
||||
if ch == '\\' && self.in_string {
|
||||
self.escape_next = true;
|
||||
self.buffer.push(ch);
|
||||
return;
|
||||
}
|
||||
|
||||
// Track string boundaries
|
||||
if ch == '"' && !self.escape_next {
|
||||
self.in_string = !self.in_string;
|
||||
}
|
||||
|
||||
// Track bracket depth for JSON
|
||||
if !self.in_string {
|
||||
match ch {
|
||||
'{' | '[' => {
|
||||
self.bracket_depth += 1;
|
||||
}
|
||||
'}' | ']' => {
|
||||
self.bracket_depth -= 1;
|
||||
if self.bracket_depth == 0 && self.partial_tool.is_some() {
|
||||
// Complete tool call found
|
||||
self.phase = ParsePhase::Complete;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.push(ch);
|
||||
}
|
||||
|
||||
/// Check if we have a complete JSON object/array
|
||||
pub fn has_complete_json(&self) -> bool {
|
||||
self.bracket_depth == 0 && !self.in_string && !self.buffer.is_empty()
|
||||
}
|
||||
|
||||
/// Extract content from buffer starting at position
|
||||
pub fn extract_from(&self, start: usize) -> &str {
|
||||
if start >= self.buffer.len() {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Find the nearest character boundary at or after start
|
||||
let mut safe_start = start;
|
||||
while safe_start < self.buffer.len() && !self.buffer.is_char_boundary(safe_start) {
|
||||
safe_start += 1;
|
||||
}
|
||||
|
||||
if safe_start < self.buffer.len() {
|
||||
&self.buffer[safe_start..]
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark content as consumed up to position
|
||||
pub fn consume_to(&mut self, position: usize) {
|
||||
if position > self.consumed {
|
||||
self.consumed = position;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get unconsumed content
|
||||
pub fn unconsumed(&self) -> &str {
|
||||
if self.consumed >= self.buffer.len() {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Find the nearest character boundary at or after consumed
|
||||
let mut safe_consumed = self.consumed;
|
||||
while safe_consumed < self.buffer.len() && !self.buffer.is_char_boundary(safe_consumed) {
|
||||
safe_consumed += 1;
|
||||
}
|
||||
|
||||
if safe_consumed < self.buffer.len() {
|
||||
&self.buffer[safe_consumed..]
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear consumed content from buffer
|
||||
pub fn clear_consumed(&mut self) {
|
||||
if self.consumed > 0 {
|
||||
// Find the nearest character boundary at or before consumed
|
||||
let mut safe_consumed = self.consumed;
|
||||
while safe_consumed > 0 && !self.buffer.is_char_boundary(safe_consumed) {
|
||||
safe_consumed -= 1;
|
||||
}
|
||||
|
||||
if safe_consumed > 0 {
|
||||
self.buffer.drain(..safe_consumed);
|
||||
self.consumed = self.consumed.saturating_sub(safe_consumed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add completed tool
|
||||
pub fn add_completed_tool(&mut self, tool: ToolCall) {
|
||||
self.completed_tools.push(tool);
|
||||
self.tool_index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ParseState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
886
sgl-router/src/tool_parser/tests.rs
Normal file
886
sgl-router/src/tool_parser/tests.rs
Normal file
@@ -0,0 +1,886 @@
|
||||
use super::*;
|
||||
use crate::tool_parser::parsers::JsonParser;
|
||||
use crate::tool_parser::partial_json::{
|
||||
compute_diff, find_common_prefix, is_complete_json, PartialJson,
|
||||
};
|
||||
use crate::tool_parser::traits::ToolParser;
|
||||
use crate::tool_parser::types::TokenConfig;
|
||||
|
||||
#[test]
|
||||
fn test_parse_state_new() {
|
||||
let state = ParseState::new();
|
||||
assert_eq!(state.phase, ParsePhase::Searching);
|
||||
assert_eq!(state.buffer, "");
|
||||
assert_eq!(state.consumed, 0);
|
||||
assert_eq!(state.bracket_depth, 0);
|
||||
assert!(!state.in_string);
|
||||
assert!(!state.escape_next);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_state_process_char() {
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Test bracket tracking
|
||||
state.process_char('{');
|
||||
assert_eq!(state.bracket_depth, 1);
|
||||
|
||||
state.process_char('}');
|
||||
assert_eq!(state.bracket_depth, 0);
|
||||
|
||||
// Test string tracking
|
||||
state.process_char('"');
|
||||
assert!(state.in_string);
|
||||
|
||||
state.process_char('"');
|
||||
assert!(!state.in_string);
|
||||
|
||||
// Test escape handling
|
||||
state.process_char('"');
|
||||
state.process_char('\\');
|
||||
assert!(state.escape_next);
|
||||
|
||||
state.process_char('"');
|
||||
assert!(!state.escape_next);
|
||||
assert!(state.in_string); // Still in string because quote was escaped
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_config() {
|
||||
let config = TokenConfig {
|
||||
start_tokens: vec!["<start>".to_string(), "[".to_string()],
|
||||
end_tokens: vec!["</end>".to_string(), "]".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
};
|
||||
|
||||
let pairs: Vec<_> = config.iter_pairs().collect();
|
||||
assert_eq!(pairs.len(), 2);
|
||||
assert_eq!(pairs[0], ("<start>", "</end>"));
|
||||
assert_eq!(pairs[1], ("[", "]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parser_registry() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// Test has default mappings
|
||||
assert!(!registry.list_mappings().is_empty());
|
||||
|
||||
// Test model pattern matching
|
||||
let mappings = registry.list_mappings();
|
||||
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
|
||||
assert!(has_gpt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parser_registry_pattern_matching() {
|
||||
let mut registry = ParserRegistry::new_for_testing();
|
||||
|
||||
// Test that model mappings work by checking the list
|
||||
registry.map_model("test-model", "json");
|
||||
|
||||
// Verify through list_mappings
|
||||
let mappings = registry.list_mappings();
|
||||
let has_test = mappings
|
||||
.iter()
|
||||
.any(|(m, p)| *m == "test-model" && *p == "json");
|
||||
assert!(has_test);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_serialization() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call-123".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "search".to_string(),
|
||||
arguments: r#"{"query": "rust programming"}"#.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&tool_call).unwrap();
|
||||
assert!(json.contains("call-123"));
|
||||
assert!(json.contains("search"));
|
||||
assert!(json.contains("rust programming"));
|
||||
|
||||
let parsed: ToolCall = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.id, "call-123");
|
||||
assert_eq!(parsed.function.name, "search");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_json_parser() {
|
||||
let parser = PartialJson::default();
|
||||
|
||||
// Test complete JSON
|
||||
let input = r#"{"name": "test", "value": 42}"#;
|
||||
let (value, consumed) = parser.parse_value(input).unwrap();
|
||||
assert_eq!(value["name"], "test");
|
||||
assert_eq!(value["value"], 42);
|
||||
assert_eq!(consumed, input.len());
|
||||
|
||||
// Test incomplete JSON object
|
||||
let input = r#"{"name": "test", "value": "#;
|
||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||
assert_eq!(value["name"], "test");
|
||||
assert!(value["value"].is_null());
|
||||
|
||||
// Test incomplete string
|
||||
let input = r#"{"name": "tes"#;
|
||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||
assert_eq!(value["name"], "tes");
|
||||
|
||||
// Test incomplete array
|
||||
let input = r#"[1, 2, "#;
|
||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||
assert!(value.is_array());
|
||||
assert_eq!(value[0], 1);
|
||||
assert_eq!(value[1], 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_json_depth_limit() {
|
||||
// max_depth of 3 allows nesting up to 3 levels
|
||||
// Set allow_incomplete to false to get errors instead of partial results
|
||||
let parser = PartialJson::new(3, false);
|
||||
|
||||
// This should work (simple object)
|
||||
let input = r#"{"a": 1}"#;
|
||||
let result = parser.parse_value(input);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// This should work (nested to depth 3)
|
||||
let input = r#"{"a": {"b": {"c": 1}}}"#;
|
||||
let result = parser.parse_value(input);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// This should fail (nested to depth 4, exceeds limit)
|
||||
let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#;
|
||||
let result = parser.parse_value(input);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_complete_json() {
|
||||
assert!(is_complete_json(r#"{"name": "test"}"#));
|
||||
assert!(is_complete_json(r#"[1, 2, 3]"#));
|
||||
assert!(is_complete_json(r#""string""#));
|
||||
assert!(is_complete_json("42"));
|
||||
assert!(is_complete_json("true"));
|
||||
assert!(is_complete_json("null"));
|
||||
|
||||
assert!(!is_complete_json(r#"{"name": "#));
|
||||
assert!(!is_complete_json(r#"[1, 2, "#));
|
||||
assert!(!is_complete_json(r#""unclosed"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_common_prefix() {
|
||||
assert_eq!(find_common_prefix("hello", "hello"), 5);
|
||||
assert_eq!(find_common_prefix("hello", "help"), 3);
|
||||
assert_eq!(find_common_prefix("hello", "world"), 0);
|
||||
assert_eq!(find_common_prefix("", "hello"), 0);
|
||||
assert_eq!(find_common_prefix("hello", ""), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_diff() {
|
||||
assert_eq!(compute_diff("hello", "hello world"), " world");
|
||||
assert_eq!(compute_diff("", "hello"), "hello");
|
||||
assert_eq!(compute_diff("hello", "hello"), "");
|
||||
assert_eq!(compute_diff("test", "hello"), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_result_variants() {
|
||||
// Test Incomplete
|
||||
let result = StreamResult::Incomplete;
|
||||
matches!(result, StreamResult::Incomplete);
|
||||
|
||||
// Test ToolName
|
||||
let result = StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: "test".to_string(),
|
||||
};
|
||||
if let StreamResult::ToolName { index, name } = result {
|
||||
assert_eq!(index, 0);
|
||||
assert_eq!(name, "test");
|
||||
} else {
|
||||
panic!("Expected ToolName variant");
|
||||
}
|
||||
|
||||
// Test ToolComplete
|
||||
let tool = ToolCall {
|
||||
id: "123".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "test".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
};
|
||||
let result = StreamResult::ToolComplete(tool.clone());
|
||||
if let StreamResult::ToolComplete(t) = result {
|
||||
assert_eq!(t.id, "123");
|
||||
} else {
|
||||
panic!("Expected ToolComplete variant");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_tool_call() {
|
||||
let mut partial = PartialToolCall {
|
||||
name: None,
|
||||
arguments_buffer: String::new(),
|
||||
start_position: 0,
|
||||
name_sent: false,
|
||||
streamed_args: String::new(),
|
||||
};
|
||||
|
||||
// Set name
|
||||
partial.name = Some("test_function".to_string());
|
||||
assert_eq!(partial.name.as_ref().unwrap(), "test_function");
|
||||
|
||||
// Append arguments
|
||||
partial.arguments_buffer.push_str(r#"{"key": "value"}"#);
|
||||
assert_eq!(partial.arguments_buffer, r#"{"key": "value"}"#);
|
||||
|
||||
// Update streaming state
|
||||
partial.name_sent = true;
|
||||
partial.streamed_args = r#"{"key": "#.to_string();
|
||||
assert!(partial.name_sent);
|
||||
assert_eq!(partial.streamed_args, r#"{"key": "#);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_parser_complete_single() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test single tool call with arguments
|
||||
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("San Francisco"));
|
||||
assert!(result[0].function.arguments.contains("celsius"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_parser_complete_array() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test array of tool calls
|
||||
let input = r#"[
|
||||
{"name": "get_weather", "arguments": {"location": "SF"}},
|
||||
{"name": "get_news", "arguments": {"query": "technology"}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "get_news");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_parser_with_parameters() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test with "parameters" instead of "arguments"
|
||||
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
assert!(result[0].function.arguments.contains("10"));
|
||||
assert!(result[0].function.arguments.contains("20"));
|
||||
assert!(result[0].function.arguments.contains("add"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_parser_with_tokens() {
|
||||
// Test with custom wrapper tokens
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["[TOOL_CALLS] [".to_string()],
|
||||
end_tokens: vec!["]".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"[TOOL_CALLS] [{"name": "search", "arguments": {"query": "rust programming"}}]"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiline_json_with_tokens() {
|
||||
// Test that regex with (?s) flag properly handles multi-line JSON
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tool>".to_string()],
|
||||
end_tokens: vec!["</tool>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
// Pretty-printed multi-line JSON
|
||||
let input = r#"<tool>{
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"location": "San Francisco",
|
||||
"units": "celsius",
|
||||
"include_forecast": true
|
||||
}
|
||||
}</tool>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("San Francisco"));
|
||||
assert!(result[0].function.arguments.contains("celsius"));
|
||||
assert!(result[0].function.arguments.contains("true"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiline_json_array() {
|
||||
// Test multi-line JSON array without wrapper tokens
|
||||
let parser = JsonParser::new();
|
||||
|
||||
let input = r#"[
|
||||
{
|
||||
"name": "function1",
|
||||
"arguments": {
|
||||
"param1": "value1",
|
||||
"param2": 42
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "function2",
|
||||
"parameters": {
|
||||
"data": [1, 2, 3],
|
||||
"flag": false
|
||||
}
|
||||
}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "function1");
|
||||
assert_eq!(result[1].function.name, "function2");
|
||||
assert!(result[0].function.arguments.contains("value1"));
|
||||
assert!(result[1].function.arguments.contains("[1,2,3]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_parser_format_detection() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Should detect valid tool call formats
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(parser.detect_format(r#"{"name": "test", "parameters": {"x": 1}}"#));
|
||||
assert!(parser.detect_format(r#"[{"name": "test"}]"#));
|
||||
|
||||
// Should not detect non-tool formats
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#));
|
||||
assert!(!parser.detect_format(r#"{"data": {"nested": true}}"#));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_parser_streaming() {
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Test with complete JSON
|
||||
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_json, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
assert!(tool.function.arguments.contains("San Francisco"));
|
||||
}
|
||||
_ => panic!("Expected ToolComplete for complete JSON"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_registry_with_json_parser() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// JSON parser should be registered by default
|
||||
assert!(registry.has_parser("json"));
|
||||
|
||||
// Should get JSON parser for OpenAI models
|
||||
let parser = registry.get_parser("gpt-4-turbo").unwrap();
|
||||
|
||||
// Test that the parser works
|
||||
let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_parser_invalid_input() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Invalid JSON should return empty results
|
||||
assert_eq!(parser.parse_complete("not json").await.unwrap().len(), 0);
|
||||
assert_eq!(parser.parse_complete("{invalid}").await.unwrap().len(), 0);
|
||||
assert_eq!(parser.parse_complete("").await.unwrap().len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_parser_empty_arguments() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Tool call with no arguments
|
||||
let input = r#"{"name": "get_time"}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
assert_eq!(result[0].function.arguments, "{}");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod failure_cases {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_malformed_tool_missing_name() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Missing name field
|
||||
let input = r#"{"arguments": {"x": 1}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should return empty for tool without name");
|
||||
|
||||
// Empty name
|
||||
let input = r#"{"name": "", "arguments": {"x": 1}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1, "Should accept empty name string");
|
||||
assert_eq!(result[0].function.name, "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_arguments_json() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Arguments is a string instead of object
|
||||
let input = r#"{"name": "test", "arguments": "not an object"}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
// Should serialize the string as JSON
|
||||
assert!(result[0].function.arguments.contains("not an object"));
|
||||
|
||||
// Arguments is a number
|
||||
let input = r#"{"name": "test", "arguments": 42}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.arguments, "42");
|
||||
|
||||
// Arguments is null
|
||||
let input = r#"{"name": "test", "arguments": null}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.arguments, "null");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_broken_wrapper_tokens() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tool>".to_string()],
|
||||
end_tokens: vec!["</tool>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
// Missing end token
|
||||
let input = r#"<tool>{"name": "test", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
0,
|
||||
"Should fail to parse without complete wrapper"
|
||||
);
|
||||
|
||||
// Missing start token - parser looks for complete wrapper, so this won't parse
|
||||
let input = r#"{"name": "test", "arguments": {}}</tool>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
0,
|
||||
"Should not parse JSON with incomplete wrapper"
|
||||
);
|
||||
|
||||
// Mismatched tokens
|
||||
let input = r#"<tool>{"name": "test", "arguments": {}}</wrong>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should fail with mismatched tokens");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_json_structures() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Trailing comma
|
||||
let input = r#"{"name": "test", "arguments": {"x": 1,}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should reject JSON with trailing comma");
|
||||
|
||||
// Missing quotes on keys
|
||||
let input = r#"{name: "test", arguments: {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should reject invalid JSON syntax");
|
||||
|
||||
// Unclosed object
|
||||
let input = r#"{"name": "test", "arguments": {"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should reject incomplete JSON");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod edge_cases {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unicode_in_names_and_arguments() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Unicode in function name
|
||||
let input = r#"{"name": "获取天气", "arguments": {"location": "北京"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "获取天气");
|
||||
assert!(result[0].function.arguments.contains("北京"));
|
||||
|
||||
// Emoji in arguments
|
||||
let input = r#"{"name": "send_message", "arguments": {"text": "Hello 👋 World 🌍"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("👋"));
|
||||
assert!(result[0].function.arguments.contains("🌍"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_escaped_characters() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Escaped quotes in arguments
|
||||
let input = r#"{"name": "echo", "arguments": {"text": "He said \"hello\""}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains(r#"\"hello\""#));
|
||||
|
||||
// Escaped backslashes
|
||||
let input = r#"{"name": "path", "arguments": {"dir": "C:\\Users\\test"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("\\\\"));
|
||||
|
||||
// Newlines and tabs
|
||||
let input = r#"{"name": "format", "arguments": {"text": "line1\nline2\ttabbed"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("\\n"));
|
||||
assert!(result[0].function.arguments.contains("\\t"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_very_large_payloads() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Large arguments object
|
||||
let mut large_args = r#"{"name": "process", "arguments": {"#.to_string();
|
||||
for i in 0..1000 {
|
||||
large_args.push_str(&format!(r#""field_{}": "value_{}","#, i, i));
|
||||
}
|
||||
large_args.push_str(r#""final": "value"}}"#);
|
||||
|
||||
let result = parser.parse_complete(&large_args).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process");
|
||||
assert!(result[0].function.arguments.contains("field_999"));
|
||||
|
||||
// Large array of tool calls
|
||||
let mut large_array = "[".to_string();
|
||||
for i in 0..100 {
|
||||
if i > 0 {
|
||||
large_array.push(',');
|
||||
}
|
||||
large_array.push_str(&format!(r#"{{"name": "func_{}", "arguments": {{}}}}"#, i));
|
||||
}
|
||||
large_array.push(']');
|
||||
|
||||
let result = parser.parse_complete(&large_array).await.unwrap();
|
||||
assert_eq!(result.len(), 100);
|
||||
assert_eq!(result[99].function.name, "func_99");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mixed_array_tools_and_non_tools() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Array with both tool calls and non-tool objects
|
||||
let input = r#"[
|
||||
{"name": "tool1", "arguments": {}},
|
||||
{"not_a_tool": "just_data"},
|
||||
{"name": "tool2", "parameters": {"x": 1}},
|
||||
{"key": "value", "another": "field"}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2, "Should only parse valid tool calls");
|
||||
assert_eq!(result[0].function.name, "tool1");
|
||||
assert_eq!(result[1].function.name, "tool2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_duplicate_keys_in_json() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// JSON with duplicate keys (last one wins in most parsers)
|
||||
let input = r#"{"name": "first", "name": "second", "arguments": {"x": 1, "x": 2}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(
|
||||
result[0].function.name, "second",
|
||||
"Last duplicate key should win"
|
||||
);
|
||||
assert!(
|
||||
result[0].function.arguments.contains("2"),
|
||||
"Last duplicate value should win"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_null_values_in_arguments() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Null values in arguments
|
||||
let input = r#"{"name": "test", "arguments": {"required": "value", "optional": null}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("null"));
|
||||
|
||||
// Array with null
|
||||
let input = r#"{"name": "test", "arguments": {"items": [1, null, "three"]}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("null"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_token_pairs_with_conflicts() {
|
||||
// Test with overlapping token patterns
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<<".to_string(), "<tool>".to_string()],
|
||||
end_tokens: vec![">>".to_string(), "</tool>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
// First pattern
|
||||
let input = r#"<<{"name": "test1", "arguments": {}}>>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test1");
|
||||
|
||||
// Second pattern
|
||||
let input = r#"<tool>{"name": "test2", "arguments": {}}</tool>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test2");
|
||||
|
||||
// Nested patterns (should use first match)
|
||||
let input = r#"<<tool>{"name": "test3", "arguments": {}}</tool>>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
// This is tricky - depends on regex behavior
|
||||
// The parser should handle this gracefully
|
||||
assert!(result.len() <= 1, "Should not parse multiple times");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_with_partial_chunks() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test 1: Very incomplete JSON (just opening brace) should return Incomplete
|
||||
let mut state1 = ParseState::new();
|
||||
let partial = r#"{"#;
|
||||
let result = parser
|
||||
.parse_incremental(partial, &mut state1)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
matches!(result, StreamResult::Incomplete),
|
||||
"Should return Incomplete for just opening brace"
|
||||
);
|
||||
|
||||
// Test 2: Complete JSON should return ToolComplete
|
||||
let mut state2 = ParseState::new();
|
||||
let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
|
||||
let result = parser
|
||||
.parse_incremental(complete, &mut state2)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "SF");
|
||||
}
|
||||
_ => panic!("Expected ToolComplete for complete JSON"),
|
||||
}
|
||||
|
||||
// Test 3: Partial JSON with name
|
||||
// The PartialJson parser can complete partial JSON by filling in missing values
|
||||
let mut state3 = ParseState::new();
|
||||
let partial_with_name = r#"{"name": "test", "argum"#;
|
||||
let result = parser
|
||||
.parse_incremental(partial_with_name, &mut state3)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "test");
|
||||
// Arguments will be empty object since "argum" is incomplete
|
||||
assert_eq!(tool.function.arguments, "{}");
|
||||
}
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "test");
|
||||
}
|
||||
StreamResult::Incomplete => {
|
||||
// Also acceptable if parser decides to wait
|
||||
}
|
||||
_ => panic!("Unexpected result for partial JSON with name"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_special_json_values() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Boolean values
|
||||
let input = r#"{"name": "toggle", "arguments": {"enabled": true, "disabled": false}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("true"));
|
||||
assert!(result[0].function.arguments.contains("false"));
|
||||
|
||||
// Numbers (including float and negative)
|
||||
let input = r#"{"name": "calc", "arguments": {"int": 42, "float": 3.14, "negative": -17}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("42"));
|
||||
assert!(result[0].function.arguments.contains("3.14"));
|
||||
assert!(result[0].function.arguments.contains("-17"));
|
||||
|
||||
// Empty arrays and objects
|
||||
let input = r#"{"name": "test", "arguments": {"empty_arr": [], "empty_obj": {}}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("[]"));
|
||||
assert!(result[0].function.arguments.contains("{}"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_function_field_alternative() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Using "function" instead of "name"
|
||||
let input = r#"{"function": "test_func", "arguments": {"x": 1}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test_func");
|
||||
|
||||
// Both "name" and "function" present (name should take precedence)
|
||||
let input = r#"{"name": "primary", "function": "secondary", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "primary");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_whitespace_handling() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Extra whitespace everywhere
|
||||
let input = r#" {
|
||||
"name" : "test" ,
|
||||
"arguments" : {
|
||||
"key" : "value"
|
||||
}
|
||||
} "#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
|
||||
// Minified JSON (no whitespace)
|
||||
let input = r#"{"name":"compact","arguments":{"a":1,"b":2}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "compact");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod stress_tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deeply_nested_arguments() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Deeply nested structure
|
||||
let input = r#"{
|
||||
"name": "nested",
|
||||
"arguments": {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"level4": {
|
||||
"level5": {
|
||||
"value": "deep"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("deep"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_parser_usage() {
|
||||
// Test that parser can be used concurrently
|
||||
let parser = std::sync::Arc::new(JsonParser::new());
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..10 {
|
||||
let parser_clone = parser.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let input = format!(r#"{{"name": "func_{}", "arguments": {{}}}}"#, i);
|
||||
let result = parser_clone.parse_complete(&input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, format!("func_{}", i));
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
35
sgl-router/src/tool_parser/traits.rs
Normal file
35
sgl-router/src/tool_parser/traits.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
state::ParseState,
|
||||
types::{StreamResult, ToolCall},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Core trait for all tool parsers
|
||||
#[async_trait]
|
||||
pub trait ToolParser: Send + Sync {
|
||||
/// Parse complete tool calls from final output
|
||||
async fn parse_complete(&self, output: &str) -> ToolParserResult<Vec<ToolCall>>;
|
||||
|
||||
/// Parse tool calls from model output (streaming)
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult>;
|
||||
|
||||
/// Check if text contains tool calls in this parser's format
|
||||
fn detect_format(&self, text: &str) -> bool;
|
||||
}
|
||||
|
||||
/// Trait for partial JSON parsing
|
||||
pub trait PartialJsonParser: Send + Sync {
|
||||
/// Parse potentially incomplete JSON
|
||||
fn parse(&self, input: &str) -> ToolParserResult<(serde_json::Value, usize)>;
|
||||
|
||||
/// Check if JSON is complete
|
||||
fn is_complete(&self, input: &str) -> bool;
|
||||
|
||||
/// Get the maximum parsing depth
|
||||
fn max_depth(&self) -> usize;
|
||||
}
|
||||
73
sgl-router/src/tool_parser/types.rs
Normal file
73
sgl-router/src/tool_parser/types.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Parsed tool call from model output (OpenAI format)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
/// Unique identifier for the tool call
|
||||
pub id: String,
|
||||
/// Type of tool call (currently always "function")
|
||||
#[serde(rename = "type")]
|
||||
pub r#type: String,
|
||||
/// Function call details
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
/// Function call within a tool call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FunctionCall {
|
||||
/// Name of the function to call
|
||||
pub name: String,
|
||||
/// Arguments as JSON string
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Streaming parse result
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamResult {
|
||||
/// Need more data to continue parsing
|
||||
Incomplete,
|
||||
/// Found a tool name (for streaming)
|
||||
ToolName { index: usize, name: String },
|
||||
/// Found incremental arguments (for streaming)
|
||||
ToolArguments { index: usize, arguments: String },
|
||||
/// Completed parsing a tool
|
||||
ToolComplete(ToolCall),
|
||||
/// Normal text (not part of tool call)
|
||||
NormalText(String),
|
||||
}
|
||||
|
||||
/// Token configuration for parsing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TokenConfig {
|
||||
/// Start tokens for tool calls
|
||||
pub start_tokens: Vec<String>,
|
||||
/// End tokens for tool calls
|
||||
pub end_tokens: Vec<String>,
|
||||
/// Separator between multiple tool calls
|
||||
pub separator: String,
|
||||
}
|
||||
|
||||
impl TokenConfig {
|
||||
/// Iterate over start/end token pairs
|
||||
pub fn iter_pairs(&self) -> impl Iterator<Item = (&str, &str)> {
|
||||
self.start_tokens
|
||||
.iter()
|
||||
.zip(self.end_tokens.iter())
|
||||
.map(|(s, e)| (s.as_str(), e.as_str()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple partial tool call for streaming
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PartialToolCall {
|
||||
/// Tool name (if parsed)
|
||||
pub name: Option<String>,
|
||||
/// Buffer for accumulating arguments
|
||||
pub arguments_buffer: String,
|
||||
/// Start position in the input buffer
|
||||
pub start_position: usize,
|
||||
/// Whether the name has been sent (for streaming)
|
||||
pub name_sent: bool,
|
||||
/// Arguments already streamed
|
||||
pub streamed_args: String,
|
||||
}
|
||||
1478
sgl-router/src/tree.rs
Normal file
1478
sgl-router/src/tree.rs
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user