From 53c2934dcef8794449b7f270ad162f3fdabc01e3 Mon Sep 17 00:00:00 2001 From: Arthur Cheng Date: Thu, 23 Oct 2025 05:15:49 -0700 Subject: [PATCH] [Router] Consolidate ConnectionMode enum to core module (#11937) --- sgl-router/src/config/types.rs | 11 +---------- sgl-router/src/config/validation.rs | 9 +++++---- sgl-router/src/core/worker.rs | 7 ++++++- sgl-router/src/lib.rs | 8 ++++---- sgl-router/src/main.rs | 9 +++++---- sgl-router/src/routers/factory.rs | 5 +++-- sgl-router/src/routers/router_manager.rs | 10 +++++----- sgl-router/src/server.rs | 14 +++++++------- sgl-router/tests/api_endpoints_test.rs | 6 ++---- sgl-router/tests/responses_api_test.rs | 5 +++-- sgl-router/tests/test_pd_routing.rs | 7 ++----- 11 files changed, 43 insertions(+), 48 deletions(-) diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 3f318634b..e0019ec77 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use super::ConfigResult; +use crate::core::ConnectionMode; /// Main router configuration #[derive(Debug, Clone, Serialize, Deserialize)] @@ -208,16 +209,6 @@ impl std::fmt::Debug for OracleConfig { } } -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] -#[serde(tag = "type")] -pub enum ConnectionMode { - #[default] - #[serde(rename = "http")] - Http, - #[serde(rename = "grpc")] - Grpc, -} - /// Routing mode configuration #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index f401bec9a..c4cbbaa88 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -1,4 +1,5 @@ use super::*; +use crate::core::ConnectionMode; /// Configuration validator pub struct ConfigValidator; @@ -476,7 +477,7 @@ impl ConfigValidator { } // Validate gRPC connection mode requires tokenizer configuration - if config.connection_mode == ConnectionMode::Grpc + if matches!(config.connection_mode, ConnectionMode::Grpc { .. }) && config.tokenizer_path.is_none() && config.model_path.is_none() { @@ -832,7 +833,7 @@ mod tests { ); // Set connection mode to gRPC without tokenizer config - config.connection_mode = ConnectionMode::Grpc; + config.connection_mode = ConnectionMode::Grpc { port: None }; config.tokenizer_path = None; config.model_path = None; @@ -852,7 +853,7 @@ mod tests { PolicyConfig::Random, ); - config.connection_mode = ConnectionMode::Grpc; + config.connection_mode = ConnectionMode::Grpc { port: None }; config.model_path = Some("meta-llama/Llama-3-8B".to_string()); let result = ConfigValidator::validate(&config); @@ -868,7 +869,7 @@ mod tests { PolicyConfig::Random, ); - config.connection_mode = ConnectionMode::Grpc; + config.connection_mode = ConnectionMode::Grpc { port: None }; config.tokenizer_path = Some("/path/to/tokenizer.json".to_string()); let result = ConfigValidator::validate(&config); diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 3f5c05dbb..ca3311c41 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -8,6 +8,7 @@ use std::{ }; use async_trait::async_trait; +use serde::{Deserialize, Serialize}; use serde_json; use tokio::{sync::RwLock, time}; @@ -240,13 +241,17 @@ pub trait Worker: Send + Sync + fmt::Debug { } /// Connection mode for worker communication -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)] +#[serde(tag = "type", rename_all = "lowercase")] pub enum ConnectionMode { /// HTTP/REST connection + #[default] Http, /// gRPC connection Grpc { /// Optional port for gRPC endpoint (if different from URL) + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] port: Option, }, } diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index e08d7c5af..68c405aba 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -194,7 +194,7 @@ struct Router { queue_size: usize, queue_timeout_secs: u64, rate_limit_tokens_per_second: Option, - connection_mode: config::ConnectionMode, + connection_mode: core::ConnectionMode, model_path: Option, tokenizer_path: Option, chat_template: Option, @@ -211,13 +211,13 @@ struct Router { impl Router { /// Determine connection mode from worker URLs - fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode { + fn determine_connection_mode(worker_urls: &[String]) -> core::ConnectionMode { for url in worker_urls { if url.starts_with("grpc://") || url.starts_with("grpcs://") { - return config::ConnectionMode::Grpc; + return core::ConnectionMode::Grpc { port: None }; } } - config::ConnectionMode::Http + core::ConnectionMode::Http } pub fn to_router_config(&self) -> config::ConfigResult { diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 4722c6a5b..bc971cab8 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -3,10 +3,11 @@ use std::collections::HashMap; use clap::{ArgAction, Parser, ValueEnum}; use sglang_router_rs::{ config::{ - CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, - HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, - RouterConfig, RoutingMode, TokenizerCacheConfig, + CircuitBreakerConfig, ConfigError, ConfigResult, DiscoveryConfig, HealthCheckConfig, + HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, RouterConfig, + RoutingMode, TokenizerCacheConfig, }, + core::ConnectionMode, metrics::PrometheusConfig, server::{self, ServerConfig}, service_discovery::ServiceDiscoveryConfig, @@ -325,7 +326,7 @@ impl CliArgs { fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode { for url in worker_urls { if url.starts_with("grpc://") || url.starts_with("grpcs://") { - return ConnectionMode::Grpc; + return ConnectionMode::Grpc { port: None }; } } ConnectionMode::Http diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index dc7910048..bfb0528f6 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -9,7 +9,8 @@ use super::{ RouterTrait, }; use crate::{ - config::{ConnectionMode, PolicyConfig, RoutingMode}, + config::{PolicyConfig, RoutingMode}, + core::ConnectionMode, policies::PolicyFactory, server::AppContext, }; @@ -21,7 +22,7 @@ impl RouterFactory { /// Create a router instance from application context pub async fn create_router(ctx: &Arc) -> Result, String> { match ctx.router_config.connection_mode { - ConnectionMode::Grpc => match &ctx.router_config.mode { + ConnectionMode::Grpc { .. } => match &ctx.router_config.mode { RoutingMode::Regular { .. } => Self::create_grpc_router(ctx).await, RoutingMode::PrefillDecode { prefill_policy, diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 1bc739202..af9b5dc60 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -18,8 +18,8 @@ use serde_json::Value; use tracing::{debug, info, warn}; use crate::{ - config::{ConnectionMode, RoutingMode}, - core::{WorkerRegistry, WorkerType}, + config::RoutingMode, + core::{ConnectionMode, WorkerRegistry, WorkerType}, protocols::{ chat::ChatCompletionRequest, classify::ClassifyRequest, @@ -148,13 +148,13 @@ impl RouterManager { (ConnectionMode::Http, RoutingMode::OpenAI { .. }) => { RouterId::new("http-openai".to_string()) } - (ConnectionMode::Grpc, RoutingMode::Regular { .. }) => { + (ConnectionMode::Grpc { .. }, RoutingMode::Regular { .. }) => { RouterId::new("grpc-regular".to_string()) } - (ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => { + (ConnectionMode::Grpc { .. }, RoutingMode::PrefillDecode { .. }) => { RouterId::new("grpc-pd".to_string()) } - (ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => { + (ConnectionMode::Grpc { .. }, RoutingMode::OpenAI { .. }) => { RouterId::new("grpc-regular".to_string()) } } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 4d6eb6a79..580bdabf4 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -20,14 +20,15 @@ use tokio::{net::TcpListener, signal, spawn}; use tracing::{error, info, warn, Level}; use crate::{ - config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, + config::{HistoryBackend, RouterConfig, RoutingMode}, core::{ worker_to_info, workflow::{ create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber, WorkflowEngine, }, - Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, WorkerType, + ConnectionMode, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, + WorkerType, }, data_connector::{ MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, @@ -825,11 +826,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box