[router] add grpc pd and regular router init (#9893)

This commit is contained in:
Chang Su
2025-09-01 20:06:15 -07:00
committed by GitHub
parent b5245064f6
commit 9a0cac1be0
14 changed files with 783 additions and 58 deletions

View File

@@ -2,6 +2,7 @@ use pyo3::prelude::*;
pub mod config;
pub mod logging;
use std::collections::HashMap;
pub mod core;
#[cfg(feature = "grpc-client")]
pub mod grpc;
@@ -89,9 +90,39 @@ struct Router {
queue_size: usize,
queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<usize>,
// Connection mode (determined from worker URLs)
connection_mode: config::ConnectionMode,
// Model path for tokenizer
model_path: Option<String>,
// Explicit tokenizer path
tokenizer_path: Option<String>,
}
impl Router {
/// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
// Check if any URL is a gRPC endpoint (starts with grpc:// or has port that commonly indicates gRPC)
for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return config::ConnectionMode::Grpc;
}
// Also check for common gRPC ports if the scheme isn't specified
if let Ok(parsed_url) = url::Url::parse(url) {
if let Some(port) = parsed_url.port() {
// Common gRPC ports
if port == 50051 || port == 9090 || ((50000..=50100).contains(&port)) {
return config::ConnectionMode::Grpc;
}
}
} else if url.contains(":50051") || url.contains(":9090") || url.contains(":5000") {
// Fallback check for URLs that might not parse correctly
return config::ConnectionMode::Grpc;
}
}
// Default to HTTP
config::ConnectionMode::Http
}
/// Convert PyO3 Router to RouterConfig
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
use config::{
@@ -168,6 +199,7 @@ impl Router {
policy,
host: self.host.clone(),
port: self.port,
connection_mode: self.connection_mode.clone(),
max_payload_size: self.max_payload_size,
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
@@ -207,6 +239,8 @@ impl Router {
endpoint: self.health_check_endpoint.clone(),
},
enable_igw: self.enable_igw,
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
})
}
}
@@ -273,6 +307,9 @@ impl Router {
queue_size = 100,
queue_timeout_secs = 60,
rate_limit_tokens_per_second = None,
// Tokenizer defaults
model_path = None,
tokenizer_path = None,
))]
#[allow(clippy::too_many_arguments)]
fn new(
@@ -330,7 +367,26 @@ impl Router {
queue_size: usize,
queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<usize>,
model_path: Option<String>,
tokenizer_path: Option<String>,
) -> PyResult<Self> {
// Determine connection mode from worker URLs
let mut all_urls = worker_urls.clone();
// Add prefill URLs if in PD mode
if let Some(ref prefill_urls) = prefill_urls {
for (url, _) in prefill_urls {
all_urls.push(url.clone());
}
}
// Add decode URLs if in PD mode
if let Some(ref decode_urls) = decode_urls {
all_urls.extend(decode_urls.clone());
}
let connection_mode = Self::determine_connection_mode(&all_urls);
Ok(Router {
host,
port,
@@ -386,6 +442,9 @@ impl Router {
queue_size,
queue_timeout_secs,
rate_limit_tokens_per_second,
connection_mode,
model_path,
tokenizer_path,
})
}