[router] add grpc pd and regular router init (#9893)
This commit is contained in:
@@ -2,6 +2,7 @@ use pyo3::prelude::*;
|
||||
pub mod config;
|
||||
pub mod logging;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub mod core;
|
||||
#[cfg(feature = "grpc-client")]
|
||||
pub mod grpc;
|
||||
@@ -89,9 +90,39 @@ struct Router {
|
||||
queue_size: usize,
|
||||
queue_timeout_secs: u64,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
// Connection mode (determined from worker URLs)
|
||||
connection_mode: config::ConnectionMode,
|
||||
// Model path for tokenizer
|
||||
model_path: Option<String>,
|
||||
// Explicit tokenizer path
|
||||
tokenizer_path: Option<String>,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
/// Determine connection mode from worker URLs
|
||||
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
|
||||
// Check if any URL is a gRPC endpoint (starts with grpc:// or has port that commonly indicates gRPC)
|
||||
for url in worker_urls {
|
||||
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
|
||||
return config::ConnectionMode::Grpc;
|
||||
}
|
||||
// Also check for common gRPC ports if the scheme isn't specified
|
||||
if let Ok(parsed_url) = url::Url::parse(url) {
|
||||
if let Some(port) = parsed_url.port() {
|
||||
// Common gRPC ports
|
||||
if port == 50051 || port == 9090 || ((50000..=50100).contains(&port)) {
|
||||
return config::ConnectionMode::Grpc;
|
||||
}
|
||||
}
|
||||
} else if url.contains(":50051") || url.contains(":9090") || url.contains(":5000") {
|
||||
// Fallback check for URLs that might not parse correctly
|
||||
return config::ConnectionMode::Grpc;
|
||||
}
|
||||
}
|
||||
// Default to HTTP
|
||||
config::ConnectionMode::Http
|
||||
}
|
||||
|
||||
/// Convert PyO3 Router to RouterConfig
|
||||
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
|
||||
use config::{
|
||||
@@ -168,6 +199,7 @@ impl Router {
|
||||
policy,
|
||||
host: self.host.clone(),
|
||||
port: self.port,
|
||||
connection_mode: self.connection_mode.clone(),
|
||||
max_payload_size: self.max_payload_size,
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
||||
@@ -207,6 +239,8 @@ impl Router {
|
||||
endpoint: self.health_check_endpoint.clone(),
|
||||
},
|
||||
enable_igw: self.enable_igw,
|
||||
model_path: self.model_path.clone(),
|
||||
tokenizer_path: self.tokenizer_path.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -273,6 +307,9 @@ impl Router {
|
||||
queue_size = 100,
|
||||
queue_timeout_secs = 60,
|
||||
rate_limit_tokens_per_second = None,
|
||||
// Tokenizer defaults
|
||||
model_path = None,
|
||||
tokenizer_path = None,
|
||||
))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
@@ -330,7 +367,26 @@ impl Router {
|
||||
queue_size: usize,
|
||||
queue_timeout_secs: u64,
|
||||
rate_limit_tokens_per_second: Option<usize>,
|
||||
model_path: Option<String>,
|
||||
tokenizer_path: Option<String>,
|
||||
) -> PyResult<Self> {
|
||||
// Determine connection mode from worker URLs
|
||||
let mut all_urls = worker_urls.clone();
|
||||
|
||||
// Add prefill URLs if in PD mode
|
||||
if let Some(ref prefill_urls) = prefill_urls {
|
||||
for (url, _) in prefill_urls {
|
||||
all_urls.push(url.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Add decode URLs if in PD mode
|
||||
if let Some(ref decode_urls) = decode_urls {
|
||||
all_urls.extend(decode_urls.clone());
|
||||
}
|
||||
|
||||
let connection_mode = Self::determine_connection_mode(&all_urls);
|
||||
|
||||
Ok(Router {
|
||||
host,
|
||||
port,
|
||||
@@ -386,6 +442,9 @@ impl Router {
|
||||
queue_size,
|
||||
queue_timeout_secs,
|
||||
rate_limit_tokens_per_second,
|
||||
connection_mode,
|
||||
model_path,
|
||||
tokenizer_path,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user