[Router] Consolidate ConnectionMode enum to core module (#11937)
This commit is contained in:
@@ -3,6 +3,7 @@ use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::ConfigResult;
|
||||
use crate::core::ConnectionMode;
|
||||
|
||||
/// Main router configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -208,16 +209,6 @@ impl std::fmt::Debug for OracleConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ConnectionMode {
|
||||
#[default]
|
||||
#[serde(rename = "http")]
|
||||
Http,
|
||||
#[serde(rename = "grpc")]
|
||||
Grpc,
|
||||
}
|
||||
|
||||
/// Routing mode configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use super::*;
|
||||
use crate::core::ConnectionMode;
|
||||
|
||||
/// Configuration validator
|
||||
pub struct ConfigValidator;
|
||||
@@ -476,7 +477,7 @@ impl ConfigValidator {
|
||||
}
|
||||
|
||||
// Validate gRPC connection mode requires tokenizer configuration
|
||||
if config.connection_mode == ConnectionMode::Grpc
|
||||
if matches!(config.connection_mode, ConnectionMode::Grpc { .. })
|
||||
&& config.tokenizer_path.is_none()
|
||||
&& config.model_path.is_none()
|
||||
{
|
||||
@@ -832,7 +833,7 @@ mod tests {
|
||||
);
|
||||
|
||||
// Set connection mode to gRPC without tokenizer config
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.connection_mode = ConnectionMode::Grpc { port: None };
|
||||
config.tokenizer_path = None;
|
||||
config.model_path = None;
|
||||
|
||||
@@ -852,7 +853,7 @@ mod tests {
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.connection_mode = ConnectionMode::Grpc { port: None };
|
||||
config.model_path = Some("meta-llama/Llama-3-8B".to_string());
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
@@ -868,7 +869,7 @@ mod tests {
|
||||
PolicyConfig::Random,
|
||||
);
|
||||
|
||||
config.connection_mode = ConnectionMode::Grpc;
|
||||
config.connection_mode = ConnectionMode::Grpc { port: None };
|
||||
config.tokenizer_path = Some("/path/to/tokenizer.json".to_string());
|
||||
|
||||
let result = ConfigValidator::validate(&config);
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::{
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use tokio::{sync::RwLock, time};
|
||||
|
||||
@@ -240,13 +241,17 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
}
|
||||
|
||||
/// Connection mode for worker communication
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ConnectionMode {
|
||||
/// HTTP/REST connection
|
||||
#[default]
|
||||
Http,
|
||||
/// gRPC connection
|
||||
Grpc {
|
||||
/// Optional port for gRPC endpoint (if different from URL)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(default)]
|
||||
port: Option<u16>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ struct Router {
|
||||
queue_size: usize,
|
||||
queue_timeout_secs: u64,
|
||||
rate_limit_tokens_per_second: Option<i32>,
|
||||
connection_mode: config::ConnectionMode,
|
||||
connection_mode: core::ConnectionMode,
|
||||
model_path: Option<String>,
|
||||
tokenizer_path: Option<String>,
|
||||
chat_template: Option<String>,
|
||||
@@ -211,13 +211,13 @@ struct Router {
|
||||
|
||||
impl Router {
|
||||
/// Determine connection mode from worker URLs
|
||||
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
|
||||
fn determine_connection_mode(worker_urls: &[String]) -> core::ConnectionMode {
|
||||
for url in worker_urls {
|
||||
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||
return config::ConnectionMode::Grpc;
|
||||
return core::ConnectionMode::Grpc { port: None };
|
||||
}
|
||||
}
|
||||
config::ConnectionMode::Http
|
||||
core::ConnectionMode::Http
|
||||
}
|
||||
|
||||
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
|
||||
|
||||
@@ -3,10 +3,11 @@ use std::collections::HashMap;
|
||||
use clap::{ArgAction, Parser, ValueEnum};
|
||||
use sglang_router_rs::{
|
||||
config::{
|
||||
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
||||
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
|
||||
RouterConfig, RoutingMode, TokenizerCacheConfig,
|
||||
CircuitBreakerConfig, ConfigError, ConfigResult, DiscoveryConfig, HealthCheckConfig,
|
||||
HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, RouterConfig,
|
||||
RoutingMode, TokenizerCacheConfig,
|
||||
},
|
||||
core::ConnectionMode,
|
||||
metrics::PrometheusConfig,
|
||||
server::{self, ServerConfig},
|
||||
service_discovery::ServiceDiscoveryConfig,
|
||||
@@ -325,7 +326,7 @@ impl CliArgs {
|
||||
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
|
||||
for url in worker_urls {
|
||||
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||
return ConnectionMode::Grpc;
|
||||
return ConnectionMode::Grpc { port: None };
|
||||
}
|
||||
}
|
||||
ConnectionMode::Http
|
||||
|
||||
@@ -9,7 +9,8 @@ use super::{
|
||||
RouterTrait,
|
||||
};
|
||||
use crate::{
|
||||
config::{ConnectionMode, PolicyConfig, RoutingMode},
|
||||
config::{PolicyConfig, RoutingMode},
|
||||
core::ConnectionMode,
|
||||
policies::PolicyFactory,
|
||||
server::AppContext,
|
||||
};
|
||||
@@ -21,7 +22,7 @@ impl RouterFactory {
|
||||
/// Create a router instance from application context
|
||||
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
match ctx.router_config.connection_mode {
|
||||
ConnectionMode::Grpc => match &ctx.router_config.mode {
|
||||
ConnectionMode::Grpc { .. } => match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { .. } => Self::create_grpc_router(ctx).await,
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_policy,
|
||||
|
||||
@@ -18,8 +18,8 @@ use serde_json::Value;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::{
|
||||
config::{ConnectionMode, RoutingMode},
|
||||
core::{WorkerRegistry, WorkerType},
|
||||
config::RoutingMode,
|
||||
core::{ConnectionMode, WorkerRegistry, WorkerType},
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
classify::ClassifyRequest,
|
||||
@@ -148,13 +148,13 @@ impl RouterManager {
|
||||
(ConnectionMode::Http, RoutingMode::OpenAI { .. }) => {
|
||||
RouterId::new("http-openai".to_string())
|
||||
}
|
||||
(ConnectionMode::Grpc, RoutingMode::Regular { .. }) => {
|
||||
(ConnectionMode::Grpc { .. }, RoutingMode::Regular { .. }) => {
|
||||
RouterId::new("grpc-regular".to_string())
|
||||
}
|
||||
(ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => {
|
||||
(ConnectionMode::Grpc { .. }, RoutingMode::PrefillDecode { .. }) => {
|
||||
RouterId::new("grpc-pd".to_string())
|
||||
}
|
||||
(ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => {
|
||||
(ConnectionMode::Grpc { .. }, RoutingMode::OpenAI { .. }) => {
|
||||
RouterId::new("grpc-regular".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,14 +20,15 @@ use tokio::{net::TcpListener, signal, spawn};
|
||||
use tracing::{error, info, warn, Level};
|
||||
|
||||
use crate::{
|
||||
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
||||
config::{HistoryBackend, RouterConfig, RoutingMode},
|
||||
core::{
|
||||
worker_to_info,
|
||||
workflow::{
|
||||
create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber,
|
||||
WorkflowEngine,
|
||||
},
|
||||
Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, WorkerType,
|
||||
ConnectionMode, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry,
|
||||
WorkerType,
|
||||
},
|
||||
data_connector::{
|
||||
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||
@@ -825,11 +826,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
};
|
||||
|
||||
// Initialize tokenizer and parser factories for gRPC mode
|
||||
let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if config
|
||||
.router_config
|
||||
.connection_mode
|
||||
== ConnectionMode::Grpc
|
||||
{
|
||||
let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if matches!(
|
||||
config.router_config.connection_mode,
|
||||
ConnectionMode::Grpc { .. }
|
||||
) {
|
||||
let tokenizer_path = config
|
||||
.router_config
|
||||
.tokenizer_path
|
||||
|
||||
@@ -11,10 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::{
|
||||
config::{
|
||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
},
|
||||
core::Job,
|
||||
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
|
||||
core::{ConnectionMode, Job},
|
||||
routers::{RouterFactory, RouterTrait},
|
||||
server::AppContext,
|
||||
};
|
||||
|
||||
@@ -16,9 +16,10 @@ use common::{
|
||||
};
|
||||
use sglang_router_rs::{
|
||||
config::{
|
||||
CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig,
|
||||
RouterConfig, RoutingMode,
|
||||
CircuitBreakerConfig, HealthCheckConfig, PolicyConfig, RetryConfig, RouterConfig,
|
||||
RoutingMode,
|
||||
},
|
||||
core::ConnectionMode,
|
||||
routers::RouterFactory,
|
||||
};
|
||||
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
mod test_pd_routing {
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::{
|
||||
config::{
|
||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig,
|
||||
RoutingMode,
|
||||
},
|
||||
core::{BasicWorkerBuilder, Worker, WorkerType},
|
||||
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
|
||||
core::{BasicWorkerBuilder, ConnectionMode, Worker, WorkerType},
|
||||
routers::{http::pd_types::PDSelectionPolicy, RouterFactory},
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user