[router] add grpc pd and regular router init (#9893)
This commit is contained in:
@@ -99,6 +99,9 @@ class RouterArgs:
|
|||||||
cb_timeout_duration_secs: int = 60
|
cb_timeout_duration_secs: int = 60
|
||||||
cb_window_duration_secs: int = 120
|
cb_window_duration_secs: int = 120
|
||||||
disable_circuit_breaker: bool = False
|
disable_circuit_breaker: bool = False
|
||||||
|
# Tokenizer configuration
|
||||||
|
model_path: Optional[str] = None
|
||||||
|
tokenizer_path: Optional[str] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
@@ -433,6 +436,19 @@ class RouterArgs:
|
|||||||
default=[],
|
default=[],
|
||||||
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
||||||
)
|
)
|
||||||
|
# Tokenizer configuration
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}model-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}tokenizer-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(
|
def from_cli_args(
|
||||||
@@ -554,6 +570,8 @@ class RouterArgs:
|
|||||||
health_check_endpoint=getattr(
|
health_check_endpoint=getattr(
|
||||||
args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint
|
args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint
|
||||||
),
|
),
|
||||||
|
model_path=getattr(args, f"{prefix}model_path", None),
|
||||||
|
tokenizer_path=getattr(args, f"{prefix}tokenizer_path", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -759,6 +777,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
health_check_timeout_secs=router_args.health_check_timeout_secs,
|
health_check_timeout_secs=router_args.health_check_timeout_secs,
|
||||||
health_check_interval_secs=router_args.health_check_interval_secs,
|
health_check_interval_secs=router_args.health_check_interval_secs,
|
||||||
health_check_endpoint=router_args.health_check_endpoint,
|
health_check_endpoint=router_args.health_check_endpoint,
|
||||||
|
model_path=router_args.model_path,
|
||||||
|
tokenizer_path=router_args.tokenizer_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
router.start()
|
router.start()
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ class Router:
|
|||||||
health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
|
health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
|
||||||
health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
|
health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
|
||||||
health_check_endpoint: Health check endpoint path. Default: '/health'
|
health_check_endpoint: Health check endpoint path. Default: '/health'
|
||||||
|
model_path: Model path for loading tokenizer (HuggingFace model ID or local path). Default: None
|
||||||
|
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -131,6 +133,8 @@ class Router:
|
|||||||
health_check_timeout_secs: int = 5,
|
health_check_timeout_secs: int = 5,
|
||||||
health_check_interval_secs: int = 60,
|
health_check_interval_secs: int = 60,
|
||||||
health_check_endpoint: str = "/health",
|
health_check_endpoint: str = "/health",
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if selector is None:
|
if selector is None:
|
||||||
selector = {}
|
selector = {}
|
||||||
@@ -195,6 +199,8 @@ class Router:
|
|||||||
health_check_timeout_secs=health_check_timeout_secs,
|
health_check_timeout_secs=health_check_timeout_secs,
|
||||||
health_check_interval_secs=health_check_interval_secs,
|
health_check_interval_secs=health_check_interval_secs,
|
||||||
health_check_endpoint=health_check_endpoint,
|
health_check_endpoint=health_check_endpoint,
|
||||||
|
model_path=model_path,
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
cb_window_duration_secs=60,
|
cb_window_duration_secs=60,
|
||||||
disable_retries=False,
|
disable_retries=False,
|
||||||
disable_circuit_breaker=False,
|
disable_circuit_breaker=False,
|
||||||
|
model_path=None,
|
||||||
|
tokenizer_path=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_router_args(self, **kwargs):
|
def create_router_args(self, **kwargs):
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ use std::collections::HashMap;
|
|||||||
pub struct RouterConfig {
|
pub struct RouterConfig {
|
||||||
/// Routing mode configuration
|
/// Routing mode configuration
|
||||||
pub mode: RoutingMode,
|
pub mode: RoutingMode,
|
||||||
|
/// Worker connection mode
|
||||||
|
#[serde(default)]
|
||||||
|
pub connection_mode: ConnectionMode,
|
||||||
/// Policy configuration
|
/// Policy configuration
|
||||||
pub policy: PolicyConfig,
|
pub policy: PolicyConfig,
|
||||||
/// Server host address
|
/// Server host address
|
||||||
@@ -60,6 +63,20 @@ pub struct RouterConfig {
|
|||||||
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
|
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub enable_igw: bool,
|
pub enable_igw: bool,
|
||||||
|
/// Model path for loading tokenizer (can be a HuggingFace model ID or local path)
|
||||||
|
pub model_path: Option<String>,
|
||||||
|
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
|
||||||
|
pub tokenizer_path: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ConnectionMode {
|
||||||
|
#[default]
|
||||||
|
#[serde(rename = "http")]
|
||||||
|
Http,
|
||||||
|
#[serde(rename = "grpc")]
|
||||||
|
Grpc,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Routing mode configuration
|
/// Routing mode configuration
|
||||||
@@ -336,6 +353,9 @@ impl Default for RouterConfig {
|
|||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: HealthCheckConfig::default(),
|
health_check: HealthCheckConfig::default(),
|
||||||
enable_igw: false,
|
enable_igw: false,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -478,6 +498,9 @@ mod tests {
|
|||||||
queue_size: 100,
|
queue_size: 100,
|
||||||
queue_timeout_secs: 60,
|
queue_timeout_secs: 60,
|
||||||
rate_limit_tokens_per_second: None,
|
rate_limit_tokens_per_second: None,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let json = serde_json::to_string(&config).unwrap();
|
let json = serde_json::to_string(&config).unwrap();
|
||||||
@@ -914,6 +937,9 @@ mod tests {
|
|||||||
queue_size: 100,
|
queue_size: 100,
|
||||||
queue_timeout_secs: 60,
|
queue_timeout_secs: 60,
|
||||||
rate_limit_tokens_per_second: None,
|
rate_limit_tokens_per_second: None,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.mode.is_pd_mode());
|
assert!(config.mode.is_pd_mode());
|
||||||
@@ -974,6 +1000,9 @@ mod tests {
|
|||||||
queue_size: 100,
|
queue_size: 100,
|
||||||
queue_timeout_secs: 60,
|
queue_timeout_secs: 60,
|
||||||
rate_limit_tokens_per_second: None,
|
rate_limit_tokens_per_second: None,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(!config.mode.is_pd_mode());
|
assert!(!config.mode.is_pd_mode());
|
||||||
@@ -1030,6 +1059,9 @@ mod tests {
|
|||||||
queue_size: 100,
|
queue_size: 100,
|
||||||
queue_timeout_secs: 60,
|
queue_timeout_secs: 60,
|
||||||
rate_limit_tokens_per_second: None,
|
rate_limit_tokens_per_second: None,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.has_service_discovery());
|
assert!(config.has_service_discovery());
|
||||||
|
|||||||
@@ -349,6 +349,16 @@ impl ConfigValidator {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate gRPC connection mode requires tokenizer configuration
|
||||||
|
if config.connection_mode == ConnectionMode::Grpc
|
||||||
|
&& config.tokenizer_path.is_none()
|
||||||
|
&& config.model_path.is_none()
|
||||||
|
{
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: "gRPC connection mode requires either --tokenizer-path or --model-path to be specified".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// All policies are now supported for both router types thanks to the unified trait design
|
// All policies are now supported for both router types thanks to the unified trait design
|
||||||
// No mode/policy restrictions needed anymore
|
// No mode/policy restrictions needed anymore
|
||||||
|
|
||||||
@@ -419,11 +429,14 @@ impl ConfigValidator {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if !url.starts_with("http://") && !url.starts_with("https://") {
|
if !url.starts_with("http://")
|
||||||
|
&& !url.starts_with("https://")
|
||||||
|
&& !url.starts_with("grpc://")
|
||||||
|
{
|
||||||
return Err(ConfigError::InvalidValue {
|
return Err(ConfigError::InvalidValue {
|
||||||
field: "worker_url".to_string(),
|
field: "worker_url".to_string(),
|
||||||
value: url.clone(),
|
value: url.clone(),
|
||||||
reason: "URL must start with http:// or https://".to_string(),
|
reason: "URL must start with http://, https://, or grpc://".to_string(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -684,4 +697,60 @@ mod tests {
|
|||||||
assert!(e.to_string().contains("prefill requires at least 2"));
|
assert!(e.to_string().contains("prefill requires at least 2"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_grpc_requires_tokenizer() {
|
||||||
|
// Test that gRPC connection mode requires tokenizer configuration
|
||||||
|
let mut config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||||
|
},
|
||||||
|
PolicyConfig::Random,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Set connection mode to gRPC without tokenizer config
|
||||||
|
config.connection_mode = ConnectionMode::Grpc;
|
||||||
|
config.tokenizer_path = None;
|
||||||
|
config.model_path = None;
|
||||||
|
|
||||||
|
let result = ConfigValidator::validate(&config);
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(e) = result {
|
||||||
|
assert!(e.to_string().contains("gRPC connection mode requires"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_grpc_with_model_path() {
|
||||||
|
// Test that gRPC works with model_path
|
||||||
|
let mut config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||||
|
},
|
||||||
|
PolicyConfig::Random,
|
||||||
|
);
|
||||||
|
|
||||||
|
config.connection_mode = ConnectionMode::Grpc;
|
||||||
|
config.model_path = Some("meta-llama/Llama-3-8B".to_string());
|
||||||
|
|
||||||
|
let result = ConfigValidator::validate(&config);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_grpc_with_tokenizer_path() {
|
||||||
|
// Test that gRPC works with tokenizer_path
|
||||||
|
let mut config = RouterConfig::new(
|
||||||
|
RoutingMode::Regular {
|
||||||
|
worker_urls: vec!["grpc://worker:50051".to_string()],
|
||||||
|
},
|
||||||
|
PolicyConfig::Random,
|
||||||
|
);
|
||||||
|
|
||||||
|
config.connection_mode = ConnectionMode::Grpc;
|
||||||
|
config.tokenizer_path = Some("/path/to/tokenizer.json".to_string());
|
||||||
|
|
||||||
|
let result = ConfigValidator::validate(&config);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use pyo3::prelude::*;
|
|||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
pub mod core;
|
pub mod core;
|
||||||
#[cfg(feature = "grpc-client")]
|
#[cfg(feature = "grpc-client")]
|
||||||
pub mod grpc;
|
pub mod grpc;
|
||||||
@@ -89,9 +90,39 @@ struct Router {
|
|||||||
queue_size: usize,
|
queue_size: usize,
|
||||||
queue_timeout_secs: u64,
|
queue_timeout_secs: u64,
|
||||||
rate_limit_tokens_per_second: Option<usize>,
|
rate_limit_tokens_per_second: Option<usize>,
|
||||||
|
// Connection mode (determined from worker URLs)
|
||||||
|
connection_mode: config::ConnectionMode,
|
||||||
|
// Model path for tokenizer
|
||||||
|
model_path: Option<String>,
|
||||||
|
// Explicit tokenizer path
|
||||||
|
tokenizer_path: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
|
/// Determine connection mode from worker URLs
|
||||||
|
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
|
||||||
|
// Check if any URL is a gRPC endpoint (starts with grpc:// or has port that commonly indicates gRPC)
|
||||||
|
for url in worker_urls {
|
||||||
|
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||||
|
return config::ConnectionMode::Grpc;
|
||||||
|
}
|
||||||
|
// Also check for common gRPC ports if the scheme isn't specified
|
||||||
|
if let Ok(parsed_url) = url::Url::parse(url) {
|
||||||
|
if let Some(port) = parsed_url.port() {
|
||||||
|
// Common gRPC ports
|
||||||
|
if port == 50051 || port == 9090 || ((50000..=50100).contains(&port)) {
|
||||||
|
return config::ConnectionMode::Grpc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if url.contains(":50051") || url.contains(":9090") || url.contains(":5000") {
|
||||||
|
// Fallback check for URLs that might not parse correctly
|
||||||
|
return config::ConnectionMode::Grpc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Default to HTTP
|
||||||
|
config::ConnectionMode::Http
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert PyO3 Router to RouterConfig
|
/// Convert PyO3 Router to RouterConfig
|
||||||
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
|
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
|
||||||
use config::{
|
use config::{
|
||||||
@@ -168,6 +199,7 @@ impl Router {
|
|||||||
policy,
|
policy,
|
||||||
host: self.host.clone(),
|
host: self.host.clone(),
|
||||||
port: self.port,
|
port: self.port,
|
||||||
|
connection_mode: self.connection_mode.clone(),
|
||||||
max_payload_size: self.max_payload_size,
|
max_payload_size: self.max_payload_size,
|
||||||
request_timeout_secs: self.request_timeout_secs,
|
request_timeout_secs: self.request_timeout_secs,
|
||||||
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
||||||
@@ -207,6 +239,8 @@ impl Router {
|
|||||||
endpoint: self.health_check_endpoint.clone(),
|
endpoint: self.health_check_endpoint.clone(),
|
||||||
},
|
},
|
||||||
enable_igw: self.enable_igw,
|
enable_igw: self.enable_igw,
|
||||||
|
model_path: self.model_path.clone(),
|
||||||
|
tokenizer_path: self.tokenizer_path.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -273,6 +307,9 @@ impl Router {
|
|||||||
queue_size = 100,
|
queue_size = 100,
|
||||||
queue_timeout_secs = 60,
|
queue_timeout_secs = 60,
|
||||||
rate_limit_tokens_per_second = None,
|
rate_limit_tokens_per_second = None,
|
||||||
|
// Tokenizer defaults
|
||||||
|
model_path = None,
|
||||||
|
tokenizer_path = None,
|
||||||
))]
|
))]
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn new(
|
fn new(
|
||||||
@@ -330,7 +367,26 @@ impl Router {
|
|||||||
queue_size: usize,
|
queue_size: usize,
|
||||||
queue_timeout_secs: u64,
|
queue_timeout_secs: u64,
|
||||||
rate_limit_tokens_per_second: Option<usize>,
|
rate_limit_tokens_per_second: Option<usize>,
|
||||||
|
model_path: Option<String>,
|
||||||
|
tokenizer_path: Option<String>,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
|
// Determine connection mode from worker URLs
|
||||||
|
let mut all_urls = worker_urls.clone();
|
||||||
|
|
||||||
|
// Add prefill URLs if in PD mode
|
||||||
|
if let Some(ref prefill_urls) = prefill_urls {
|
||||||
|
for (url, _) in prefill_urls {
|
||||||
|
all_urls.push(url.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add decode URLs if in PD mode
|
||||||
|
if let Some(ref decode_urls) = decode_urls {
|
||||||
|
all_urls.extend(decode_urls.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let connection_mode = Self::determine_connection_mode(&all_urls);
|
||||||
|
|
||||||
Ok(Router {
|
Ok(Router {
|
||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
@@ -386,6 +442,9 @@ impl Router {
|
|||||||
queue_size,
|
queue_size,
|
||||||
queue_timeout_secs,
|
queue_timeout_secs,
|
||||||
rate_limit_tokens_per_second,
|
rate_limit_tokens_per_second,
|
||||||
|
connection_mode,
|
||||||
|
model_path,
|
||||||
|
tokenizer_path,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use clap::{ArgAction, Parser};
|
use clap::{ArgAction, Parser};
|
||||||
use sglang_router_rs::config::{
|
use sglang_router_rs::config::{
|
||||||
CircuitBreakerConfig, ConfigError, ConfigResult, DiscoveryConfig, HealthCheckConfig,
|
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
||||||
MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::metrics::PrometheusConfig;
|
use sglang_router_rs::metrics::PrometheusConfig;
|
||||||
use sglang_router_rs::server::{self, ServerConfig};
|
use sglang_router_rs::server::{self, ServerConfig};
|
||||||
@@ -272,9 +272,42 @@ struct CliArgs {
|
|||||||
/// Enable Inference Gateway mode
|
/// Enable Inference Gateway mode
|
||||||
#[arg(long, default_value_t = false)]
|
#[arg(long, default_value_t = false)]
|
||||||
enable_igw: bool,
|
enable_igw: bool,
|
||||||
|
|
||||||
|
// Tokenizer configuration
|
||||||
|
/// Model path for loading tokenizer (HuggingFace model ID or local path)
|
||||||
|
#[arg(long)]
|
||||||
|
model_path: Option<String>,
|
||||||
|
|
||||||
|
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_path: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CliArgs {
|
impl CliArgs {
|
||||||
|
/// Determine connection mode from worker URLs
|
||||||
|
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
|
||||||
|
// Check if any URL is a gRPC endpoint (starts with grpc:// or has port that commonly indicates gRPC)
|
||||||
|
for url in worker_urls {
|
||||||
|
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||||
|
return ConnectionMode::Grpc;
|
||||||
|
}
|
||||||
|
// Also check for common gRPC ports if the scheme isn't specified
|
||||||
|
if let Ok(parsed_url) = url::Url::parse(url) {
|
||||||
|
if let Some(port) = parsed_url.port() {
|
||||||
|
// Common gRPC ports
|
||||||
|
if port == 50051 || port == 9090 || ((50000..=50100).contains(&port)) {
|
||||||
|
return ConnectionMode::Grpc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if url.contains(":50051") || url.contains(":9090") || url.contains(":5000") {
|
||||||
|
// Fallback check for URLs that might not parse correctly
|
||||||
|
return ConnectionMode::Grpc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Default to HTTP
|
||||||
|
ConnectionMode::Http
|
||||||
|
}
|
||||||
|
|
||||||
/// Parse selector strings into HashMap
|
/// Parse selector strings into HashMap
|
||||||
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
|
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
|
||||||
let mut map = HashMap::new();
|
let mut map = HashMap::new();
|
||||||
@@ -372,10 +405,30 @@ impl CliArgs {
|
|||||||
host: self.prometheus_host.clone(),
|
host: self.prometheus_host.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Determine connection mode from all worker URLs
|
||||||
|
let mut all_urls = Vec::new();
|
||||||
|
match &mode {
|
||||||
|
RoutingMode::Regular { worker_urls } => {
|
||||||
|
all_urls.extend(worker_urls.clone());
|
||||||
|
}
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
for (url, _) in prefill_urls {
|
||||||
|
all_urls.push(url.clone());
|
||||||
|
}
|
||||||
|
all_urls.extend(decode_urls.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let connection_mode = Self::determine_connection_mode(&all_urls);
|
||||||
|
|
||||||
// Build RouterConfig
|
// Build RouterConfig
|
||||||
Ok(RouterConfig {
|
Ok(RouterConfig {
|
||||||
mode,
|
mode,
|
||||||
policy,
|
policy,
|
||||||
|
connection_mode,
|
||||||
host: self.host.clone(),
|
host: self.host.clone(),
|
||||||
port: self.port,
|
port: self.port,
|
||||||
max_payload_size: self.max_payload_size,
|
max_payload_size: self.max_payload_size,
|
||||||
@@ -421,6 +474,8 @@ impl CliArgs {
|
|||||||
},
|
},
|
||||||
enable_igw: self.enable_igw,
|
enable_igw: self.enable_igw,
|
||||||
rate_limit_tokens_per_second: None,
|
rate_limit_tokens_per_second: None,
|
||||||
|
model_path: self.model_path.clone(),
|
||||||
|
tokenizer_path: self.tokenizer_path.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use super::{
|
|||||||
http::{pd_router::PDRouter, router::Router},
|
http::{pd_router::PDRouter, router::Router},
|
||||||
RouterTrait,
|
RouterTrait,
|
||||||
};
|
};
|
||||||
use crate::config::{PolicyConfig, RoutingMode};
|
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
|
||||||
use crate::policies::PolicyFactory;
|
use crate::policies::PolicyFactory;
|
||||||
use crate::server::AppContext;
|
use crate::server::AppContext;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -20,28 +20,56 @@ impl RouterFactory {
|
|||||||
return Self::create_igw_router(ctx).await;
|
return Self::create_igw_router(ctx).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add gRPC mode check here when implementing gRPC support
|
// Check connection mode and route to appropriate implementation
|
||||||
|
match ctx.router_config.connection_mode {
|
||||||
// Default to HTTP proxy mode
|
ConnectionMode::Grpc => {
|
||||||
match &ctx.router_config.mode {
|
// Route to gRPC implementation based on routing mode
|
||||||
RoutingMode::Regular { worker_urls } => {
|
match &ctx.router_config.mode {
|
||||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
|
RoutingMode::Regular { worker_urls } => {
|
||||||
|
Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||||
|
}
|
||||||
|
RoutingMode::PrefillDecode {
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
|
} => {
|
||||||
|
Self::create_grpc_pd_router(
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
prefill_policy.as_ref(),
|
||||||
|
decode_policy.as_ref(),
|
||||||
|
&ctx.router_config.policy,
|
||||||
|
ctx,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
RoutingMode::PrefillDecode {
|
ConnectionMode::Http => {
|
||||||
prefill_urls,
|
// Route to HTTP implementation based on routing mode
|
||||||
decode_urls,
|
match &ctx.router_config.mode {
|
||||||
prefill_policy,
|
RoutingMode::Regular { worker_urls } => {
|
||||||
decode_policy,
|
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
|
||||||
} => {
|
.await
|
||||||
Self::create_pd_router(
|
}
|
||||||
prefill_urls,
|
RoutingMode::PrefillDecode {
|
||||||
decode_urls,
|
prefill_urls,
|
||||||
prefill_policy.as_ref(),
|
decode_urls,
|
||||||
decode_policy.as_ref(),
|
prefill_policy,
|
||||||
&ctx.router_config.policy,
|
decode_policy,
|
||||||
ctx,
|
} => {
|
||||||
)
|
Self::create_pd_router(
|
||||||
.await
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
prefill_policy.as_ref(),
|
||||||
|
decode_policy.as_ref(),
|
||||||
|
&ctx.router_config.policy,
|
||||||
|
ctx,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -109,25 +137,92 @@ impl RouterFactory {
|
|||||||
|
|
||||||
/// Create a gRPC router with injected policy
|
/// Create a gRPC router with injected policy
|
||||||
pub async fn create_grpc_router(
|
pub async fn create_grpc_router(
|
||||||
_worker_urls: &[String],
|
worker_urls: &[String],
|
||||||
_policy_config: &PolicyConfig,
|
policy_config: &PolicyConfig,
|
||||||
_ctx: &Arc<AppContext>,
|
ctx: &Arc<AppContext>,
|
||||||
) -> Result<Box<dyn RouterTrait>, String> {
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
// For now, return an error as gRPC router is not yet implemented
|
use super::grpc::router::GrpcRouter;
|
||||||
Err("gRPC router is not yet implemented".to_string())
|
|
||||||
|
// Create policy
|
||||||
|
let policy = PolicyFactory::create_from_config(policy_config);
|
||||||
|
|
||||||
|
// Determine which tokenizer path to use
|
||||||
|
// Priority: tokenizer_path > model_path
|
||||||
|
let tokenizer_path = ctx
|
||||||
|
.router_config
|
||||||
|
.tokenizer_path
|
||||||
|
.clone()
|
||||||
|
.or_else(|| ctx.router_config.model_path.clone())
|
||||||
|
.ok_or_else(|| {
|
||||||
|
"gRPC router requires either --tokenizer-path or --model-path to be specified"
|
||||||
|
.to_string()
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Create gRPC router
|
||||||
|
let router = GrpcRouter::new(
|
||||||
|
worker_urls.to_vec(),
|
||||||
|
policy,
|
||||||
|
ctx.router_config.worker_startup_timeout_secs,
|
||||||
|
ctx.router_config.worker_startup_check_interval_secs,
|
||||||
|
ctx.router_config.dp_aware,
|
||||||
|
ctx.router_config.api_key.clone(),
|
||||||
|
ctx.router_config.effective_retry_config(),
|
||||||
|
ctx.router_config.effective_circuit_breaker_config(),
|
||||||
|
ctx.router_config.health_check.clone(),
|
||||||
|
tokenizer_path,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a gRPC PD router (placeholder for now)
|
/// Create a gRPC PD router with tokenizer and worker configuration
|
||||||
pub async fn create_grpc_pd_router(
|
pub async fn create_grpc_pd_router(
|
||||||
_prefill_urls: &[(String, Option<u16>)],
|
prefill_urls: &[(String, Option<u16>)],
|
||||||
_decode_urls: &[String],
|
decode_urls: &[String],
|
||||||
_prefill_policy_config: Option<&PolicyConfig>,
|
prefill_policy_config: Option<&PolicyConfig>,
|
||||||
_decode_policy_config: Option<&PolicyConfig>,
|
decode_policy_config: Option<&PolicyConfig>,
|
||||||
_main_policy_config: &PolicyConfig,
|
main_policy_config: &PolicyConfig,
|
||||||
_ctx: &Arc<AppContext>,
|
ctx: &Arc<AppContext>,
|
||||||
) -> Result<Box<dyn RouterTrait>, String> {
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
// For now, return an error as gRPC PD router is not yet implemented
|
use super::grpc::pd_router::GrpcPDRouter;
|
||||||
Err("gRPC PD router is not yet implemented".to_string())
|
|
||||||
|
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
||||||
|
let prefill_policy =
|
||||||
|
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
|
||||||
|
let decode_policy =
|
||||||
|
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
||||||
|
|
||||||
|
// Determine which tokenizer path to use
|
||||||
|
// Priority: tokenizer_path > model_path
|
||||||
|
let tokenizer_path = ctx
|
||||||
|
.router_config
|
||||||
|
.tokenizer_path
|
||||||
|
.clone()
|
||||||
|
.or_else(|| ctx.router_config.model_path.clone())
|
||||||
|
.ok_or_else(|| {
|
||||||
|
"gRPC PD router requires either --tokenizer-path or --model-path to be specified"
|
||||||
|
.to_string()
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Create gRPC PD router
|
||||||
|
let router = GrpcPDRouter::new(
|
||||||
|
prefill_urls.to_vec(),
|
||||||
|
decode_urls.to_vec(),
|
||||||
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
|
ctx.router_config.worker_startup_timeout_secs,
|
||||||
|
ctx.router_config.worker_startup_check_interval_secs,
|
||||||
|
ctx.router_config.dp_aware,
|
||||||
|
ctx.router_config.api_key.clone(),
|
||||||
|
ctx.router_config.effective_retry_config(),
|
||||||
|
ctx.router_config.effective_circuit_breaker_config(),
|
||||||
|
ctx.router_config.health_check.clone(),
|
||||||
|
tokenizer_path,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an IGW router (placeholder for future implementation)
|
/// Create an IGW router (placeholder for future implementation)
|
||||||
|
|||||||
@@ -1,7 +1,19 @@
|
|||||||
// PD (Prefill-Decode) gRPC Router Implementation
|
// PD (Prefill-Decode) gRPC Router Implementation
|
||||||
// TODO: Implement gRPC-based PD router for disaggregated prefill-decode systems
|
|
||||||
|
|
||||||
|
use crate::config::types::{
|
||||||
|
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||||
|
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
|
||||||
|
};
|
||||||
|
use crate::core::{
|
||||||
|
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||||
|
};
|
||||||
|
use crate::grpc::SglangSchedulerClient;
|
||||||
|
use crate::metrics::RouterMetrics;
|
||||||
|
use crate::policies::LoadBalancingPolicy;
|
||||||
|
use crate::reasoning_parser::ParserFactory;
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
|
use crate::tokenizer::{factory, traits::Tokenizer};
|
||||||
|
use crate::tool_parser::ParserRegistry;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
@@ -9,15 +21,222 @@ use axum::{
|
|||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
/// Placeholder for gRPC PD router
|
/// gRPC PD (Prefill-Decode) router implementation for SGLang
|
||||||
#[derive(Debug)]
|
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||||
pub struct GrpcPDRouter;
|
pub struct GrpcPDRouter {
|
||||||
|
/// Prefill worker connections
|
||||||
|
prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
|
/// Decode worker connections
|
||||||
|
decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
|
/// gRPC clients for prefill workers
|
||||||
|
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||||
|
/// gRPC clients for decode workers
|
||||||
|
decode_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||||
|
/// Load balancing policy for prefill
|
||||||
|
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
/// Load balancing policy for decode
|
||||||
|
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
/// Tokenizer for handling text encoding/decoding
|
||||||
|
tokenizer: Arc<dyn Tokenizer>,
|
||||||
|
/// Reasoning parser factory for structured reasoning outputs
|
||||||
|
reasoning_parser_factory: ParserFactory,
|
||||||
|
/// Tool parser registry for function/tool calls
|
||||||
|
tool_parser_registry: &'static ParserRegistry,
|
||||||
|
/// Worker health checkers
|
||||||
|
_prefill_health_checker: Option<HealthChecker>,
|
||||||
|
_decode_health_checker: Option<HealthChecker>,
|
||||||
|
/// Configuration
|
||||||
|
timeout_secs: u64,
|
||||||
|
interval_secs: u64,
|
||||||
|
dp_aware: bool,
|
||||||
|
api_key: Option<String>,
|
||||||
|
retry_config: RetryConfig,
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
impl GrpcPDRouter {
|
impl GrpcPDRouter {
|
||||||
pub async fn new() -> Result<Self, String> {
|
/// Create a new gRPC PD router
|
||||||
// TODO: Implement gRPC PD router initialization
|
#[allow(clippy::too_many_arguments)]
|
||||||
Err("gRPC PD router not yet implemented".to_string())
|
pub async fn new(
|
||||||
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
|
decode_urls: Vec<String>,
|
||||||
|
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
timeout_secs: u64,
|
||||||
|
interval_secs: u64,
|
||||||
|
dp_aware: bool,
|
||||||
|
api_key: Option<String>,
|
||||||
|
retry_config: RetryConfig,
|
||||||
|
circuit_breaker_config: ConfigCircuitBreakerConfig,
|
||||||
|
health_check_config: ConfigHealthCheckConfig,
|
||||||
|
tokenizer_path_or_model: String,
|
||||||
|
) -> Result<Self, String> {
|
||||||
|
// Update metrics
|
||||||
|
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
|
||||||
|
|
||||||
|
// Initialize tokenizer
|
||||||
|
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
|
||||||
|
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
|
||||||
|
|
||||||
|
// Initialize reasoning parser factory
|
||||||
|
let reasoning_parser_factory = ParserFactory::new();
|
||||||
|
|
||||||
|
// Get tool parser registry
|
||||||
|
let tool_parser_registry = ParserRegistry::new();
|
||||||
|
|
||||||
|
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||||
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
|
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||||
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create gRPC clients for prefill workers
|
||||||
|
let mut prefill_grpc_clients = HashMap::new();
|
||||||
|
for (url, _bootstrap_port) in &prefill_urls {
|
||||||
|
match SglangSchedulerClient::connect(url).await {
|
||||||
|
Ok(client) => {
|
||||||
|
prefill_grpc_clients.insert(url.clone(), client);
|
||||||
|
info!("Connected to gRPC prefill worker at {}", url);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
|
||||||
|
// Continue with other workers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create gRPC clients for decode workers
|
||||||
|
let mut decode_grpc_clients = HashMap::new();
|
||||||
|
for url in &decode_urls {
|
||||||
|
match SglangSchedulerClient::connect(url).await {
|
||||||
|
Ok(client) => {
|
||||||
|
decode_grpc_clients.insert(url.clone(), client);
|
||||||
|
info!("Connected to gRPC decode worker at {}", url);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
|
||||||
|
// Continue with other workers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
|
||||||
|
return Err("Failed to connect to any gRPC workers".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Prefill Worker trait objects with gRPC connection mode
|
||||||
|
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||||
|
.iter()
|
||||||
|
.map(|(url, bootstrap_port)| {
|
||||||
|
let worker = BasicWorker::with_connection_mode(
|
||||||
|
url.clone(),
|
||||||
|
WorkerType::Prefill {
|
||||||
|
bootstrap_port: *bootstrap_port,
|
||||||
|
},
|
||||||
|
crate::core::ConnectionMode::Grpc {
|
||||||
|
port: *bootstrap_port,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.with_circuit_breaker_config(core_cb_config.clone())
|
||||||
|
.with_health_config(HealthConfig {
|
||||||
|
timeout_secs: health_check_config.timeout_secs,
|
||||||
|
check_interval_secs: health_check_config.check_interval_secs,
|
||||||
|
endpoint: health_check_config.endpoint.clone(),
|
||||||
|
failure_threshold: health_check_config.failure_threshold,
|
||||||
|
success_threshold: health_check_config.success_threshold,
|
||||||
|
});
|
||||||
|
Box::new(worker) as Box<dyn Worker>
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Create Decode Worker trait objects with gRPC connection mode
|
||||||
|
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
||||||
|
.iter()
|
||||||
|
.map(|url| {
|
||||||
|
let worker = BasicWorker::with_connection_mode(
|
||||||
|
url.clone(),
|
||||||
|
WorkerType::Decode,
|
||||||
|
crate::core::ConnectionMode::Grpc { port: None },
|
||||||
|
)
|
||||||
|
.with_circuit_breaker_config(core_cb_config.clone())
|
||||||
|
.with_health_config(HealthConfig {
|
||||||
|
timeout_secs: health_check_config.timeout_secs,
|
||||||
|
check_interval_secs: health_check_config.check_interval_secs,
|
||||||
|
endpoint: health_check_config.endpoint.clone(),
|
||||||
|
failure_threshold: health_check_config.failure_threshold,
|
||||||
|
success_threshold: health_check_config.success_threshold,
|
||||||
|
});
|
||||||
|
Box::new(worker) as Box<dyn Worker>
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Initialize policies with workers if needed
|
||||||
|
if let Some(cache_aware) = prefill_policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_aware.init_workers(&prefill_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(cache_aware) = decode_policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_aware.init_workers(&decode_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
|
||||||
|
let decode_workers = Arc::new(RwLock::new(decode_workers));
|
||||||
|
|
||||||
|
let prefill_health_checker =
|
||||||
|
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs);
|
||||||
|
let decode_health_checker =
|
||||||
|
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
|
||||||
|
|
||||||
|
Ok(GrpcPDRouter {
|
||||||
|
prefill_workers,
|
||||||
|
decode_workers,
|
||||||
|
prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)),
|
||||||
|
decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)),
|
||||||
|
prefill_policy,
|
||||||
|
decode_policy,
|
||||||
|
tokenizer,
|
||||||
|
reasoning_parser_factory,
|
||||||
|
tool_parser_registry,
|
||||||
|
_prefill_health_checker: Some(prefill_health_checker),
|
||||||
|
_decode_health_checker: Some(decode_health_checker),
|
||||||
|
timeout_secs,
|
||||||
|
interval_secs,
|
||||||
|
dp_aware,
|
||||||
|
api_key,
|
||||||
|
retry_config,
|
||||||
|
circuit_breaker_config: core_cb_config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for GrpcPDRouter {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("GrpcPDRouter")
|
||||||
|
.field(
|
||||||
|
"prefill_workers_count",
|
||||||
|
&self.prefill_workers.read().unwrap().len(),
|
||||||
|
)
|
||||||
|
.field(
|
||||||
|
"decode_workers_count",
|
||||||
|
&self.decode_workers.read().unwrap().len(),
|
||||||
|
)
|
||||||
|
.field("timeout_secs", &self.timeout_secs)
|
||||||
|
.field("interval_secs", &self.interval_secs)
|
||||||
|
.field("dp_aware", &self.dp_aware)
|
||||||
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,19 @@
|
|||||||
// gRPC Router Implementation
|
// gRPC Router Implementation
|
||||||
// TODO: Implement gRPC-based router
|
|
||||||
|
|
||||||
|
use crate::config::types::{
|
||||||
|
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||||
|
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
|
||||||
|
};
|
||||||
|
use crate::core::{
|
||||||
|
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||||
|
};
|
||||||
|
use crate::grpc::SglangSchedulerClient;
|
||||||
|
use crate::metrics::RouterMetrics;
|
||||||
|
use crate::policies::LoadBalancingPolicy;
|
||||||
|
use crate::reasoning_parser::ParserFactory;
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
|
use crate::tokenizer::{factory, traits::Tokenizer};
|
||||||
|
use crate::tool_parser::ParserRegistry;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
@@ -9,15 +21,150 @@ use axum::{
|
|||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
/// Placeholder for gRPC router
|
/// gRPC router implementation for SGLang
|
||||||
#[derive(Debug)]
|
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||||
pub struct GrpcRouter;
|
pub struct GrpcRouter {
|
||||||
|
/// Worker connections
|
||||||
|
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
|
/// gRPC clients for each worker
|
||||||
|
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||||
|
/// Load balancing policy
|
||||||
|
policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
/// Tokenizer for handling text encoding/decoding
|
||||||
|
tokenizer: Arc<dyn Tokenizer>,
|
||||||
|
/// Reasoning parser factory for structured reasoning outputs
|
||||||
|
reasoning_parser_factory: ParserFactory,
|
||||||
|
/// Tool parser registry for function/tool calls
|
||||||
|
tool_parser_registry: &'static ParserRegistry,
|
||||||
|
/// Worker health checker
|
||||||
|
_health_checker: Option<HealthChecker>,
|
||||||
|
/// Configuration
|
||||||
|
timeout_secs: u64,
|
||||||
|
interval_secs: u64,
|
||||||
|
dp_aware: bool,
|
||||||
|
api_key: Option<String>,
|
||||||
|
retry_config: RetryConfig,
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
impl GrpcRouter {
|
impl GrpcRouter {
|
||||||
pub async fn new() -> Result<Self, String> {
|
/// Create a new gRPC router
|
||||||
// TODO: Implement gRPC router initialization
|
#[allow(clippy::too_many_arguments)]
|
||||||
Err("gRPC router not yet implemented".to_string())
|
pub async fn new(
|
||||||
|
worker_urls: Vec<String>,
|
||||||
|
policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
timeout_secs: u64,
|
||||||
|
interval_secs: u64,
|
||||||
|
dp_aware: bool,
|
||||||
|
api_key: Option<String>,
|
||||||
|
retry_config: RetryConfig,
|
||||||
|
circuit_breaker_config: ConfigCircuitBreakerConfig,
|
||||||
|
health_check_config: ConfigHealthCheckConfig,
|
||||||
|
tokenizer_path_or_model: String,
|
||||||
|
) -> Result<Self, String> {
|
||||||
|
// Update metrics
|
||||||
|
RouterMetrics::set_active_workers(worker_urls.len());
|
||||||
|
|
||||||
|
// Initialize tokenizer
|
||||||
|
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
|
||||||
|
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
|
||||||
|
|
||||||
|
// Initialize reasoning parser factory
|
||||||
|
let reasoning_parser_factory = ParserFactory::new();
|
||||||
|
|
||||||
|
// Get tool parser registry
|
||||||
|
let tool_parser_registry = ParserRegistry::new();
|
||||||
|
|
||||||
|
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||||
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
|
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||||
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create gRPC clients for each worker
|
||||||
|
let mut grpc_clients = HashMap::new();
|
||||||
|
for url in &worker_urls {
|
||||||
|
match SglangSchedulerClient::connect(url).await {
|
||||||
|
Ok(client) => {
|
||||||
|
grpc_clients.insert(url.clone(), client);
|
||||||
|
info!("Connected to gRPC worker at {}", url);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to connect to gRPC worker at {}: {}", url, e);
|
||||||
|
// Continue with other workers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if grpc_clients.is_empty() {
|
||||||
|
return Err("Failed to connect to any gRPC workers".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Worker trait objects with gRPC connection mode
|
||||||
|
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||||
|
.iter()
|
||||||
|
.map(|url| {
|
||||||
|
let worker = BasicWorker::with_connection_mode(
|
||||||
|
url.clone(),
|
||||||
|
WorkerType::Regular,
|
||||||
|
crate::core::ConnectionMode::Grpc { port: None },
|
||||||
|
)
|
||||||
|
.with_circuit_breaker_config(core_cb_config.clone())
|
||||||
|
.with_health_config(HealthConfig {
|
||||||
|
timeout_secs: health_check_config.timeout_secs,
|
||||||
|
check_interval_secs: health_check_config.check_interval_secs,
|
||||||
|
endpoint: health_check_config.endpoint.clone(),
|
||||||
|
failure_threshold: health_check_config.failure_threshold,
|
||||||
|
success_threshold: health_check_config.success_threshold,
|
||||||
|
});
|
||||||
|
Box::new(worker) as Box<dyn Worker>
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Initialize policy with workers if needed
|
||||||
|
if let Some(cache_aware) = policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_aware.init_workers(&workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
let workers = Arc::new(RwLock::new(workers));
|
||||||
|
let health_checker = crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
|
||||||
|
|
||||||
|
Ok(GrpcRouter {
|
||||||
|
workers,
|
||||||
|
grpc_clients: Arc::new(RwLock::new(grpc_clients)),
|
||||||
|
policy,
|
||||||
|
tokenizer,
|
||||||
|
reasoning_parser_factory,
|
||||||
|
tool_parser_registry,
|
||||||
|
_health_checker: Some(health_checker),
|
||||||
|
timeout_secs,
|
||||||
|
interval_secs,
|
||||||
|
dp_aware,
|
||||||
|
api_key,
|
||||||
|
retry_config,
|
||||||
|
circuit_breaker_config: core_cb_config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for GrpcRouter {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("GrpcRouter")
|
||||||
|
.field("workers_count", &self.workers.read().unwrap().len())
|
||||||
|
.field("timeout_secs", &self.timeout_secs)
|
||||||
|
.field("interval_secs", &self.interval_secs)
|
||||||
|
.field("dp_aware", &self.dp_aware)
|
||||||
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{
|
use sglang_router_rs::config::{
|
||||||
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -55,6 +55,9 @@ impl TestContext {
|
|||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
enable_igw: false,
|
enable_igw: false,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
Self::new_with_config(config, worker_configs).await
|
Self::new_with_config(config, worker_configs).await
|
||||||
@@ -1101,6 +1104,9 @@ mod error_tests {
|
|||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
enable_igw: false,
|
enable_igw: false,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = TestContext::new_with_config(
|
let ctx = TestContext::new_with_config(
|
||||||
@@ -1456,6 +1462,9 @@ mod pd_mode_tests {
|
|||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
enable_igw: false,
|
enable_igw: false,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create app context
|
// Create app context
|
||||||
@@ -1615,6 +1624,9 @@ mod request_id_tests {
|
|||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
enable_igw: false,
|
enable_igw: false,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = TestContext::new_with_config(
|
let ctx = TestContext::new_with_config(
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{
|
use sglang_router_rs::config::{
|
||||||
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -46,6 +46,9 @@ impl TestContext {
|
|||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
enable_igw: false,
|
enable_igw: false,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut workers = Vec::new();
|
let mut workers = Vec::new();
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use futures_util::StreamExt;
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{
|
use sglang_router_rs::config::{
|
||||||
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -47,6 +47,9 @@ impl TestContext {
|
|||||||
disable_circuit_breaker: false,
|
disable_circuit_breaker: false,
|
||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
enable_igw: false,
|
enable_igw: false,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut workers = Vec::new();
|
let mut workers = Vec::new();
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
mod test_pd_routing {
|
mod test_pd_routing {
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{
|
use sglang_router_rs::config::{
|
||||||
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||||
use sglang_router_rs::routers::http::pd_types::get_hostname;
|
use sglang_router_rs::routers::http::pd_types::get_hostname;
|
||||||
@@ -188,6 +188,9 @@ mod test_pd_routing {
|
|||||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||||
enable_igw: false,
|
enable_igw: false,
|
||||||
rate_limit_tokens_per_second: None,
|
rate_limit_tokens_per_second: None,
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Router creation will fail due to health checks, but config should be valid
|
// Router creation will fail due to health checks, but config should be valid
|
||||||
|
|||||||
Reference in New Issue
Block a user