Files
sglang/sgl-router/src/lib.rs
2025-01-20 14:36:54 -08:00

124 lines
3.6 KiB
Rust

use pyo3::prelude::*;
pub mod router;
pub mod server;
pub mod tree;
#[pyclass(eq)]
#[derive(Clone, PartialEq)]
pub enum PolicyType {
Random,
RoundRobin,
CacheAware,
}
#[pyclass]
struct Router {
host: String,
port: u16,
worker_urls: Vec<String>,
policy: PolicyType,
worker_startup_timeout_secs: u64,
worker_startup_check_interval: u64,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
verbose: bool,
}
#[pymethods]
impl Router {
#[new]
#[pyo3(signature = (
worker_urls,
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
port = 3001,
worker_startup_timeout_secs = 300,
worker_startup_check_interval = 10,
cache_threshold = 0.50,
balance_abs_threshold = 32,
balance_rel_threshold = 1.0001,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 4 * 1024 * 1024,
verbose = false
))]
fn new(
worker_urls: Vec<String>,
policy: PolicyType,
host: String,
port: u16,
worker_startup_timeout_secs: u64,
worker_startup_check_interval: u64,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
verbose: bool,
) -> PyResult<Self> {
Ok(Router {
host,
port,
worker_urls,
policy,
worker_startup_timeout_secs,
worker_startup_check_interval,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
max_payload_size,
verbose,
})
}
fn start(&self) -> PyResult<()> {
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
};
actix_web::rt::System::new().block_on(async move {
server::startup(server::ServerConfig {
host: self.host.clone(),
port: self.port,
worker_urls: self.worker_urls.clone(),
policy_config,
verbose: self.verbose,
max_payload_size: self.max_payload_size,
})
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(())
})
}
}
#[pymodule]
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PolicyType>()?;
m.add_class::<Router>()?;
Ok(())
}